Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 11 additions & 58 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ def setup_vllm_engines(
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())

# Adjust queue sizes for individual prompts
param_prompt_Q = ray_queue.Queue(maxsize=100)
inference_results_Q = ray_queue.Queue(maxsize=100)
param_prompt_Q = ray_queue.Queue(maxsize=10)
inference_results_Q = ray_queue.Queue(maxsize=10)

vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=args.vllm_num_engines,
Expand Down Expand Up @@ -335,41 +334,6 @@ def get_batch_data(
return prompts


def run_generation_batch(
inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int
) -> dict[str, Any]:
"""Run generation for a batch of prompts and measure performance."""

start_time = time.time()

# Insert individual prompts
for i, prompt in enumerate(prompts):
dataset_idx = batch_idx * len(prompts) + i
param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx))

# Collect individual results
all_responses = []
all_finish_reasons = []
for _ in range(len(prompts)):
result = inference_results_Q.get()
all_responses.extend(result.responses)
all_finish_reasons.extend(result.finish_reasons)

generation_time = time.time() - start_time

new_tokens = sum(len(response) for response in all_responses)
tokens_per_second = new_tokens / generation_time
return {
"tokens_per_second": tokens_per_second,
"generation_time": generation_time,
"num_new_tokens": new_tokens,
# dict mapping string reasons to counts.
"finish_reasons": collections.Counter(all_finish_reasons),
"response_lengths": [len(response) for response in all_responses],
"prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts
}


def run_benchmark(
dataset: datasets.Dataset,
vllm_engines: list[ray.actor.ActorHandle],
Expand Down Expand Up @@ -402,7 +366,6 @@ def run_benchmark(
num_batches + 1, # eval_freq (avoid evaluation)
num_batches, # num_training_steps
1, # resume_training_step
args.num_unique_prompts_rollout // args.vllm_num_engines, # batch_size
)

# Wait for engines to be ready
Expand All @@ -420,47 +383,37 @@ def run_benchmark(
]
submission_start_time = time.time()
for batch_idx in range(num_batches):
# Insert individual prompts for this batch
for i, prompt in enumerate(all_prompts[batch_idx]):
dataset_idx = batch_idx * args.num_unique_prompts_rollout + i
param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx))
param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=all_prompts[batch_idx], dataset_index=batch_idx))
submission_time = time.time() - submission_start_time
logger.info(f"All batches submitted in {submission_time:.2f}s")

# Receive results and measure time for each batch
last_completion_time = submission_start_time
for batch_idx in range(num_batches):
# Collect individual results for this batch
batch_responses = []
batch_finish_reasons = []
for _ in range(args.num_unique_prompts_rollout):
result = inference_results_Q.get()
batch_responses.extend(result.responses)
batch_finish_reasons.extend(result.finish_reasons)

result = inference_results_Q.get()
completion_time = time.time()
batch_generation_time = completion_time - last_completion_time
last_completion_time = completion_time

# Process batch results
new_tokens = sum(len(response) for response in batch_responses)
# Process result
new_tokens = sum(len(response) for response in result.responses)
tokens_per_second = new_tokens / batch_generation_time

result_dict = {
"tokens_per_second": tokens_per_second,
"generation_time": batch_generation_time,
"num_new_tokens": new_tokens,
"finish_reasons": collections.Counter(batch_finish_reasons),
"response_lengths": [len(response) for response in batch_responses],
"batch_idx": batch_idx,
"finish_reasons": collections.Counter(result.finish_reasons),
"response_lengths": [len(response) for response in result.responses],
"batch_idx": result.dataset_index,
}
result_dict["mfu"] = 100 * result_dict["tokens_per_second"] * flops_per_token / device_flops

# We incrementally save completion lengths so even if the job dies, we still have data.
save_completion_lengths([result_dict], timestamp, batch_idx)
save_completion_lengths([result_dict], timestamp, result.dataset_index)
results.append(result_dict)
logger.info(
f"Batch {batch_idx + 1}: "
f"Batch {result.dataset_index + 1}: "
f"{result_dict['tokens_per_second']:.2f} new tokens/sec, "
f"MFU: {result_dict['mfu']:.2f}%, "
f"generation time: {batch_generation_time:.2f}s"
Expand Down
94 changes: 48 additions & 46 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,42 +1106,50 @@ def get_bundle_index(rank, num_gpus_per_node):


def accumulate_inference_batches(
inference_results_Q: ray_queue.Queue, pending_queries_map: dict, expected_results: int, training_step: int
inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int
) -> tuple:
"""Accumulate individual inference results into a single training batch.
"""Accumulate multiple inference results into a single training batch.

Args:
inference_results_Q: Queue containing GenerationResult objects (one per prompt)
inference_results_Q: Queue containing GenerationResult objects
pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets)
expected_results: Number of individual results to accumulate
vllm_num_engines: Number of vLLM engines (number of batches to accumulate)
training_step: Current training step for error reporting

