diff --git a/open_instruct/benchmark_generators.py b/open_instruct/benchmark_generators.py index 075758064e..925b7e18d6 100644 --- a/open_instruct/benchmark_generators.py +++ b/open_instruct/benchmark_generators.py @@ -14,8 +14,6 @@ import json import logging import pathlib -import queue -import threading import time from typing import Any @@ -26,6 +24,7 @@ import torch.utils.flop_counter import transformers import vllm +from ray.util import queue as ray_queue from open_instruct import dataset_transformation, grpo_fast, model_utils, utils, vllm_utils3 @@ -204,8 +203,8 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation def setup_vllm_engines( args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480 -) -> list[ray.actor.ActorHandle]: - """Set up vLLM engines.""" +) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]: + """Set up vLLM engines and queues.""" logger.info("Setting up vLLM engines...") # Initialize Ray @@ -213,12 +212,13 @@ def setup_vllm_engines( ray.shutdown() ray.init(num_cpus=4, num_gpus=1, ignore_reinit_error=True, runtime_env={"excludes": ["/benchmark_cache/"]}) - # Create placement group for multiple engines bundles = [{"GPU": 1, "CPU": 1} for _ in range(args.vllm_num_engines)] pg = ray.util.placement_group(bundles, strategy="PACK") ray.get(pg.ready()) - # Create vLLM engines + param_prompt_Q = ray_queue.Queue(maxsize=10) + inference_results_Q = ray_queue.Queue(maxsize=10) + vllm_engines = vllm_utils3.create_vllm_engines( num_engines=args.vllm_num_engines, tensor_parallel_size=args.vllm_tensor_parallel_size, @@ -234,11 +234,13 @@ def setup_vllm_engines( pg=pg, tools={}, max_tool_calls=[0], + prompt_queue=param_prompt_Q, + results_queue=inference_results_Q, ) logger.info("vLLM engines ready") - return vllm_engines + return vllm_engines, param_prompt_Q, inference_results_Q def get_batch_data( @@ -254,27 +256,24 @@ def get_batch_data( def run_generation_batch( - inference_results_Q: queue.Queue, param_prompt_Q: queue.Queue, prompts: list[list[int]] + inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int ) -> dict[str, Any]: """Run generation for a batch of prompts and measure performance.""" start_time = time.time() - param_prompt_Q.put((None, prompts)) + param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=prompts, dataset_index=batch_idx)) result = inference_results_Q.get() generation_time = time.time() - start_time - response_ids, finish_reasons, _, _ = result - collated_finish_reasons = collections.Counter(finish_reasons) - - new_tokens = sum(len(response) for response in response_ids) + new_tokens = sum(len(response) for response in result.responses) tokens_per_second = new_tokens / generation_time return { "tokens_per_second": tokens_per_second, "generation_time": generation_time, "num_new_tokens": new_tokens, # dict mapping string reasons to counts. - "finish_reasons": collated_finish_reasons, - "response_lengths": [len(response) for response in response_ids], + "finish_reasons": collections.Counter(result.finish_reasons), + "response_lengths": [len(response) for response in result.responses], "prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts } @@ -282,6 +281,8 @@ def run_generation_batch( def run_benchmark( dataset: datasets.Dataset, vllm_engines: list[ray.actor.ActorHandle], + param_prompt_Q: ray_queue.Queue, + inference_results_Q: ray_queue.Queue, args: grpo_fast.Args, model_config: model_utils.ModelConfig, timestamp: int, @@ -290,11 +291,6 @@ def run_benchmark( """Run the full benchmark.""" logger.info(f"Starting benchmark with {num_batches} batches of size {args.num_unique_prompts_rollout}") - # Create persistent queues - inference_results_Q = queue.Queue(maxsize=10) - param_prompt_Q = queue.Queue(maxsize=10) - evaluation_inference_results_Q = queue.Queue(maxsize=10) - # Create sampling parameters with 'n' for multiple samples per prompt generation_config = vllm.SamplingParams( temperature=args.temperature, @@ -311,26 +307,17 @@ def run_benchmark( # Unclear why we need this. We didn't need it before torch 2.7.0. free_all_gpu_memory() - # Start persistent vLLM generation thread - def wrapped_vllm_generate_thread() -> None: - grpo_fast.vllm_generate_thread( - vllm_engines, + # Start vLLM engines to process from queues + for engine in vllm_engines: + engine.process_from_queue.remote( generation_config, eval_generation_config, - inference_results_Q, - param_prompt_Q, - num_batches, # num_training_steps - None, # eval_prompt_token_ids - evaluation_inference_results_Q, num_batches + 1, # eval_freq (avoid evaluation) + num_batches, # num_training_steps 1, # resume_training_step - False, # tool_use ) - thread = threading.Thread(target=wrapped_vllm_generate_thread) - thread.start() - - # Wait for thread to be ready + # Wait for engines to be ready time.sleep(0.1) results = [] @@ -343,7 +330,7 @@ def wrapped_vllm_generate_thread() -> None: prompts = get_batch_data(dataset, args.num_unique_prompts_rollout, batch_idx) # Run generation - result = run_generation_batch(inference_results_Q, param_prompt_Q, prompts) + result = run_generation_batch(inference_results_Q, param_prompt_Q, prompts, batch_idx) result["mfu"] = 100 * result["tokens_per_second"] * flops_per_token / device_flops # We incrementally save completion lengths so even if the job dies, we still have data. save_completion_lengths([result], timestamp, batch_idx) @@ -360,9 +347,6 @@ def wrapped_vllm_generate_thread() -> None: # Send stop signal param_prompt_Q.put(None) - # Wait for thread to finish - thread.join(timeout=10) - print_summary(results, total_time, args, model_config) @@ -477,12 +461,12 @@ def main() -> None: DATA_DIR.mkdir(parents=True, exist_ok=True) dataset = setup_dataset(args, tokenizer_config) - vllm_engines = setup_vllm_engines(args, model_config) + vllm_engines, param_prompt_Q, inference_results_Q = setup_vllm_engines(args, model_config) # Create the timestamp here so we use it for both filenames. timestamp = int(time.time()) save_config(args, tokenizer_config, model_config, timestamp) - run_benchmark(dataset, vllm_engines, args, model_config, timestamp) + run_benchmark(dataset, vllm_engines, param_prompt_Q, inference_results_Q, args, model_config, timestamp) cleanup(vllm_engines) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 0e53f90640..3c89bfab10 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -67,6 +67,7 @@ import wandb from huggingface_hub import HfApi from peft import PeftModel, get_peft_model_state_dict +from ray.util import queue as ray_queue from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint @@ -120,7 +121,14 @@ maybe_use_ai2_wandb_entity, sync_gs_bucket, ) -from open_instruct.vllm_utils3 import LLMRayActor, create_vllm_engines, init_process_group +from open_instruct.vllm_utils3 import ( + GenerationResult, + LLMRayActor, + PromptRequest, + RequestInfo, + create_vllm_engines, + init_process_group, +) # Setup logging with filename and line number format logging.basicConfig( @@ -1095,97 +1103,112 @@ def get_bundle_index(rank, num_gpus_per_node): self.models.append(worker_policy) -def vllm_generate_thread( - vllm_engines: List[ray.actor.ActorHandle], - generation_config: SamplingParams, - eval_generation_config: SamplingParams, - inference_results_Q: Queue, - param_prompt_Q: Queue, - num_training_steps: int, - eval_prompt_token_ids: Optional[List[int]], - evaluation_inference_results_Q: Queue, - local_eval_freq: int, - resume_training_step: int = 1, - tool_use: bool = False, -): - def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams): - # Split queries between engines - queries_per_engine = (len(prompts) + len(vllm_engines) - 1) // len(vllm_engines) - split_queries = [prompts[i : i + queries_per_engine] for i in range(0, len(prompts), queries_per_engine)] - # Generate responses in parallel across engines - futures = [ - vllm_engine.generate.remote(sampling_params=sampling_params, prompt_token_ids=queries, use_tqdm=False) - for vllm_engine, queries in zip(vllm_engines, split_queries) - ] - # Gather all responses - all_outputs = ray.get(futures) - response_ids = [] - finish_reasons = [] # either "stop" or "length" - masks = [] - num_calls = [] - timeouts = [] - tool_errors = [] - tool_outputs = [] - tool_runtimes = [] - tool_calleds = [] - for outputs in all_outputs: - response_ids.extend([list(out.token_ids) for output in outputs for out in output.outputs]) - finish_reasons.extend([out.finish_reason for output in outputs for out in output.outputs]) - if tool_use: - masks.extend([out.mask for output in outputs for out in output.outputs]) - num_calls.extend([out.num_calls for output in outputs for out in output.outputs]) - timeouts.extend([out.timeout for output in outputs for out in output.outputs]) - tool_errors.extend([out.tool_error for output in outputs for out in output.outputs]) - tool_outputs.extend([out.tool_output for output in outputs for out in output.outputs]) - tool_runtimes.extend([out.tool_runtime for output in outputs for out in output.outputs]) - tool_calleds.extend([out.tool_called for output in outputs for out in output.outputs]) - # if not using the tool, mask is all 1s - if not tool_use: - masks = [[1] * len(response_ids[i]) for i in range(len(response_ids))] - num_calls = [0] * len(response_ids) - timeouts = [0] * len(response_ids) - tool_errors = [""] * len(response_ids) - tool_outputs = [""] * len(response_ids) - tool_runtimes = [0] * len(response_ids) - tool_calleds = [False] * len(response_ids) - return ( - response_ids, - finish_reasons, - masks, - (num_calls, timeouts, tool_errors, tool_outputs, tool_runtimes, tool_calleds), - ) +def accumulate_inference_batches( + inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int +) -> tuple: + """Accumulate multiple inference results into a single training batch. - for training_step in range(resume_training_step, num_training_steps + 1): - items = param_prompt_Q.get() - if items is None: - break - _, g_queries_list = items + Args: + inference_results_Q: Queue containing GenerationResult objects + pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets) + vllm_num_engines: Number of vLLM engines (number of batches to accumulate) + training_step: Current training step for error reporting - with Timer("🔥 Generation time"): - response_ids, finish_reasons, masks, info = generate_with_engines(g_queries_list, generation_config) - inference_results_Q.put((response_ids, finish_reasons, masks, info)) + Returns: + Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets) + """ + # Collect results from all engines + results = [] + all_queries = [] + all_ground_truths = [] + all_datasets = [] + + for batch_idx in range(vllm_num_engines): + # Get result from queue + result = inference_results_Q.get() + dataset_indices = result.dataset_index + + if dataset_indices is None: + raise RuntimeError(f"Dataset indices is None for batch {batch_idx}") + + # Get corresponding queries, ground_truths, datasets for each individual prompt + batch_queries = [] + batch_ground_truths = [] + batch_datasets = [] + for dataset_idx in dataset_indices: + if dataset_idx not in pending_queries_map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + + query, ground_truth, dataset = pending_queries_map.pop(dataset_idx) + batch_queries.append(query) + batch_ground_truths.append(ground_truth) + batch_datasets.append(dataset) + + results.append(result) + all_queries.extend(batch_queries) + all_ground_truths.extend(batch_ground_truths) + all_datasets.extend(batch_datasets) + + # Combine all results into a single GenerationResult + combined_responses = [] + combined_finish_reasons = [] + combined_masks = [] + combined_num_calls = [] + combined_timeouts = [] + combined_tool_errors = [] + combined_tool_outputs = [] + combined_tool_runtimes = [] + combined_tool_calleds = [] + + for result in results: + combined_responses.extend(result.responses) + combined_finish_reasons.extend(result.finish_reasons) + combined_masks.extend(result.masks) + combined_num_calls.extend(result.request_info.num_calls) + combined_timeouts.extend(result.request_info.timeouts) + combined_tool_errors.extend(result.request_info.tool_errors) + combined_tool_outputs.extend(result.request_info.tool_outputs) + combined_tool_runtimes.extend(result.request_info.tool_runtimes) + combined_tool_calleds.extend(result.request_info.tool_calleds) + + # Create combined RequestInfo + combined_request_info = RequestInfo( + num_calls=combined_num_calls, + timeouts=combined_timeouts, + tool_errors=combined_tool_errors, + tool_outputs=combined_tool_outputs, + tool_runtimes=combined_tool_runtimes, + tool_calleds=combined_tool_calleds, + ) - # Evaluate the model - if eval_prompt_token_ids is not None and (training_step - 1) % local_eval_freq == 0: - response_ids, finish_reasons, masks, info = generate_with_engines( - eval_prompt_token_ids, eval_generation_config - ) - evaluation_inference_results_Q.put((response_ids, finish_reasons, masks, info)) + # Create combined GenerationResult + combined_result = GenerationResult( + responses=combined_responses, + finish_reasons=combined_finish_reasons, + masks=combined_masks, + request_info=combined_request_info, + is_eval=results[0].is_eval, # All should have the same is_eval value + dataset_index=None, # Not meaningful for combined result + ) + + return combined_result, all_queries, all_ground_truths, all_datasets def data_preparation_thread( reward_fn: Callable, - inference_results_Q: Queue, + inference_results_Q: ray_queue.Queue, # Ray queue packed_sequences_Q: Queue, - queries_prompt_Q: Queue, + pending_queries_map: dict, args: Args, tokenizer: PreTrainedTokenizer, num_training_steps: int, ): for training_step in range(1, num_training_steps + 1): - # Get next batch of prompts and responses - items = queries_prompt_Q.get() - queries, ground_truths, datasets = items + # Accumulate results from multiple vLLM engines into a single training batch + with Timer("🚀 [Data Preparation Thread] Getting response ids"): + result, queries, ground_truths, datasets = accumulate_inference_batches( + inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step + ) # ------------------------------------------------------------------------------------------------ # Pack sequences @@ -1193,33 +1216,41 @@ def data_preparation_thread( queries = [item for item in queries for _ in range(args.num_samples_per_prompt_rollout)] ground_truths = [item for item in ground_truths for _ in range(args.num_samples_per_prompt_rollout)] datasets = [item for item in datasets for _ in range(args.num_samples_per_prompt_rollout)] - with Timer("🚀 [Data Preparation Thread] Getting response ids"): - responses, finish_reasons, masks, infos = inference_results_Q.get() - num_calls, timeouts, tool_errors, tool_outputs, tool_runtimes, tool_calleds = infos good_outputs = [ - len(tool_outputs[i]) > 0 and tool_calleds[i] and not timeouts[i] and not tool_errors[i] - for i in range(len(tool_outputs)) + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) ] - for i in range(len(finish_reasons)): + for i in range(len(result.finish_reasons)): # edge case: sometimes it outputs eos immediately, and we get an empty response # in that case, we need to add the eos token to the response # note that this also adds eos to the end of reponses that stopped for other reasons. - if finish_reasons[i] == "stop" and ( - len(responses[i]) == 0 or responses[i][-1] != tokenizer.eos_token_id + if result.finish_reasons[i] == "stop" and ( + len(result.responses[i]) == 0 or result.responses[i][-1] != tokenizer.eos_token_id ): - responses[i].append(tokenizer.eos_token_id) - masks[i].append(1) # never mask the eos token for now? + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) # never mask the eos token for now? with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True): - decoded_responses = tokenizer.batch_decode(responses, skip_special_tokens=True) + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) decoded_queries = tokenizer.batch_decode(queries, skip_special_tokens=True) decoded_queries = [extract_user_query(query) for query in decoded_queries] - stop_rate = sum(int(finish_reason == "stop") for finish_reason in finish_reasons) / len(finish_reasons) + stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( + result.finish_reasons + ) with Timer("💰 [Data Preparation Thread] Calculating rewards and advantages"): scores, reward_metrics = asyncio.run( reward_fn( - responses, decoded_responses, ground_truths, datasets, finish_reasons, infos, decoded_queries + result.responses, + decoded_responses, + ground_truths, + datasets, + result.finish_reasons, + result.request_info, + decoded_queries, ) ) scores = np.array(scores) @@ -1251,12 +1282,12 @@ def data_preparation_thread( non_zero_gradient_index = np.where(expanded_mask)[0] advantages = advantages[non_zero_gradient_index] scores = scores[non_zero_gradient_index] - responses = [responses[i] for i in non_zero_gradient_index] - masks = [masks[i] for i in non_zero_gradient_index] + responses = [result.responses[i] for i in non_zero_gradient_index] + masks = [result.masks[i] for i in non_zero_gradient_index] queries = [queries[i] for i in non_zero_gradient_index] ground_truths = [ground_truths[i] for i in non_zero_gradient_index] datasets = [datasets[i] for i in non_zero_gradient_index] - finish_reasons = [finish_reasons[i] for i in non_zero_gradient_index] + finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index] if args.mask_truncated_completions: stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"]) scores = scores[stop_idxes] @@ -1400,12 +1431,12 @@ def data_preparation_thread( "val/advantages_min": advantages.min(), "val/advantages_max": advantages.max(), "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(num_calls).mean(), - "val/timeouts_rate": np.array(timeouts).mean(), - "val/tool_errors_rate": np.array([len(item) > 0 for item in tool_errors]).mean(), + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array([len(item) > 0 for item in result.request_info.tool_errors]).mean(), "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(tool_calleds).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), **reward_metrics, } @@ -1554,6 +1585,9 @@ def create_model_and_optimizer( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + evaluation_inference_results_Q: ray_queue.Queue, ) -> tuple[ModelGroup, list[LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" # Ray initialization @@ -1594,7 +1628,7 @@ def create_model_and_optimizer( else: raise ValueError(f"Unknown tool: {tool}") - # Create vLLM engines + # Create vLLM engines with queues vllm_engines = create_vllm_engines( args.vllm_num_engines, args.vllm_tensor_parallel_size, @@ -1610,6 +1644,9 @@ def create_model_and_optimizer( pg=pg if args.single_gpu_mode else None, tools=tool_objects, max_tool_calls=args.max_tool_calls, + prompt_queue=param_prompt_Q, + results_queue=inference_results_Q, + eval_results_queue=evaluation_inference_results_Q, ) resume_training_step = ray.get(inits)[0] + 1 @@ -1627,21 +1664,62 @@ def create_model_and_optimizer( return policy_group, vllm_engines, tool_objects, resume_training_step, episode +def split_and_insert_batch( + queries_next, + ground_truths_next, + datasets_next, + dataset_indices, + training_step, + vllm_num_engines, + pending_queries_map, + param_prompt_Q, + eval_prompt_token_ids=None, +): + """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" + # Split the batch over the VLLM engines. + inference_batch_size = len(queries_next) // vllm_num_engines + for batch_idx in range(vllm_num_engines): + start_idx = batch_idx * inference_batch_size + end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next) + + batch_queries = queries_next[start_idx:end_idx] + batch_ground_truths = ground_truths_next[start_idx:end_idx] + batch_datasets = datasets_next[start_idx:end_idx] + batch_dataset_indices = dataset_indices[start_idx:end_idx] + + # Store individual prompts in the map using dataset indices as keys + for i, dataset_idx in enumerate(batch_dataset_indices): + pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i]) + + # Use PromptRequest for Ray queue with batch-specific dataset_index list + param_prompt_Q.put( + PromptRequest( + prompts=batch_queries, + training_step=training_step, + eval_prompts=eval_prompt_token_ids, + dataset_index=batch_dataset_indices, + ) + ) + + def sync_weights_and_prepare_prompts( training_step: int, args: Args, train_dataset, iter_dataloader, policy_group: ModelGroup, - queries_prompt_Q: Queue, - param_prompt_Q: Queue, + pending_queries_map: dict, + param_prompt_Q: ray_queue.Queue, # Ray queue queries_next=None, ground_truths_next=None, datasets_next=None, + dataset_indices=None, + eval_prompt_token_ids=None, ): """Sync weights and send the next batch of prompts to vLLM.""" if training_step != 1: - data_next = train_dataset[next(iter_dataloader)] + dataset_indices = next(iter_dataloader) + data_next = train_dataset[dataset_indices] queries_next = data_next[INPUT_IDS_PROMPT_KEY] ground_truths_next = data_next[GROUND_TRUTHS_KEY] datasets_next = data_next[DATASET_SOURCE_KEY] @@ -1652,15 +1730,18 @@ def sync_weights_and_prepare_prompts( ): ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models]) - if args.async_mode: - queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) - param_prompt_Q.put((None, queries_next)) - else: - if training_step != 1: - queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) - param_prompt_Q.put((None, queries_next)) - - return queries_next, ground_truths_next, datasets_next + if args.async_mode or training_step != 1: + split_and_insert_batch( + queries_next, + ground_truths_next, + datasets_next, + dataset_indices, + training_step, + args.vllm_num_engines, + pending_queries_map, + param_prompt_Q, + eval_prompt_token_ids, + ) def load_data_from_packing_thread(packed_sequences_Q: Queue, num_total_tokens: int): @@ -1769,7 +1850,7 @@ def one_training_step( def maybe_evaluate( args: Args, training_step: int, - evaluation_inference_results_Q: Queue, + evaluation_inference_results_Q: ray_queue.Queue, # Ray queue tokenizer, eval_prompt_token_ids, eval_ground_truths, @@ -1782,25 +1863,26 @@ def maybe_evaluate( try: # timeout 0.01 if this is the last training step or we're not evaluating # otherwise, wait to get the last evaluation generations (long timeout just in case) - timeout = 0.01 if (training_step < args.num_training_steps or args.local_eval_freq < 0) else 100 - eval_responses, eval_finish_reasons, masks, eval_infos = evaluation_inference_results_Q.get(timeout=timeout) + timeout = 0.01 if (training_step < args.num_training_steps or args.eval_freq < 0) else 100 + eval_result = evaluation_inference_results_Q.get(timeout=timeout) + logger.info("[Main Thread] 📊 Evaluation responses received") - eval_sequence_lengths = np.array([len(response) for response in eval_responses]) - eval_decoded_responses = tokenizer.batch_decode(eval_responses, skip_special_tokens=True) - eval_stop_rate = sum(int(finish_reason == "stop") for finish_reason in eval_finish_reasons) / len( - eval_finish_reasons + eval_sequence_lengths = np.array([len(response) for response in eval_result.responses]) + eval_decoded_responses = tokenizer.batch_decode(eval_result.responses, skip_special_tokens=True) + eval_stop_rate = sum(int(finish_reason == "stop") for finish_reason in eval_result.finish_reasons) / len( + eval_result.finish_reasons ) # get and log evaluation metrics eval_scores, eval_reward_metrics = asyncio.run( reward_fn( - eval_responses, + eval_result.responses, eval_decoded_responses, eval_ground_truths, eval_dataset_names, - eval_finish_reasons, - eval_infos, + eval_result.finish_reasons, + eval_result.request_info, ) ) eval_reward_metrics = {f"eval/{key}": val for key, val in eval_reward_metrics.items()} @@ -1965,8 +2047,21 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa pprint([args, model_config]) + # Create Ray queues + inference_results_Q = ray_queue.Queue(maxsize=args.async_steps) + param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps) + evaluation_inference_results_Q = ray_queue.Queue(maxsize=1) + policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer( - args, tc, model_config, beaker_config, wandb_url, tokenizer + args, + tc, + model_config, + beaker_config, + wandb_url, + tokenizer, + inference_results_Q, + param_prompt_Q, + evaluation_inference_results_Q, ) # Setup training @@ -1989,11 +2084,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa train_dataset_idxs = np.arange(len(train_dataset)) iter_dataloader = ShufflingIterator(train_dataset_idxs, args.num_unique_prompts_rollout, seed=args.seed) - inference_results_Q = Queue(maxsize=args.async_steps) - param_prompt_Q = Queue(maxsize=args.async_steps) - evaluation_inference_results_Q = Queue(maxsize=1) - packed_sequences_Q = Queue(maxsize=args.async_steps) - queries_prompt_Q = Queue(maxsize=args.async_steps) + # Create additional queues (main queues already created above) + packed_sequences_Q = Queue(maxsize=args.async_steps) # Keep this as threading Queue for now + pending_queries_map = {} # Map dataset_index -> (queries, ground_truths, datasets) eval_prompt_token_ids = None eval_ground_truths = None @@ -2003,25 +2096,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa eval_ground_truths = eval_dataset[:num_eval_samples][GROUND_TRUTHS_KEY] eval_dataset_names = eval_dataset[:num_eval_samples][DATASET_SOURCE_KEY] reward_fn = make_reward_fn(args) - generate_thread = threading.Thread( - target=vllm_generate_thread, - args=( - vllm_engines, - generation_config, - eval_generation_config, - inference_results_Q, - param_prompt_Q, - args.num_training_steps, - eval_prompt_token_ids, - evaluation_inference_results_Q, - args.local_eval_freq, - resume_training_step, - args.tool_use, - ), - ) - generate_thread.start() - logger.info("======== ✅ vllm generate thread starts =========") - reward_fn = make_reward_fn(args) + + # Start vLLM engines to process from queues + for engine in vllm_engines: + engine.process_from_queue.remote( + generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step + ) + logger.info("======== ✅ vllm engines started processing from queues =========") packing_thread = threading.Thread( target=data_preparation_thread, @@ -2029,7 +2110,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa reward_fn, inference_results_Q, packed_sequences_Q, - queries_prompt_Q, + pending_queries_map, args, tokenizer, args.num_training_steps, @@ -2039,15 +2120,28 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa logger.info("======== ✅ data preparation thread starts =========") # Send initial data to both threads. - data_next = train_dataset[next(iter_dataloader)] + dataset_indices = next(iter_dataloader) + data_next = train_dataset[dataset_indices] queries_next = data_next[INPUT_IDS_PROMPT_KEY] ground_truths_next = data_next[GROUND_TRUTHS_KEY] datasets_next = data_next[DATASET_SOURCE_KEY] - queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next)) - param_prompt_Q.put((None, queries_next)) + + # Split the initial batch using the split_and_insert_batch function + split_and_insert_batch( + queries_next, + ground_truths_next, + datasets_next, + dataset_indices, + 1, # training_step + args.vllm_num_engines, + pending_queries_map, + param_prompt_Q, + eval_prompt_token_ids if eval_dataset is not None else None, + ) num_total_tokens = 0 start_time = time.time() + dataset_indices = None # Initialize for training loop try: for training_step in range(resume_training_step, args.num_training_steps + 1): logger.info("-" * 100) @@ -2061,11 +2155,13 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa train_dataset, iter_dataloader, policy_group, - queries_prompt_Q, + pending_queries_map, param_prompt_Q, queries_next, ground_truths_next, datasets_next, + dataset_indices, + eval_prompt_token_ids, ) collated_data, data_thread_metrics, num_total_tokens = load_data_from_packing_thread( packed_sequences_Q, num_total_tokens @@ -2110,8 +2206,6 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa os._exit(1) # Clean up threads - generate_thread.join() - logger.info("======== ✅ vllm generate thread ends =========") packing_thread.join() logger.info("======== ✅ data preparation thread ends =========") diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 25fa56a31f..66aed2c837 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1,33 +1,34 @@ -import queue -import threading import unittest import ray import torch +from parameterized import parameterized +from ray.util import queue as ray_queue from transformers import AutoTokenizer from vllm import SamplingParams -from open_instruct.grpo_fast import vllm_generate_thread -from open_instruct.vllm_utils3 import create_vllm_engines +from open_instruct.grpo_fast import accumulate_inference_batches, split_and_insert_batch +from open_instruct.vllm_utils3 import GenerationResult, PromptRequest, RequestInfo, create_vllm_engines class TestGrpoFastVLLM(unittest.TestCase): - def setUp(self): - # Check if CUDA is available - if not torch.cuda.is_available(): - self.skipTest("CUDA is not available, skipping test") - + @classmethod + def setUpClass(cls): # Initialize Ray if not ray.is_initialized(): ray.init(ignore_reinit_error=True) - def tearDown(self): + @classmethod + def tearDownClass(cls): # Shutdown Ray after test if ray.is_initialized(): ray.shutdown() - def test_vllm_generate_thread_single_prompt(self): - """Test vllm_generate_thread with a single prompt 'What is the capital of France?'""" + def test_vllm_queue_system_single_prompt(self): + """Test the new queue-based vLLM system with a single prompt 'What is the capital of France?'""" + # Check if CUDA is available + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available, skipping test") # Set up tokenizer tokenizer_name = "EleutherAI/pythia-14m" # Using a small model for testing @@ -35,9 +36,13 @@ def test_vllm_generate_thread_single_prompt(self): # Tokenize the test prompt test_prompt = "What is the capital of France?" - prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist() + prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0] + + # Create Ray queues + param_prompt_Q = ray_queue.Queue(maxsize=1) + inference_results_Q = ray_queue.Queue(maxsize=1) - # Create vLLM engines + # Create vLLM engines with queues vllm_engines = create_vllm_engines( num_engines=1, tensor_parallel_size=1, @@ -49,6 +54,8 @@ def test_vllm_generate_thread_single_prompt(self): enable_prefix_caching=False, max_model_len=512, vllm_gpu_memory_utilization=0.5, # Use less GPU memory for testing + prompt_queue=param_prompt_Q, + results_queue=inference_results_Q, ) # Set up generation config @@ -59,41 +66,129 @@ def test_vllm_generate_thread_single_prompt(self): seed=42, ) - # Set up queues - inference_results_Q = queue.Queue() - param_prompt_Q = queue.Queue() - evaluation_inference_results_Q = queue.Queue() - - # Create and start the generation thread - generate_thread = threading.Thread( - target=vllm_generate_thread, - args=( - vllm_engines, + # Start vLLM engines to process from queues + for engine in vllm_engines: + engine.process_from_queue.remote( generation_config, - generation_config, # Using same config for eval - inference_results_Q, - param_prompt_Q, + generation_config, # eval_sampling_params + 999, # eval_freq (avoid evaluation) 1, # num_training_steps - None, # eval_prompt_token_ids - evaluation_inference_results_Q, - 1, # eval_freq 1, # resume_training_step - False, # tool_use - ), - ) - generate_thread.start() - # Put the test prompt in the queue - param_prompt_Q.put((None, prompt_token_ids)) + ) - response_ids, _, _, _ = inference_results_Q.get() + # Put the test prompt in the queue using PromptRequest + request = PromptRequest(prompts=[prompt_token_ids], dataset_index=0) + param_prompt_Q.put(request) + + # Get the result + result = inference_results_Q.get() + + # Verify it's a GenerationResult dataclass + self.assertIsInstance(result, GenerationResult) + + # Check that we got a response + self.assertGreater(len(result.responses), 0) + response_ids = result.responses[0] # Decode the response - generated_text = tokenizer.decode(response_ids[0], skip_special_tokens=True) + generated_text = tokenizer.decode(response_ids, skip_special_tokens=True) self.assertIsInstance(generated_text, str) self.assertGreater(len(generated_text), 0) - generate_thread.join(timeout=5) + # Send stop signal + param_prompt_Q.put(None) + + @parameterized.expand([(1,), (2,), (4,), (8,)]) + def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_rollout: int = 16): + """Test the batch splitting and accumulation logic using split_and_insert_batch and accumulate_inference_batches.""" + + # Mock data - simulating num_unique_prompts_rollout * num_samples_per_prompt_rollout + queries_next = [f"query_{i}" for i in range(num_unique_prompts_rollout)] + ground_truths_next = [f"truth_{i}" for i in range(num_unique_prompts_rollout)] + datasets_next = [f"dataset_{i}" for i in range(num_unique_prompts_rollout)] + + pending_queries_map = {} + training_step = 1 + + # Create mock Ray queue for testing + param_prompt_Q = ray_queue.Queue(maxsize=vllm_num_engines) + + # Create mock dataset indices + dataset_indices = list(range(num_unique_prompts_rollout)) + + # Use split_and_insert_batch to split and insert data + split_and_insert_batch( + queries_next, + ground_truths_next, + datasets_next, + dataset_indices, + training_step, + vllm_num_engines, + pending_queries_map, + param_prompt_Q, + ) + + # Verify that we have individual prompts in the map (not batches) + self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) + + # Verify that we have the expected number of items in the queue + self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines) + + # Create mock inference results to simulate vLLM engine outputs + mock_inference_results = [] + for batch_idx in range(vllm_num_engines): + # Get the request from the queue + request = param_prompt_Q.get() + self.assertIsInstance(request, PromptRequest) + self.assertEqual(request.training_step, training_step) + self.assertIsInstance(request.dataset_index, list) # Now expects a list of indices + + # Create mock GenerationResult + batch_size = len(request.prompts) + mock_result = GenerationResult( + responses=[[i] for i in range(batch_size)], # Mock token IDs + finish_reasons=["stop"] * batch_size, + masks=[[1] * 5] * batch_size, # Mock masks + request_info=RequestInfo( + num_calls=[0] * batch_size, + timeouts=[0] * batch_size, + tool_errors=[""] * batch_size, + tool_outputs=[""] * batch_size, + tool_runtimes=[0] * batch_size, + tool_calleds=[False] * batch_size, + ), + is_eval=False, + dataset_index=request.dataset_index, + ) + mock_inference_results.append(mock_result) + + # Create mock inference results queue + inference_results_Q = ray_queue.Queue(maxsize=vllm_num_engines) + for result in mock_inference_results: + inference_results_Q.put(result) + + # Use accumulate_inference_batches to combine results + combined_result, combined_queries, combined_ground_truths, combined_datasets = accumulate_inference_batches( + inference_results_Q, pending_queries_map, vllm_num_engines, training_step + ) + + # Verify that the combined results match the original input + self.assertEqual(combined_queries, queries_next) + self.assertEqual(combined_ground_truths, ground_truths_next) + self.assertEqual(combined_datasets, datasets_next) + + # Verify that the combined result has the correct structure + self.assertIsInstance(combined_result, GenerationResult) + self.assertEqual(len(combined_result.responses), len(queries_next)) + self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) + self.assertEqual(len(combined_result.masks), len(queries_next)) + + # Verify that the pending_queries_map is empty after accumulation + self.assertEqual(len(pending_queries_map), 0) + + # Verify that the inference_results_Q is empty after accumulation + self.assertEqual(inference_results_Q.qsize(), 0) if __name__ == "__main__": diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index 7e0fca0bea..b4d36de07c 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -15,6 +15,7 @@ """This file is copied from https://github.com/OpenRLHF/OpenRLHF""" +import dataclasses import os from datetime import timedelta from typing import Any, List, Optional, Union @@ -35,6 +36,40 @@ ) +@dataclasses.dataclass +class RequestInfo: + """Container for tool usage information.""" + + num_calls: List[int] + timeouts: List[int] + tool_errors: List[str] + tool_outputs: List[str] + tool_runtimes: List[float] + tool_calleds: List[bool] + + +@dataclasses.dataclass +class GenerationResult: + """Container for generation results from vLLM.""" + + responses: List[List[int]] + finish_reasons: List[str] + masks: List[List[int]] + request_info: RequestInfo + is_eval: bool = False + dataset_index: Optional[int] = None + + +@dataclasses.dataclass +class PromptRequest: + """Container for prompt requests to vLLM.""" + + prompts: List[List[int]] + training_step: Optional[int] = None + eval_prompts: Optional[List[List[int]]] = None + dataset_index: Optional[int] = None + + def ray_noset_visible_devices(env_vars=os.environ): # Refer to # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 @@ -116,7 +151,16 @@ def init_process_group( @ray.remote class LLMRayActor: - def __init__(self, *args, bundle_indices: list = None, tool_use: bool = False, **kwargs): + def __init__( + self, + *args, + bundle_indices: list = None, + tool_use: bool = False, + prompt_queue=None, + results_queue=None, + eval_results_queue=None, + **kwargs, + ): noset_visible_devices = kwargs.pop("noset_visible_devices") if kwargs.get("distributed_executor_backend") == "ray": # a hack to make the script work. @@ -145,9 +189,91 @@ def __init__(self, *args, bundle_indices: list = None, tool_use: bool = False, * self.llm = LLM(*args, **kwargs) + self.prompt_queue = prompt_queue + self.results_queue = results_queue + self.eval_results_queue = eval_results_queue + self.tool_use = tool_use + def generate(self, *args, **kwargs): return self.llm.generate(*args, **kwargs) + def process_from_queue( + self, + sampling_params, + eval_sampling_params=None, + eval_freq=None, + num_training_steps=None, + resume_training_step=1, + ): + """Process prompts from the queue and put results in the results queue.""" + for training_step in range(resume_training_step, num_training_steps + 1): + # Get prompts from queue + request = self.prompt_queue.get() + if request is None: + break + + # Process training prompts + result = self._generate_batch(request.prompts, sampling_params, request.dataset_index) + self.results_queue.put(result) + + # Handle evaluation if needed + if ( + request.eval_prompts is not None + and eval_sampling_params is not None + and (training_step - 1) % eval_freq == 0 + ): + eval_result = self._generate_batch(request.eval_prompts, eval_sampling_params, request.dataset_index) + eval_result.is_eval = True + # Put eval results in separate queue if available + if self.eval_results_queue is not None: + self.eval_results_queue.put(eval_result) + else: + self.results_queue.put(eval_result) + + def _generate_batch( + self, prompts: List[List[int]], sampling_params, dataset_index: Optional[int] = None + ) -> GenerationResult: + """Generate responses for a batch of prompts.""" + outputs = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=prompts, use_tqdm=False) + + # Process outputs + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + finish_reasons = [out.finish_reason for output in outputs for out in output.outputs] + + if self.tool_use: + masks = [out.mask for output in outputs for out in output.outputs] + num_calls = [out.num_calls for output in outputs for out in output.outputs] + timeouts = [out.timeout for output in outputs for out in output.outputs] + tool_errors = [out.tool_error for output in outputs for out in output.outputs] + tool_outputs = [out.tool_output for output in outputs for out in output.outputs] + tool_runtimes = [out.tool_runtime for output in outputs for out in output.outputs] + tool_calleds = [out.tool_called for output in outputs for out in output.outputs] + else: + masks = [[1] * len(resp) for resp in response_ids] + num_calls = [0] * len(response_ids) + timeouts = [0] * len(response_ids) + tool_errors = [""] * len(response_ids) + tool_outputs = [""] * len(response_ids) + tool_runtimes = [0] * len(response_ids) + tool_calleds = [False] * len(response_ids) + + request_info = RequestInfo( + num_calls=num_calls, + timeouts=timeouts, + tool_errors=tool_errors, + tool_outputs=tool_outputs, + tool_runtimes=tool_runtimes, + tool_calleds=tool_calleds, + ) + + return GenerationResult( + responses=response_ids, + finish_reasons=finish_reasons, + masks=masks, + request_info=request_info, + dataset_index=dataset_index, + ) + def init_process_group( self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray=False ): @@ -188,6 +314,9 @@ def create_vllm_engines( vllm_enable_sleep=False, tools: Optional[List[Any]] = None, max_tool_calls: List[int] = [5], + prompt_queue=None, + results_queue=None, + eval_results_queue=None, ) -> list[LLMRayActor]: import vllm @@ -265,6 +394,9 @@ def create_vllm_engines( enable_sleep_mode=vllm_enable_sleep, noset_visible_devices=ray_noset_visible_devices(), tool_use=tool_use, + prompt_queue=prompt_queue, + results_queue=results_queue, + eval_results_queue=eval_results_queue, **additional_kwargs, ) )