diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index a671b1c539..a79744ca6f 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -295,8 +295,9 @@ def setup_vllm_engines( pg = ray.util.placement_group(bundles, strategy="PACK") ray.get(pg.ready()) - param_prompt_Q = ray_queue.Queue(maxsize=10) - inference_results_Q = ray_queue.Queue(maxsize=10) + # Adjust queue sizes for individual prompts + param_prompt_Q = ray_queue.Queue(maxsize=100) + inference_results_Q = ray_queue.Queue(maxsize=100) vllm_engines = vllm_utils3.create_vllm_engines( num_engines=args.vllm_num_engines, @@ -334,6 +335,41 @@ def get_batch_data( return prompts +def run_generation_batch( + inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int +) -> dict[str, Any]: + """Run generation for a batch of prompts and measure performance.""" + + start_time = time.time() + + # Insert individual prompts + for i, prompt in enumerate(prompts): + dataset_idx = batch_idx * len(prompts) + i + param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx)) + + # Collect individual results + all_responses = [] + all_finish_reasons = [] + for _ in range(len(prompts)): + result = inference_results_Q.get() + all_responses.extend(result.responses) + all_finish_reasons.extend(result.finish_reasons) + + generation_time = time.time() - start_time + + new_tokens = sum(len(response) for response in all_responses) + tokens_per_second = new_tokens / generation_time + return { + "tokens_per_second": tokens_per_second, + "generation_time": generation_time, + "num_new_tokens": new_tokens, + # dict mapping string reasons to counts. + "finish_reasons": collections.Counter(all_finish_reasons), + "response_lengths": [len(response) for response in all_responses], + "prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts + } + + def run_benchmark( dataset: datasets.Dataset, vllm_engines: list[ray.actor.ActorHandle], @@ -366,6 +402,7 @@ def run_benchmark( num_batches + 1, # eval_freq (avoid evaluation) num_batches, # num_training_steps 1, # resume_training_step + args.num_unique_prompts_rollout // args.vllm_num_engines, # batch_size ) # Wait for engines to be ready @@ -383,37 +420,47 @@ def run_benchmark( ] submission_start_time = time.time() for batch_idx in range(num_batches): - param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=all_prompts[batch_idx], dataset_index=batch_idx)) + # Insert individual prompts for this batch + for i, prompt in enumerate(all_prompts[batch_idx]): + dataset_idx = batch_idx * args.num_unique_prompts_rollout + i + param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx)) submission_time = time.time() - submission_start_time logger.info(f"All batches submitted in {submission_time:.2f}s") # Receive results and measure time for each batch last_completion_time = submission_start_time for batch_idx in range(num_batches): - result = inference_results_Q.get() + # Collect individual results for this batch + batch_responses = [] + batch_finish_reasons = [] + for _ in range(args.num_unique_prompts_rollout): + result = inference_results_Q.get() + batch_responses.extend(result.responses) + batch_finish_reasons.extend(result.finish_reasons) + completion_time = time.time() batch_generation_time = completion_time - last_completion_time last_completion_time = completion_time - # Process result - new_tokens = sum(len(response) for response in result.responses) + # Process batch results + new_tokens = sum(len(response) for response in batch_responses) tokens_per_second = new_tokens / batch_generation_time result_dict = { "tokens_per_second": tokens_per_second, "generation_time": batch_generation_time, "num_new_tokens": new_tokens, - "finish_reasons": collections.Counter(result.finish_reasons), - "response_lengths": [len(response) for response in result.responses], - "batch_idx": result.dataset_index, + "finish_reasons": collections.Counter(batch_finish_reasons), + "response_lengths": [len(response) for response in batch_responses], + "batch_idx": batch_idx, } result_dict["mfu"] = 100 * result_dict["tokens_per_second"] * flops_per_token / device_flops # We incrementally save completion lengths so even if the job dies, we still have data. - save_completion_lengths([result_dict], timestamp, result.dataset_index) + save_completion_lengths([result_dict], timestamp, batch_idx) results.append(result_dict) logger.info( - f"Batch {result.dataset_index + 1}: " + f"Batch {batch_idx + 1}: " f"{result_dict['tokens_per_second']:.2f} new tokens/sec, " f"MFU: {result_dict['mfu']:.2f}%, " f"generation time: {batch_generation_time:.2f}s" diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index fe02a5bb29..ac8973e02d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1106,50 +1106,42 @@ def get_bundle_index(rank, num_gpus_per_node): def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int + inference_results_Q: ray_queue.Queue, pending_queries_map: dict, expected_results: int, training_step: int ) -> tuple: - """Accumulate multiple inference results into a single training batch. + """Accumulate individual inference results into a single training batch. Args: - inference_results_Q: Queue containing GenerationResult objects + inference_results_Q: Queue containing GenerationResult objects (one per prompt) pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets) - vllm_num_engines: Number of vLLM engines (number of batches to accumulate) + expected_results: Number of individual results to accumulate training_step: Current training step for error reporting Returns: Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets) """ - # Collect results from all engines + # Collect individual results results = [] all_queries = [] all_ground_truths = [] all_datasets = [] - for batch_idx in range(vllm_num_engines): - # Get result from queue + for i in range(expected_results): + # Get individual result from queue result = inference_results_Q.get() - dataset_indices = result.dataset_index - if dataset_indices is None: - raise RuntimeError(f"Dataset indices is None for batch {batch_idx}") + if result.dataset_index is None or len(result.dataset_index) != 1: + raise RuntimeError(f"Expected single dataset index, got {result.dataset_index}") - # Get corresponding queries, ground_truths, datasets for each individual prompt - batch_queries = [] - batch_ground_truths = [] - batch_datasets = [] - for dataset_idx in dataset_indices: - if dataset_idx not in pending_queries_map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + dataset_idx = result.dataset_index[0] + if dataset_idx not in pending_queries_map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) - batch_queries.append(query) - batch_ground_truths.append(ground_truth) - batch_datasets.append(dataset) + query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) results.append(result) - all_queries.extend(batch_queries) - all_ground_truths.extend(batch_ground_truths) - all_datasets.extend(batch_datasets) + all_queries.append(query) + all_ground_truths.append(ground_truth) + all_datasets.append(dataset) # Combine all results into a single GenerationResult combined_responses = [] @@ -1209,7 +1201,10 @@ def data_preparation_thread( # Accumulate results from multiple vLLM engines into a single training batch with Timer("🚀 [Data Preparation Thread] Getting response ids"): result, queries, ground_truths, datasets = accumulate_inference_batches( - inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step + inference_results_Q, + pending_queries_map, + args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout, + training_step, ) # ------------------------------------------------------------------------------------------------ @@ -1677,29 +1672,19 @@ def split_and_insert_batch( param_prompt_Q, eval_prompt_token_ids=None, ): - """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - # Split the batch over the VLLM engines. - inference_batch_size = len(queries_next) // vllm_num_engines - for batch_idx in range(vllm_num_engines): - start_idx = batch_idx * inference_batch_size - end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next) - - batch_queries = queries_next[start_idx:end_idx] - batch_ground_truths = ground_truths_next[start_idx:end_idx] - batch_datasets = datasets_next[start_idx:end_idx] - batch_dataset_indices = dataset_indices[start_idx:end_idx] - - # Store individual prompts in the map using dataset indices as keys - for i, dataset_idx in enumerate(batch_dataset_indices): - pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i]) - - # Use PromptRequest for Ray queue with batch-specific dataset_index list + """Insert individual prompts into queues and mapping.""" + # Insert individual prompts into the queue + for i, dataset_idx in enumerate(dataset_indices): + # Store individual prompt in the map using dataset index as key + pending_queries_map[dataset_idx] = (queries_next[i], ground_truths_next[i], datasets_next[i]) + + # Create PromptRequest for single prompt param_prompt_Q.put( PromptRequest( - prompts=batch_queries, + prompt=queries_next[i], training_step=training_step, eval_prompts=eval_prompt_token_ids, - dataset_index=batch_dataset_indices, + dataset_index=dataset_idx, ) ) @@ -2049,9 +2034,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa pprint([args, model_config]) - # Create Ray queues - inference_results_Q = ray_queue.Queue(maxsize=args.async_steps) - param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps) + # Create Ray queues - adjust maxsize for individual prompts. + total_prompts_per_step = args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + inference_results_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) + param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) evaluation_inference_results_Q = ray_queue.Queue(maxsize=1) policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer( @@ -2100,9 +2086,21 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa reward_fn = make_reward_fn(args) # Start vLLM engines to process from queues + if args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout % args.vllm_num_engines != 0: + raise ValueError( + "The number of unique prompts times the number of samples per prompt must be divisible by the number of vLLM engines." + ) + batch_size_per_engine = ( + args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + ) // args.vllm_num_engines for engine in vllm_engines: engine.process_from_queue.remote( - generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step + generation_config, + eval_generation_config, + args.eval_freq, + args.num_training_steps, + resume_training_step, + batch_size_per_engine, ) logger.info("======== ✅ vllm engines started processing from queues =========") diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 66aed2c837..1957bad7ee 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -74,10 +74,11 @@ def test_vllm_queue_system_single_prompt(self): 999, # eval_freq (avoid evaluation) 1, # num_training_steps 1, # resume_training_step + 1, # batch_size ) # Put the test prompt in the queue using PromptRequest - request = PromptRequest(prompts=[prompt_token_ids], dataset_index=0) + request = PromptRequest(prompt=prompt_token_ids, dataset_index=0) param_prompt_Q.put(request) # Get the result @@ -104,17 +105,16 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r """Test the batch splitting and accumulation logic using split_and_insert_batch and accumulate_inference_batches.""" # Mock data - simulating num_unique_prompts_rollout * num_samples_per_prompt_rollout - queries_next = [f"query_{i}" for i in range(num_unique_prompts_rollout)] + # Use lists of integers to simulate tokenized prompts + queries_next = [[i, i + 1, i + 2] for i in range(num_unique_prompts_rollout)] # Mock token IDs ground_truths_next = [f"truth_{i}" for i in range(num_unique_prompts_rollout)] datasets_next = [f"dataset_{i}" for i in range(num_unique_prompts_rollout)] pending_queries_map = {} training_step = 1 - # Create mock Ray queue for testing - param_prompt_Q = ray_queue.Queue(maxsize=vllm_num_engines) + param_prompt_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout) - # Create mock dataset indices dataset_indices = list(range(num_unique_prompts_rollout)) # Use split_and_insert_batch to split and insert data @@ -129,48 +129,49 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r param_prompt_Q, ) - # Verify that we have individual prompts in the map (not batches) self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) - # Verify that we have the expected number of items in the queue - self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines) + self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout) - # Create mock inference results to simulate vLLM engine outputs + # Create mock inference results to simulate vLLM engine outputs (individual results) mock_inference_results = [] - for batch_idx in range(vllm_num_engines): + requests_processed = [] + for i in range(num_unique_prompts_rollout): # Get the request from the queue request = param_prompt_Q.get() self.assertIsInstance(request, PromptRequest) self.assertEqual(request.training_step, training_step) - self.assertIsInstance(request.dataset_index, list) # Now expects a list of indices + self.assertIsInstance(request.dataset_index, int) # Single dataset index + self.assertIsInstance(request.prompt, list) # Single prompt as list of ints - # Create mock GenerationResult - batch_size = len(request.prompts) + # Store request for later verification + requests_processed.append(request) + + # Create mock GenerationResult for single prompt mock_result = GenerationResult( - responses=[[i] for i in range(batch_size)], # Mock token IDs - finish_reasons=["stop"] * batch_size, - masks=[[1] * 5] * batch_size, # Mock masks + responses=[[i]], # Mock token IDs for single response + finish_reasons=["stop"], + masks=[[1] * 5], # Mock masks request_info=RequestInfo( - num_calls=[0] * batch_size, - timeouts=[0] * batch_size, - tool_errors=[""] * batch_size, - tool_outputs=[""] * batch_size, - tool_runtimes=[0] * batch_size, - tool_calleds=[False] * batch_size, + num_calls=[0], + timeouts=[0], + tool_errors=[""], + tool_outputs=[""], + tool_runtimes=[0], + tool_calleds=[False], ), is_eval=False, - dataset_index=request.dataset_index, + dataset_index=[request.dataset_index], ) mock_inference_results.append(mock_result) - # Create mock inference results queue - inference_results_Q = ray_queue.Queue(maxsize=vllm_num_engines) + inference_results_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout) for result in mock_inference_results: inference_results_Q.put(result) # Use accumulate_inference_batches to combine results combined_result, combined_queries, combined_ground_truths, combined_datasets = accumulate_inference_batches( - inference_results_Q, pending_queries_map, vllm_num_engines, training_step + inference_results_Q, pending_queries_map, num_unique_prompts_rollout, training_step ) # Verify that the combined results match the original input @@ -184,7 +185,7 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) self.assertEqual(len(combined_result.masks), len(queries_next)) - # Verify that the pending_queries_map is empty after accumulation + # Verify that the test_pending_queries_map is empty after accumulation self.assertEqual(len(pending_queries_map), 0) # Verify that the inference_results_Q is empty after accumulation diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index b4d36de07c..a1560eec46 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -23,6 +23,7 @@ import ray import torch import torch.distributed +import vllm from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torch.distributed.distributed_c10d import ( @@ -57,17 +58,17 @@ class GenerationResult: masks: List[List[int]] request_info: RequestInfo is_eval: bool = False - dataset_index: Optional[int] = None + dataset_index: Optional[List[int]] = None @dataclasses.dataclass class PromptRequest: - """Container for prompt requests to vLLM.""" + """Container for a single prompt request to vLLM.""" - prompts: List[List[int]] + prompt: List[int] # Single prompt training_step: Optional[int] = None - eval_prompts: Optional[List[List[int]]] = None - dataset_index: Optional[int] = None + eval_prompts: Optional[List[List[int]]] = None # Keep as list for eval + dataset_index: Optional[int] = None # Single dataset index def ray_noset_visible_devices(env_vars=os.environ): @@ -199,80 +200,89 @@ def generate(self, *args, **kwargs): def process_from_queue( self, - sampling_params, - eval_sampling_params=None, - eval_freq=None, - num_training_steps=None, - resume_training_step=1, - ): + sampling_params: vllm.SamplingParams, + eval_sampling_params: Optional[vllm.SamplingParams] = None, + eval_freq: Optional[int] = None, + num_training_steps: Optional[int] = None, + resume_training_step: int = 1, + batch_size: Optional[int] = None, + ) -> None: """Process prompts from the queue and put results in the results queue.""" for training_step in range(resume_training_step, num_training_steps + 1): - # Get prompts from queue - request = self.prompt_queue.get() - if request is None: - break - - # Process training prompts - result = self._generate_batch(request.prompts, sampling_params, request.dataset_index) - self.results_queue.put(result) - - # Handle evaluation if needed - if ( - request.eval_prompts is not None - and eval_sampling_params is not None - and (training_step - 1) % eval_freq == 0 - ): - eval_result = self._generate_batch(request.eval_prompts, eval_sampling_params, request.dataset_index) - eval_result.is_eval = True - # Put eval results in separate queue if available - if self.eval_results_queue is not None: + prompts_batch = [] + dataset_indices_batch = [] + eval_prompts = None + + while len(prompts_batch) < batch_size: + request: PromptRequest = self.prompt_queue.get() + + prompts_batch.append(request.prompt) + if request.dataset_index is not None: + dataset_indices_batch.append(request.dataset_index) + + if eval_prompts is None and request.eval_prompts is not None: + eval_prompts = request.eval_prompts + + results = self._generate_batch(prompts_batch, sampling_params, dataset_indices_batch) + + for result in results: + self.results_queue.put(result) + + if eval_prompts is not None and eval_sampling_params is not None and (training_step - 1) % eval_freq == 0: + eval_results = self._generate_batch(eval_prompts, eval_sampling_params, None) + for eval_result in eval_results: + eval_result.is_eval = True self.eval_results_queue.put(eval_result) - else: - self.results_queue.put(eval_result) def _generate_batch( - self, prompts: List[List[int]], sampling_params, dataset_index: Optional[int] = None - ) -> GenerationResult: - """Generate responses for a batch of prompts.""" + self, prompts: List[List[int]], sampling_params, dataset_indices: Optional[List[int]] = None + ) -> List[GenerationResult]: + """Generate responses for a batch of prompts and return individual results.""" outputs = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=prompts, use_tqdm=False) - # Process outputs - response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] - finish_reasons = [out.finish_reason for output in outputs for out in output.outputs] - - if self.tool_use: - masks = [out.mask for output in outputs for out in output.outputs] - num_calls = [out.num_calls for output in outputs for out in output.outputs] - timeouts = [out.timeout for output in outputs for out in output.outputs] - tool_errors = [out.tool_error for output in outputs for out in output.outputs] - tool_outputs = [out.tool_output for output in outputs for out in output.outputs] - tool_runtimes = [out.tool_runtime for output in outputs for out in output.outputs] - tool_calleds = [out.tool_called for output in outputs for out in output.outputs] - else: - masks = [[1] * len(resp) for resp in response_ids] - num_calls = [0] * len(response_ids) - timeouts = [0] * len(response_ids) - tool_errors = [""] * len(response_ids) - tool_outputs = [""] * len(response_ids) - tool_runtimes = [0] * len(response_ids) - tool_calleds = [False] * len(response_ids) - - request_info = RequestInfo( - num_calls=num_calls, - timeouts=timeouts, - tool_errors=tool_errors, - tool_outputs=tool_outputs, - tool_runtimes=tool_runtimes, - tool_calleds=tool_calleds, - ) + # Process outputs and create individual GenerationResult objects + results = [] + for i, output in enumerate(outputs): + response_ids = [list(out.token_ids) for out in output.outputs] + finish_reasons = [out.finish_reason for out in output.outputs] + + if self.tool_use: + masks = [out.mask for out in output.outputs] + num_calls = [out.num_calls for out in output.outputs] + timeouts = [out.timeout for out in output.outputs] + tool_errors = [out.tool_error for out in output.outputs] + tool_outputs = [out.tool_output for out in output.outputs] + tool_runtimes = [out.tool_runtime for out in output.outputs] + tool_calleds = [out.tool_called for out in output.outputs] + else: + masks = [[1] * len(resp) for resp in response_ids] + num_calls = [0] * len(response_ids) + timeouts = [0] * len(response_ids) + tool_errors = [""] * len(response_ids) + tool_outputs = [""] * len(response_ids) + tool_runtimes = [0] * len(response_ids) + tool_calleds = [False] * len(response_ids) + + request_info = RequestInfo( + num_calls=num_calls, + timeouts=timeouts, + tool_errors=tool_errors, + tool_outputs=tool_outputs, + tool_runtimes=tool_runtimes, + tool_calleds=tool_calleds, + ) - return GenerationResult( - responses=response_ids, - finish_reasons=finish_reasons, - masks=masks, - request_info=request_info, - dataset_index=dataset_index, - ) + results.append( + GenerationResult( + responses=response_ids, + finish_reasons=finish_reasons, + masks=masks, + request_info=request_info, + dataset_index=[dataset_indices[i]], + ) + ) + + return results def init_process_group( self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray=False