Returns:
Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets)
"""
# Collect individual results
# Collect results from all engines
results = []
all_queries = []
all_ground_truths = []
all_datasets = []

for i in range(expected_results):
# Get individual result from queue
for batch_idx in range(vllm_num_engines):
# Get result from queue
result = inference_results_Q.get()
dataset_indices = result.dataset_index

if result.dataset_index is None or len(result.dataset_index) != 1:
raise RuntimeError(f"Expected single dataset index, got {result.dataset_index}")
if dataset_indices is None:
raise RuntimeError(f"Dataset indices is None for batch {batch_idx}")

dataset_idx = result.dataset_index[0]
if dataset_idx not in pending_queries_map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")
# Get corresponding queries, ground_truths, datasets for each individual prompt
batch_queries = []
batch_ground_truths = []
batch_datasets = []
for dataset_idx in dataset_indices:
if dataset_idx not in pending_queries_map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")

query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
batch_queries.append(query)
batch_ground_truths.append(ground_truth)
batch_datasets.append(dataset)

results.append(result)
all_queries.append(query)
all_ground_truths.append(ground_truth)
all_datasets.append(dataset)
all_queries.extend(batch_queries)
all_ground_truths.extend(batch_ground_truths)
all_datasets.extend(batch_datasets)

# Combine all results into a single GenerationResult
combined_responses = []
Expand Down Expand Up @@ -1201,10 +1209,7 @@ def data_preparation_thread(
# Accumulate results from multiple vLLM engines into a single training batch
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
result, queries, ground_truths, datasets = accumulate_inference_batches(
inference_results_Q,
pending_queries_map,
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout,
training_step,
inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step
)

# ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1672,19 +1677,29 @@ def split_and_insert_batch(
param_prompt_Q,
eval_prompt_token_ids=None,
):
"""Insert individual prompts into queues and mapping."""
# Insert individual prompts into the queue
for i, dataset_idx in enumerate(dataset_indices):
# Store individual prompt in the map using dataset index as key
pending_queries_map[dataset_idx] = (queries_next[i], ground_truths_next[i], datasets_next[i])

# Create PromptRequest for single prompt
"""Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
# Split the batch over the VLLM engines.
inference_batch_size = len(queries_next) // vllm_num_engines
for batch_idx in range(vllm_num_engines):
start_idx = batch_idx * inference_batch_size
end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next)

batch_queries = queries_next[start_idx:end_idx]
batch_ground_truths = ground_truths_next[start_idx:end_idx]
batch_datasets = datasets_next[start_idx:end_idx]
batch_dataset_indices = dataset_indices[start_idx:end_idx]

# Store individual prompts in the map using dataset indices as keys
for i, dataset_idx in enumerate(batch_dataset_indices):
pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i])

# Use PromptRequest for Ray queue with batch-specific dataset_index list
param_prompt_Q.put(
PromptRequest(
prompt=queries_next[i],
prompts=batch_queries,
training_step=training_step,
eval_prompts=eval_prompt_token_ids,
dataset_index=dataset_idx,
dataset_index=batch_dataset_indices,
)
)

Expand Down Expand Up @@ -2034,10 +2049,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa

pprint([args, model_config])

# Create Ray queues - adjust maxsize for individual prompts.
total_prompts_per_step = args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
inference_results_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step)
param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step)
# Create Ray queues
inference_results_Q = ray_queue.Queue(maxsize=args.async_steps)
param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps)
evaluation_inference_results_Q = ray_queue.Queue(maxsize=1)

policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer(
Expand Down Expand Up @@ -2086,21 +2100,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
reward_fn = make_reward_fn(args)

# Start vLLM engines to process from queues
if args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout % args.vllm_num_engines != 0:
raise ValueError(
"The number of unique prompts times the number of samples per prompt must be divisible by the number of vLLM engines."
)
batch_size_per_engine = (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
) // args.vllm_num_engines
for engine in vllm_engines:
engine.process_from_queue.remote(
generation_config,
eval_generation_config,
args.eval_freq,
args.num_training_steps,
resume_training_step,
batch_size_per_engine,
generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step
)
logger.info("======== ✅ vllm engines started processing from queues =========")

Expand Down
55 changes: 27 additions & 28 deletions open_instruct/test_grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,10 @@ def test_vllm_queue_system_single_prompt(self):
999, # eval_freq (avoid evaluation)
1, # num_training_steps
1, # resume_training_step
1, # batch_size
)

