-
Notifications
You must be signed in to change notification settings - Fork 493
Now, we run individual prompts through the queue. #796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fd8aff1
03593a2
fe11a69
a5f5c79
bd8ff44
736833a
85486c6
1c73a03
5103503
f883a22
45161ff
01965e9
bd9f6a2
05e35bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1106,50 +1106,42 @@ def get_bundle_index(rank, num_gpus_per_node): | |
|
|
||
|
|
||
| def accumulate_inference_batches( | ||
| inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int | ||
| inference_results_Q: ray_queue.Queue, pending_queries_map: dict, expected_results: int, training_step: int | ||
| ) -> tuple: | ||
| """Accumulate multiple inference results into a single training batch. | ||
| """Accumulate individual inference results into a single training batch. | ||
|
|
||
| Args: | ||
| inference_results_Q: Queue containing GenerationResult objects | ||
| inference_results_Q: Queue containing GenerationResult objects (one per prompt) | ||
| pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets) | ||
| vllm_num_engines: Number of vLLM engines (number of batches to accumulate) | ||
| expected_results: Number of individual results to accumulate | ||
| training_step: Current training step for error reporting | ||
|
|
||
| Returns: | ||
| Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets) | ||
| """ | ||
| # Collect results from all engines | ||
| # Collect individual results | ||
| results = [] | ||
| all_queries = [] | ||
| all_ground_truths = [] | ||
| all_datasets = [] | ||
|
|
||
| for batch_idx in range(vllm_num_engines): | ||
| # Get result from queue | ||
| for i in range(expected_results): | ||
| # Get individual result from queue | ||
| result = inference_results_Q.get() | ||
| dataset_indices = result.dataset_index | ||
|
|
||
| if dataset_indices is None: | ||
| raise RuntimeError(f"Dataset indices is None for batch {batch_idx}") | ||
| if result.dataset_index is None or len(result.dataset_index) != 1: | ||
| raise RuntimeError(f"Expected single dataset index, got {result.dataset_index}") | ||
|
|
||
| # Get corresponding queries, ground_truths, datasets for each individual prompt | ||
| batch_queries = [] | ||
| batch_ground_truths = [] | ||
| batch_datasets = [] | ||
| for dataset_idx in dataset_indices: | ||
| if dataset_idx not in pending_queries_map: | ||
| raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") | ||
| dataset_idx = result.dataset_index[0] | ||
| if dataset_idx not in pending_queries_map: | ||
| raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") | ||
|
|
||
| query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) | ||
| batch_queries.append(query) | ||
| batch_ground_truths.append(ground_truth) | ||
| batch_datasets.append(dataset) | ||
| query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) | ||
|
|
||
| results.append(result) | ||
| all_queries.extend(batch_queries) | ||
| all_ground_truths.extend(batch_ground_truths) | ||
| all_datasets.extend(batch_datasets) | ||
| all_queries.append(query) | ||
| all_ground_truths.append(ground_truth) | ||
| all_datasets.append(dataset) | ||
|
|
||
| # Combine all results into a single GenerationResult | ||
| combined_responses = [] | ||
|
|
@@ -1209,7 +1201,10 @@ def data_preparation_thread( | |
| # Accumulate results from multiple vLLM engines into a single training batch | ||
| with Timer("🚀 [Data Preparation Thread] Getting response ids"): | ||
| result, queries, ground_truths, datasets = accumulate_inference_batches( | ||
| inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step | ||
| inference_results_Q, | ||
| pending_queries_map, | ||
| args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout, | ||
| training_step, | ||
| ) | ||
|
|
||
| # ------------------------------------------------------------------------------------------------ | ||
|
|
@@ -1677,29 +1672,19 @@ def split_and_insert_batch( | |
| param_prompt_Q, | ||
| eval_prompt_token_ids=None, | ||
| ): | ||
| """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" | ||
| # Split the batch over the VLLM engines. | ||
| inference_batch_size = len(queries_next) // vllm_num_engines | ||
| for batch_idx in range(vllm_num_engines): | ||
| start_idx = batch_idx * inference_batch_size | ||
| end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next) | ||
|
|
||
| batch_queries = queries_next[start_idx:end_idx] | ||
| batch_ground_truths = ground_truths_next[start_idx:end_idx] | ||
| batch_datasets = datasets_next[start_idx:end_idx] | ||
| batch_dataset_indices = dataset_indices[start_idx:end_idx] | ||
|
|
||
| # Store individual prompts in the map using dataset indices as keys | ||
| for i, dataset_idx in enumerate(batch_dataset_indices): | ||
| pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i]) | ||
|
|
||
| # Use PromptRequest for Ray queue with batch-specific dataset_index list | ||
| """Insert individual prompts into queues and mapping.""" | ||
| # Insert individual prompts into the queue | ||
| for i, dataset_idx in enumerate(dataset_indices): | ||
| # Store individual prompt in the map using dataset index as key | ||
| pending_queries_map[dataset_idx] = (queries_next[i], ground_truths_next[i], datasets_next[i]) | ||
|
|
||
| # Create PromptRequest for single prompt | ||
| param_prompt_Q.put( | ||
| PromptRequest( | ||
| prompts=batch_queries, | ||
| prompt=queries_next[i], | ||
| training_step=training_step, | ||
| eval_prompts=eval_prompt_token_ids, | ||
| dataset_index=batch_dataset_indices, | ||
| dataset_index=dataset_idx, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -2049,9 +2034,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa | |
|
|
||
| pprint([args, model_config]) | ||
|
|
||
| # Create Ray queues | ||
| inference_results_Q = ray_queue.Queue(maxsize=args.async_steps) | ||
| param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps) | ||
| # Create Ray queues - adjust maxsize for individual prompts. | ||
| total_prompts_per_step = args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout | ||
| inference_results_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) | ||
| param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step) | ||
| evaluation_inference_results_Q = ray_queue.Queue(maxsize=1) | ||
|
|
||
| policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer( | ||
|
|
@@ -2100,9 +2086,21 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa | |
| reward_fn = make_reward_fn(args) | ||
|
|
||
| # Start vLLM engines to process from queues | ||
| if args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout % args.vllm_num_engines != 0: | ||
| raise ValueError( | ||
| "The number of unique prompts times the number of samples per prompt must be divisible by the number of vLLM engines." | ||
| ) | ||
| batch_size_per_engine = ( | ||
| args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout | ||
| ) // args.vllm_num_engines | ||
|
Comment on lines
+2093
to
+2095
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we know if this is equally divisible? and is it alright if its not?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be divisible; I changed the code to raise a |
||
| for engine in vllm_engines: | ||
| engine.process_from_queue.remote( | ||
| generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step | ||
| generation_config, | ||
| eval_generation_config, | ||
| args.eval_freq, | ||
| args.num_training_steps, | ||
| resume_training_step, | ||
| batch_size_per_engine, | ||
| ) | ||
| logger.info("======== ✅ vllm engines started processing from queues =========") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that at the end of an epoch or something you will have fewer than this number of results? or are we dropping the last batch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShufflingIteratorguarantees that they're all the same size: