diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index a79744ca6f..a671b1c539 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -295,9 +295,8 @@ def setup_vllm_engines( pg = ray.util.placement_group(bundles, strategy="PACK") ray.get(pg.ready()) - # Adjust queue sizes for individual prompts - param_prompt_Q = ray_queue.Queue(maxsize=100) - inference_results_Q = ray_queue.Queue(maxsize=100) + param_prompt_Q = ray_queue.Queue(maxsize=10) + inference_results_Q = ray_queue.Queue(maxsize=10) vllm_engines = vllm_utils3.create_vllm_engines( num_engines=args.vllm_num_engines, @@ -335,41 +334,6 @@ def get_batch_data( return prompts -def run_generation_batch( - inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int -) -> dict[str, Any]: - """Run generation for a batch of prompts and measure performance.""" - - start_time = time.time() - - # Insert individual prompts - for i, prompt in enumerate(prompts): - dataset_idx = batch_idx * len(prompts) + i - param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx)) - - # Collect individual results - all_responses = [] - all_finish_reasons = [] - for _ in range(len(prompts)): - result = inference_results_Q.get() - all_responses.extend(result.responses) - all_finish_reasons.extend(result.finish_reasons) - - generation_time = time.time() - start_time - - new_tokens = sum(len(response) for response in all_responses) - tokens_per_second = new_tokens / generation_time - return { - "tokens_per_second": tokens_per_second, - "generation_time": generation_time, - "num_new_tokens": new_tokens, - # dict mapping string reasons to counts. - "finish_reasons": collections.Counter(all_finish_reasons), - "response_lengths": [len(response) for response in all_responses], - "prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts - } - - def run_benchmark( dataset: datasets.Dataset, vllm_engines: list[ray.actor.ActorHandle], @@ -402,7 +366,6 @@ def run_benchmark( num_batches + 1, # eval_freq (avoid evaluation) num_batches, # num_training_steps 1, # resume_training_step - args.num_unique_prompts_rollout // args.vllm_num_engines, # batch_size ) # Wait for engines to be ready @@ -420,47 +383,37 @@ def run_benchmark( ] submission_start_time = time.time() for batch_idx in range(num_batches): - # Insert individual prompts for this batch - for i, prompt in enumerate(all_prompts[batch_idx]): - dataset_idx = batch_idx * args.num_unique_prompts_rollout + i - param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx)) + param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=all_prompts[batch_idx], dataset_index=batch_idx)) submission_time = time.time() - submission_start_time logger.info(f"All batches submitted in {submission_time:.2f}s") # Receive results and measure time for each batch last_completion_time = submission_start_time for batch_idx in range(num_batches): - # Collect individual results for this batch - batch_responses = [] - batch_finish_reasons = [] - for _ in range(args.num_unique_prompts_rollout): - result = inference_results_Q.get() - batch_responses.extend(result.responses) - batch_finish_reasons.extend(result.finish_reasons) - + result = inference_results_Q.get() completion_time = time.time() batch_generation_time = completion_time - last_completion_time last_completion_time = completion_time - # Process batch results - new_tokens = sum(len(response) for response in batch_responses) + # Process result + new_tokens = sum(len(response) for response in result.responses) tokens_per_second = new_tokens / batch_generation_time result_dict = { "tokens_per_second": tokens_per_second, "generation_time": batch_generation_time, "num_new_tokens": new_tokens, - "finish_reasons": collections.Counter(batch_finish_reasons), - "response_lengths": [len(response) for response in batch_responses], - "batch_idx": batch_idx, + "finish_reasons": collections.Counter(result.finish_reasons), + "response_lengths": [len(response) for response in result.responses], + "batch_idx": result.dataset_index, } result_dict["mfu"] = 100 * result_dict["tokens_per_second"] * flops_per_token / device_flops # We incrementally save completion lengths so even if the job dies, we still have data. - save_completion_lengths([result_dict], timestamp, batch_idx) + save_completion_lengths([result_dict], timestamp, result.dataset_index) results.append(result_dict) logger.info( - f"Batch {batch_idx + 1}: " + f"Batch {result.dataset_index + 1}: " f"{result_dict['tokens_per_second']:.2f} new tokens/sec, " f"MFU: {result_dict['mfu']:.2f}%, " f"generation time: {batch_generation_time:.2f}s" diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ac8973e02d..fe02a5bb29 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1106,42 +1106,50 @@ def get_bundle_index(rank, num_gpus_per_node): def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, pending_queries_map: dict, expected_results: int, training_step: int + inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int ) -> tuple: - """Accumulate individual inference results into a single training batch. + """Accumulate multiple inference results into a single training batch. Args: - inference_results_Q: Queue containing GenerationResult objects (one per prompt) + inference_results_Q: Queue containing GenerationResult objects pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets) - expected_results: Number of individual results to accumulate + vllm_num_engines: Number of vLLM engines (number of batches to accumulate) training_step: Current training step for error reporting Returns: Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets) """ - # Collect individual results + # Collect results from all engines results = [] all_queries = [] all_ground_truths = [] all_datasets = [] - for i in range(expected_results): - # Get individual result from queue + for batch_idx in range(vllm_num_engines): + # Get result from queue result = inference_results_Q.get() + dataset_indices = result.dataset_index - if result.dataset_index is None or len(result.dataset_index) != 1: - raise RuntimeError(f"Expected single dataset index, got {result.dataset_index}") + if dataset_indices is None: + raise RuntimeError(f"Dataset indices is None for batch {batch_idx}") - dataset_idx = result.dataset_index[0] - if dataset_idx not in pending_queries_map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + # Get corresponding queries, ground_truths, datasets for each individual prompt + batch_queries = [] + batch_ground_truths = [] + batch_datasets = [] + for dataset_idx in dataset_indices: + if dataset_idx not in pending_queries_map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) + query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) + batch_queries.append(query) + batch_ground_truths.append(ground_truth) + batch_datasets.append(dataset) results.append(result) - all_queries.append(query) - all_ground_truths.append(ground_truth) - all_datasets.append(dataset) + all_queries.extend(batch_queries) + all_ground_truths.extend(batch_ground_truths) + all_datasets.extend(batch_datasets) # Combine all results into a single GenerationResult combined_responses = [] @@ -1201,10 +1209,7 @@ def data_preparation_thread( # Accumulate results from multiple vLLM engines into a single training batch with Timer("🚀 [Data Preparation Thread] Getting response ids"): result, queries, ground_truths, datasets = accumulate_inference_batches( - inference_results_Q, - pending_queries_map, - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout, - training_step, + inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step ) # ------------------------------------------------------------------------------------------------ @@ -1672,19 +1677,29 @@ def split_and_insert_batch( param_prompt_Q, eval_prompt_token_ids=None, ): - """Insert individual prompts into queues and mapping.""" - # Insert individual prompts into the queue - for i, dataset_idx in enumerate(dataset_indices): - # Store individual prompt in the map using dataset index as key - pending_queries_map[dataset_idx] = (queries_next[i], ground_truths_next[i], datasets_next[i]) - - # Create PromptRequest for single prompt + """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" + # Split the batch over the VLLM engines. + inference_batch_size = len(queries_next) // vllm_num_engines + for batch_idx in range(vllm_num_engines): + start_idx = batch_idx * inference_batch_size + end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next) + + batch_queries = queries_next[start_idx:end_idx] + batch_ground_truths = ground_truths_next[start_idx:end_idx] + batch_datasets = datasets_next[start_idx:end_idx] + batch_dataset_indices = dataset_indices[start_idx:end_idx] + + # Store individual prompts in the map using dataset indices as keys + for i, dataset_idx in enumerate(batch_dataset_indices): + pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i]) + + # Use PromptRequest for Ray queue with batch-specific dataset_index list param_prompt_Q.put( PromptRequest( - prompt=queries_next[i], + prompts=batch_queries, training_step=training_step, eval_prompts=eval_prompt_token_ids, - dataset_index=dataset_idx, + dataset_index=batch_dataset_indices, ) ) @@ -2034,10 +2049,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa pprint([args, model_config]) - # Create Ray queues - adjust maxsize for individual prompts. - total_prompts_per_step = args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout - inference_results_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) - param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) + # Create Ray queues + inference_results_Q = ray_queue.Queue(maxsize=args.async_steps) + param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps) evaluation_inference_results_Q = ray_queue.Queue(maxsize=1) policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer( @@ -2086,21 +2100,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa reward_fn = make_reward_fn(args) # Start vLLM engines to process from queues - if args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout % args.vllm_num_engines != 0: - raise ValueError( - "The number of unique prompts times the number of samples per prompt must be divisible by the number of vLLM engines." - ) - batch_size_per_engine = ( - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout - ) // args.vllm_num_engines for engine in vllm_engines: engine.process_from_queue.remote( - generation_config, - eval_generation_config, - args.eval_freq, - args.num_training_steps, - resume_training_step, - batch_size_per_engine, + generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step ) logger.info("======== ✅ vllm engines started processing from queues =========") diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 1957bad7ee..66aed2c837 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -74,11 +74,10 @@ def test_vllm_queue_system_single_prompt(self): 999, # eval_freq (avoid evaluation) 1, # num_training_steps 1, # resume_training_step - 1, # batch_size ) # Put the test prompt in the queue using PromptRequest - request = PromptRequest(prompt=prompt_token_ids, dataset_index=0) + request = PromptRequest(prompts=[prompt_token_ids], dataset_index=0) param_prompt_Q.put(request) # Get the result @@ -105,16 +104,17 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r """Test the batch splitting and accumulation logic using split_and_insert_batch and accumulate_inference_batches.""" # Mock data - simulating num_unique_prompts_rollout * num_samples_per_prompt_rollout - # Use lists of integers to simulate tokenized prompts - queries_next = [[i, i + 1, i + 2] for i in range(num_unique_prompts_rollout)] # Mock token IDs + queries_next = [f"query_{i}" for i in range(num_unique_prompts_rollout)] ground_truths_next = [f"truth_{i}" for i in range(num_unique_prompts_rollout)] datasets_next = [f"dataset_{i}" for i in range(num_unique_prompts_rollout)] pending_queries_map = {} training_step = 1 - param_prompt_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout) + # Create mock Ray queue for testing + param_prompt_Q = ray_queue.Queue(maxsize=vllm_num_engines) + # Create mock dataset indices dataset_indices = list(range(num_unique_prompts_rollout)) # Use split_and_insert_batch to split and insert data @@ -129,49 +129,48 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r param_prompt_Q, ) + # Verify that we have individual prompts in the map (not batches) self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) - self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout) + # Verify that we have the expected number of items in the queue + self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines) - # Create mock inference results to simulate vLLM engine outputs (individual results) + # Create mock inference results to simulate vLLM engine outputs mock_inference_results = [] - requests_processed = [] - for i in range(num_unique_prompts_rollout): + for batch_idx in range(vllm_num_engines): # Get the request from the queue request = param_prompt_Q.get() self.assertIsInstance(request, PromptRequest) self.assertEqual(request.training_step, training_step) - self.assertIsInstance(request.dataset_index, int) # Single dataset index - self.assertIsInstance(request.prompt, list) # Single prompt as list of ints + self.assertIsInstance(request.dataset_index, list) # Now expects a list of indices - # Store request for later verification - requests_processed.append(request) - - # Create mock GenerationResult for single prompt + # Create mock GenerationResult + batch_size = len(request.prompts) mock_result = GenerationResult( - responses=[[i]], # Mock token IDs for single response - finish_reasons=["stop"], - masks=[[1] * 5], # Mock masks + responses=[[i] for i in range(batch_size)], # Mock token IDs + finish_reasons=["stop"] * batch_size, + masks=[[1] * 5] * batch_size, # Mock masks request_info=RequestInfo( - num_calls=[0], - timeouts=[0], - tool_errors=[""], - tool_outputs=[""], - tool_runtimes=[0], - tool_calleds=[False], + num_calls=[0] * batch_size, + timeouts=[0] * batch_size, + tool_errors=[""] * batch_size, + tool_outputs=[""] * batch_size, + tool_runtimes=[0] * batch_size, + tool_calleds=[False] * batch_size, ), is_eval=False, - dataset_index=[request.dataset_index], + dataset_index=request.dataset_index, ) mock_inference_results.append(mock_result) - inference_results_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout) + # Create mock inference results queue + inference_results_Q = ray_queue.Queue(maxsize=vllm_num_engines) for result in mock_inference_results: inference_results_Q.put(result) # Use accumulate_inference_batches to combine results combined_result, combined_queries, combined_ground_truths, combined_datasets = accumulate_inference_batches( - inference_results_Q, pending_queries_map, num_unique_prompts_rollout, training_step + inference_results_Q, pending_queries_map, vllm_num_engines, training_step ) # Verify that the combined results match the original input @@ -185,7 +184,7 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) self.assertEqual(len(combined_result.masks), len(queries_next)) - # Verify that the test_pending_queries_map is empty after accumulation + # Verify that the pending_queries_map is empty after accumulation self.assertEqual(len(pending_queries_map), 0) # Verify that the inference_results_Q is empty after accumulation diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index a1560eec46..b4d36de07c 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -23,7 +23,6 @@ import ray import torch import torch.distributed -import vllm from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from torch.distributed.distributed_c10d import ( @@ -58,17 +57,17 @@ class GenerationResult: masks: List[List[int]] request_info: RequestInfo is_eval: bool = False - dataset_index: Optional[List[int]] = None + dataset_index: Optional[int] = None @dataclasses.dataclass class PromptRequest: - """Container for a single prompt request to vLLM.""" + """Container for prompt requests to vLLM.""" - prompt: List[int] # Single prompt + prompts: List[List[int]] training_step: Optional[int] = None - eval_prompts: Optional[List[List[int]]] = None # Keep as list for eval - dataset_index: Optional[int] = None # Single dataset index + eval_prompts: Optional[List[List[int]]] = None + dataset_index: Optional[int] = None def ray_noset_visible_devices(env_vars=os.environ): @@ -200,89 +199,80 @@ def generate(self, *args, **kwargs): def process_from_queue( self, - sampling_params: vllm.SamplingParams, - eval_sampling_params: Optional[vllm.SamplingParams] = None, - eval_freq: Optional[int] = None, - num_training_steps: Optional[int] = None, - resume_training_step: int = 1, - batch_size: Optional[int] = None, - ) -> None: + sampling_params, + eval_sampling_params=None, + eval_freq=None, + num_training_steps=None, + resume_training_step=1, + ): """Process prompts from the queue and put results in the results queue.""" for training_step in range(resume_training_step, num_training_steps + 1): - prompts_batch = [] - dataset_indices_batch = [] - eval_prompts = None - - while len(prompts_batch) < batch_size: - request: PromptRequest = self.prompt_queue.get() - - prompts_batch.append(request.prompt) - if request.dataset_index is not None: - dataset_indices_batch.append(request.dataset_index) - - if eval_prompts is None and request.eval_prompts is not None: - eval_prompts = request.eval_prompts - - results = self._generate_batch(prompts_batch, sampling_params, dataset_indices_batch) - - for result in results: - self.results_queue.put(result) - - if eval_prompts is not None and eval_sampling_params is not None and (training_step - 1) % eval_freq == 0: - eval_results = self._generate_batch(eval_prompts, eval_sampling_params, None) - for eval_result in eval_results: - eval_result.is_eval = True + # Get prompts from queue + request = self.prompt_queue.get() + if request is None: + break + + # Process training prompts + result = self._generate_batch(request.prompts, sampling_params, request.dataset_index) + self.results_queue.put(result) + + # Handle evaluation if needed + if ( + request.eval_prompts is not None + and eval_sampling_params is not None + and (training_step - 1) % eval_freq == 0 + ): + eval_result = self._generate_batch(request.eval_prompts, eval_sampling_params, request.dataset_index) + eval_result.is_eval = True + # Put eval results in separate queue if available + if self.eval_results_queue is not None: self.eval_results_queue.put(eval_result) + else: + self.results_queue.put(eval_result) def _generate_batch( - self, prompts: List[List[int]], sampling_params, dataset_indices: Optional[List[int]] = None - ) -> List[GenerationResult]: - """Generate responses for a batch of prompts and return individual results.""" + self, prompts: List[List[int]], sampling_params, dataset_index: Optional[int] = None + ) -> GenerationResult: + """Generate responses for a batch of prompts.""" outputs = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=prompts, use_tqdm=False) - # Process outputs and create individual GenerationResult objects - results = [] - for i, output in enumerate(outputs): - response_ids = [list(out.token_ids) for out in output.outputs] - finish_reasons = [out.finish_reason for out in output.outputs] - - if self.tool_use: - masks = [out.mask for out in output.outputs] - num_calls = [out.num_calls for out in output.outputs] - timeouts = [out.timeout for out in output.outputs] - tool_errors = [out.tool_error for out in output.outputs] - tool_outputs = [out.tool_output for out in output.outputs] - tool_runtimes = [out.tool_runtime for out in output.outputs] - tool_calleds = [out.tool_called for out in output.outputs] - else: - masks = [[1] * len(resp) for resp in response_ids] - num_calls = [0] * len(response_ids) - timeouts = [0] * len(response_ids) - tool_errors = [""] * len(response_ids) - tool_outputs = [""] * len(response_ids) - tool_runtimes = [0] * len(response_ids) - tool_calleds = [False] * len(response_ids) - - request_info = RequestInfo( - num_calls=num_calls, - timeouts=timeouts, - tool_errors=tool_errors, - tool_outputs=tool_outputs, - tool_runtimes=tool_runtimes, - tool_calleds=tool_calleds, - ) - - results.append( - GenerationResult( - responses=response_ids, - finish_reasons=finish_reasons, - masks=masks, - request_info=request_info, - dataset_index=[dataset_indices[i]], - ) - ) + # Process outputs + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + finish_reasons = [out.finish_reason for output in outputs for out in output.outputs] + + if self.tool_use: + masks = [out.mask for output in outputs for out in output.outputs] + num_calls = [out.num_calls for output in outputs for out in output.outputs] + timeouts = [out.timeout for output in outputs for out in output.outputs] + tool_errors = [out.tool_error for output in outputs for out in output.outputs] + tool_outputs = [out.tool_output for output in outputs for out in output.outputs] + tool_runtimes = [out.tool_runtime for output in outputs for out in output.outputs] + tool_calleds = [out.tool_called for output in outputs for out in output.outputs] + else: + masks = [[1] * len(resp) for resp in response_ids] + num_calls = [0] * len(response_ids) + timeouts = [0] * len(response_ids) + tool_errors = [""] * len(response_ids) + tool_outputs = [""] * len(response_ids) + tool_runtimes = [0] * len(response_ids) + tool_calleds = [False] * len(response_ids) + + request_info = RequestInfo( + num_calls=num_calls, + timeouts=timeouts, + tool_errors=tool_errors, + tool_outputs=tool_outputs, + tool_runtimes=tool_runtimes, + tool_calleds=tool_calleds, + ) - return results + return GenerationResult( + responses=response_ids, + finish_reasons=finish_reasons, + masks=masks, + request_info=request_info, + dataset_index=dataset_index, + ) def init_process_group( self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray=False