# Put the test prompt in the queue using PromptRequest
request = PromptRequest(prompt=prompt_token_ids, dataset_index=0)
request = PromptRequest(prompts=[prompt_token_ids], dataset_index=0)
param_prompt_Q.put(request)

# Get the result
Expand All @@ -105,16 +104,17 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
"""Test the batch splitting and accumulation logic using split_and_insert_batch and accumulate_inference_batches."""

# Mock data - simulating num_unique_prompts_rollout * num_samples_per_prompt_rollout
# Use lists of integers to simulate tokenized prompts
queries_next = [[i, i + 1, i + 2] for i in range(num_unique_prompts_rollout)] # Mock token IDs
queries_next = [f"query_{i}" for i in range(num_unique_prompts_rollout)]
ground_truths_next = [f"truth_{i}" for i in range(num_unique_prompts_rollout)]
datasets_next = [f"dataset_{i}" for i in range(num_unique_prompts_rollout)]

pending_queries_map = {}
training_step = 1

param_prompt_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout)
# Create mock Ray queue for testing
param_prompt_Q = ray_queue.Queue(maxsize=vllm_num_engines)

# Create mock dataset indices
dataset_indices = list(range(num_unique_prompts_rollout))

# Use split_and_insert_batch to split and insert data
Expand All @@ -129,49 +129,48 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
param_prompt_Q,
)

# Verify that we have individual prompts in the map (not batches)
self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout)

self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout)
# Verify that we have the expected number of items in the queue
self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines)

# Create mock inference results to simulate vLLM engine outputs (individual results)
# Create mock inference results to simulate vLLM engine outputs
mock_inference_results = []
requests_processed = []
for i in range(num_unique_prompts_rollout):
for batch_idx in range(vllm_num_engines):
# Get the request from the queue
request = param_prompt_Q.get()
self.assertIsInstance(request, PromptRequest)
self.assertEqual(request.training_step, training_step)
self.assertIsInstance(request.dataset_index, int) # Single dataset index
self.assertIsInstance(request.prompt, list) # Single prompt as list of ints
self.assertIsInstance(request.dataset_index, list) # Now expects a list of indices

# Store request for later verification
requests_processed.append(request)

# Create mock GenerationResult for single prompt
# Create mock GenerationResult
batch_size = len(request.prompts)
mock_result = GenerationResult(
responses=[[i]], # Mock token IDs for single response
finish_reasons=["stop"],
masks=[[1] * 5], # Mock masks
responses=[[i] for i in range(batch_size)], # Mock token IDs
finish_reasons=["stop"] * batch_size,
masks=[[1] * 5] * batch_size, # Mock masks
request_info=RequestInfo(
num_calls=[0],
timeouts=[0],
tool_errors=[""],
tool_outputs=[""],
tool_runtimes=[0],
tool_calleds=[False],
num_calls=[0] * batch_size,
timeouts=[0] * batch_size,
tool_errors=[""] * batch_size,
tool_outputs=[""] * batch_size,
tool_runtimes=[0] * batch_size,
tool_calleds=[False] * batch_size,
),
is_eval=False,
dataset_index=[request.dataset_index],
dataset_index=request.dataset_index,
)
mock_inference_results.append(mock_result)

inference_results_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout)
# Create mock inference results queue
inference_results_Q = ray_queue.Queue(maxsize=vllm_num_engines)
for result in mock_inference_results:
inference_results_Q.put(result)

# Use accumulate_inference_batches to combine results
combined_result, combined_queries, combined_ground_truths, combined_datasets = accumulate_inference_batches(
inference_results_Q, pending_queries_map, num_unique_prompts_rollout, training_step
inference_results_Q, pending_queries_map, vllm_num_engines, training_step
)

# Verify that the combined results match the original input
Expand All @@ -185,7 +184,7 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
self.assertEqual(len(combined_result.finish_reasons), len(queries_next))
self.assertEqual(len(combined_result.masks), len(queries_next))

# Verify that the test_pending_queries_map is empty after accumulation
# Verify that the pending_queries_map is empty after accumulation
self.assertEqual(len(pending_queries_map), 0)

# Verify that the inference_results_Q is empty after accumulation
Expand Down
Loading