Skip to content
69 changes: 58 additions & 11 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,9 @@ def setup_vllm_engines(
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())

param_prompt_Q = ray_queue.Queue(maxsize=10)
inference_results_Q = ray_queue.Queue(maxsize=10)
# Adjust queue sizes for individual prompts
param_prompt_Q = ray_queue.Queue(maxsize=100)
inference_results_Q = ray_queue.Queue(maxsize=100)

vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=args.vllm_num_engines,
Expand Down Expand Up @@ -334,6 +335,41 @@ def get_batch_data(
return prompts


def run_generation_batch(
inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int
) -> dict[str, Any]:
"""Run generation for a batch of prompts and measure performance."""

start_time = time.time()

# Insert individual prompts
for i, prompt in enumerate(prompts):
dataset_idx = batch_idx * len(prompts) + i
param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx))

# Collect individual results
all_responses = []
all_finish_reasons = []
for _ in range(len(prompts)):
result = inference_results_Q.get()
all_responses.extend(result.responses)
all_finish_reasons.extend(result.finish_reasons)

generation_time = time.time() - start_time

new_tokens = sum(len(response) for response in all_responses)
tokens_per_second = new_tokens / generation_time
return {
"tokens_per_second": tokens_per_second,
"generation_time": generation_time,
"num_new_tokens": new_tokens,
# dict mapping string reasons to counts.
"finish_reasons": collections.Counter(all_finish_reasons),
"response_lengths": [len(response) for response in all_responses],
"prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts
}


def run_benchmark(
dataset: datasets.Dataset,
vllm_engines: list[ray.actor.ActorHandle],
Expand Down Expand Up @@ -366,6 +402,7 @@ def run_benchmark(
num_batches + 1, # eval_freq (avoid evaluation)
num_batches, # num_training_steps
1, # resume_training_step
args.num_unique_prompts_rollout // args.vllm_num_engines, # batch_size
)

# Wait for engines to be ready
Expand All @@ -383,37 +420,47 @@ def run_benchmark(
]
submission_start_time = time.time()
for batch_idx in range(num_batches):
param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=all_prompts[batch_idx], dataset_index=batch_idx))
# Insert individual prompts for this batch
for i, prompt in enumerate(all_prompts[batch_idx]):
dataset_idx = batch_idx * args.num_unique_prompts_rollout + i
param_prompt_Q.put(vllm_utils3.PromptRequest(prompt=prompt, dataset_index=dataset_idx))
submission_time = time.time() - submission_start_time
logger.info(f"All batches submitted in {submission_time:.2f}s")

# Receive results and measure time for each batch
last_completion_time = submission_start_time
for batch_idx in range(num_batches):
result = inference_results_Q.get()
# Collect individual results for this batch
batch_responses = []
batch_finish_reasons = []
for _ in range(args.num_unique_prompts_rollout):
result = inference_results_Q.get()
batch_responses.extend(result.responses)
batch_finish_reasons.extend(result.finish_reasons)

completion_time = time.time()
batch_generation_time = completion_time - last_completion_time
last_completion_time = completion_time

# Process result
new_tokens = sum(len(response) for response in result.responses)
# Process batch results
new_tokens = sum(len(response) for response in batch_responses)
tokens_per_second = new_tokens / batch_generation_time

result_dict = {
"tokens_per_second": tokens_per_second,
"generation_time": batch_generation_time,
"num_new_tokens": new_tokens,
"finish_reasons": collections.Counter(result.finish_reasons),
"response_lengths": [len(response) for response in result.responses],
"batch_idx": result.dataset_index,
"finish_reasons": collections.Counter(batch_finish_reasons),
"response_lengths": [len(response) for response in batch_responses],
"batch_idx": batch_idx,
}
result_dict["mfu"] = 100 * result_dict["tokens_per_second"] * flops_per_token / device_flops

# We incrementally save completion lengths so even if the job dies, we still have data.
save_completion_lengths([result_dict], timestamp, result.dataset_index)
save_completion_lengths([result_dict], timestamp, batch_idx)
results.append(result_dict)
logger.info(
f"Batch {result.dataset_index + 1}: "
f"Batch {batch_idx + 1}: "
f"{result_dict['tokens_per_second']:.2f} new tokens/sec, "
f"MFU: {result_dict['mfu']:.2f}%, "
f"generation time: {batch_generation_time:.2f}s"
Expand Down
94 changes: 46 additions & 48 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,50 +1106,42 @@ def get_bundle_index(rank, num_gpus_per_node):


def accumulate_inference_batches(
inference_results_Q: ray_queue.Queue, pending_queries_map: dict, vllm_num_engines: int, training_step: int
inference_results_Q: ray_queue.Queue, pending_queries_map: dict, expected_results: int, training_step: int
) -> tuple:
"""Accumulate multiple inference results into a single training batch.
"""Accumulate individual inference results into a single training batch.

Args:
inference_results_Q: Queue containing GenerationResult objects
inference_results_Q: Queue containing GenerationResult objects (one per prompt)
pending_queries_map: Map of dataset_index -> (queries, ground_truths, datasets)
vllm_num_engines: Number of vLLM engines (number of batches to accumulate)
expected_results: Number of individual results to accumulate
training_step: Current training step for error reporting

Returns:
Tuple of (combined_result, combined_queries, combined_ground_truths, combined_datasets)
"""
# Collect results from all engines
# Collect individual results
results = []
all_queries = []
all_ground_truths = []
all_datasets = []

for batch_idx in range(vllm_num_engines):
# Get result from queue
for i in range(expected_results):
# Get individual result from queue
result = inference_results_Q.get()
dataset_indices = result.dataset_index

if dataset_indices is None:
raise RuntimeError(f"Dataset indices is None for batch {batch_idx}")
if result.dataset_index is None or len(result.dataset_index) != 1:
raise RuntimeError(f"Expected single dataset index, got {result.dataset_index}")

# Get corresponding queries, ground_truths, datasets for each individual prompt
batch_queries = []
batch_ground_truths = []
batch_datasets = []
for dataset_idx in dataset_indices:
if dataset_idx not in pending_queries_map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")
dataset_idx = result.dataset_index[0]
if dataset_idx not in pending_queries_map:
raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map")

query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)
batch_queries.append(query)
batch_ground_truths.append(ground_truth)
batch_datasets.append(dataset)
query, ground_truth, dataset = pending_queries_map.pop(dataset_idx)

results.append(result)
all_queries.extend(batch_queries)
all_ground_truths.extend(batch_ground_truths)
all_datasets.extend(batch_datasets)
all_queries.append(query)
all_ground_truths.append(ground_truth)
all_datasets.append(dataset)

# Combine all results into a single GenerationResult
combined_responses = []
Expand Down Expand Up @@ -1209,7 +1201,10 @@ def data_preparation_thread(
# Accumulate results from multiple vLLM engines into a single training batch
with Timer("🚀 [Data Preparation Thread] Getting response ids"):
result, queries, ground_truths, datasets = accumulate_inference_batches(
inference_results_Q, pending_queries_map, args.vllm_num_engines, training_step
inference_results_Q,
pending_queries_map,
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that at the end of an epoch or something you will have fewer than this number of results? or are we dropping the last batch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShufflingIterator guarantees that they're all the same size:

 class ShufflingIterator:
      def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None):
          self.data = data.copy()
          self.batch_size = batch_size
          self.index = 0
          self.rng = np.random.default_rng(seed)
          self.rng.shuffle(self.data)

          # Ensure the effective dataset size is divisible by batch_size                                                  
          self.effective_size = len(self.data) - (len(self.data) % batch_size)

      def __iter__(self) -> Iterator[List[int]]:
          return self

      def __next__(self) -> List[int]:
          if self.index >= self.effective_size:
              self.index = 0
              self.rng.shuffle(self.data)

          end_index = self.index + self.batch_size
          batch = self.data[self.index : end_index].tolist()
          self.index = end_index

          return batch

training_step,
)

# ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1677,29 +1672,19 @@ def split_and_insert_batch(
param_prompt_Q,
eval_prompt_token_ids=None,
):
"""Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
# Split the batch over the VLLM engines.
inference_batch_size = len(queries_next) // vllm_num_engines
for batch_idx in range(vllm_num_engines):
start_idx = batch_idx * inference_batch_size
end_idx = start_idx + inference_batch_size if batch_idx < vllm_num_engines - 1 else len(queries_next)

batch_queries = queries_next[start_idx:end_idx]
batch_ground_truths = ground_truths_next[start_idx:end_idx]
batch_datasets = datasets_next[start_idx:end_idx]
batch_dataset_indices = dataset_indices[start_idx:end_idx]

# Store individual prompts in the map using dataset indices as keys
for i, dataset_idx in enumerate(batch_dataset_indices):
pending_queries_map[dataset_idx] = (batch_queries[i], batch_ground_truths[i], batch_datasets[i])

# Use PromptRequest for Ray queue with batch-specific dataset_index list
"""Insert individual prompts into queues and mapping."""
# Insert individual prompts into the queue
for i, dataset_idx in enumerate(dataset_indices):
# Store individual prompt in the map using dataset index as key
pending_queries_map[dataset_idx] = (queries_next[i], ground_truths_next[i], datasets_next[i])

# Create PromptRequest for single prompt
param_prompt_Q.put(
PromptRequest(
prompts=batch_queries,
prompt=queries_next[i],
training_step=training_step,
eval_prompts=eval_prompt_token_ids,
dataset_index=batch_dataset_indices,
dataset_index=dataset_idx,
)
)

Expand Down Expand Up @@ -2049,9 +2034,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa

pprint([args, model_config])

# Create Ray queues
inference_results_Q = ray_queue.Queue(maxsize=args.async_steps)
param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps)
# Create Ray queues - adjust maxsize for individual prompts.
total_prompts_per_step = args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
inference_results_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step)
param_prompt_Q = ray_queue.Queue(maxsize=args.async_steps * total_prompts_per_step)
evaluation_inference_results_Q = ray_queue.Queue(maxsize=1)

policy_group, vllm_engines, tool_objects, resume_training_step, episode = create_model_and_optimizer(
Expand Down Expand Up @@ -2100,9 +2086,21 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
reward_fn = make_reward_fn(args)

# Start vLLM engines to process from queues
if args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout % args.vllm_num_engines != 0:
raise ValueError(
"The number of unique prompts times the number of samples per prompt must be divisible by the number of vLLM engines."
)
batch_size_per_engine = (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
) // args.vllm_num_engines
Comment on lines +2093 to +2095
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know if this is equally divisible? and is it alright if its not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be divisible; I changed the code to raise a ValueError if it isn't, and we can handle that then.

for engine in vllm_engines:
engine.process_from_queue.remote(
generation_config, eval_generation_config, args.eval_freq, args.num_training_steps, resume_training_step
generation_config,
eval_generation_config,
args.eval_freq,
args.num_training_steps,
resume_training_step,
batch_size_per_engine,
)
logger.info("======== ✅ vllm engines started processing from queues =========")

Expand Down
55 changes: 28 additions & 27 deletions open_instruct/test_grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ def test_vllm_queue_system_single_prompt(self):
999, # eval_freq (avoid evaluation)
1, # num_training_steps
1, # resume_training_step
1, # batch_size
)

# Put the test prompt in the queue using PromptRequest
request = PromptRequest(prompts=[prompt_token_ids], dataset_index=0)
request = PromptRequest(prompt=prompt_token_ids, dataset_index=0)
param_prompt_Q.put(request)

# Get the result
Expand All @@ -104,17 +105,16 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
"""Test the batch splitting and accumulation logic using split_and_insert_batch and accumulate_inference_batches."""

# Mock data - simulating num_unique_prompts_rollout * num_samples_per_prompt_rollout
queries_next = [f"query_{i}" for i in range(num_unique_prompts_rollout)]
# Use lists of integers to simulate tokenized prompts
queries_next = [[i, i + 1, i + 2] for i in range(num_unique_prompts_rollout)] # Mock token IDs
ground_truths_next = [f"truth_{i}" for i in range(num_unique_prompts_rollout)]
datasets_next = [f"dataset_{i}" for i in range(num_unique_prompts_rollout)]

pending_queries_map = {}
training_step = 1

# Create mock Ray queue for testing
param_prompt_Q = ray_queue.Queue(maxsize=vllm_num_engines)
param_prompt_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout)

# Create mock dataset indices
dataset_indices = list(range(num_unique_prompts_rollout))

# Use split_and_insert_batch to split and insert data
Expand All @@ -129,48 +129,49 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
param_prompt_Q,
)

# Verify that we have individual prompts in the map (not batches)
self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout)

# Verify that we have the expected number of items in the queue
self.assertEqual(param_prompt_Q.qsize(), vllm_num_engines)
self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout)

# Create mock inference results to simulate vLLM engine outputs
# Create mock inference results to simulate vLLM engine outputs (individual results)
mock_inference_results = []
for batch_idx in range(vllm_num_engines):
requests_processed = []
for i in range(num_unique_prompts_rollout):
# Get the request from the queue
request = param_prompt_Q.get()
self.assertIsInstance(request, PromptRequest)
self.assertEqual(request.training_step, training_step)
self.assertIsInstance(request.dataset_index, list) # Now expects a list of indices
self.assertIsInstance(request.dataset_index, int) # Single dataset index
self.assertIsInstance(request.prompt, list) # Single prompt as list of ints

# Create mock GenerationResult
batch_size = len(request.prompts)
# Store request for later verification
requests_processed.append(request)

# Create mock GenerationResult for single prompt
mock_result = GenerationResult(
responses=[[i] for i in range(batch_size)], # Mock token IDs
finish_reasons=["stop"] * batch_size,
masks=[[1] * 5] * batch_size, # Mock masks
responses=[[i]], # Mock token IDs for single response
finish_reasons=["stop"],
masks=[[1] * 5], # Mock masks
request_info=RequestInfo(
num_calls=[0] * batch_size,
timeouts=[0] * batch_size,
tool_errors=[""] * batch_size,
tool_outputs=[""] * batch_size,
tool_runtimes=[0] * batch_size,
tool_calleds=[False] * batch_size,
num_calls=[0],
timeouts=[0],
tool_errors=[""],
tool_outputs=[""],
tool_runtimes=[0],
tool_calleds=[False],
),
is_eval=False,
dataset_index=request.dataset_index,
dataset_index=[request.dataset_index],
)
mock_inference_results.append(mock_result)

# Create mock inference results queue
inference_results_Q = ray_queue.Queue(maxsize=vllm_num_engines)
inference_results_Q = ray_queue.Queue(maxsize=num_unique_prompts_rollout)
for result in mock_inference_results:
inference_results_Q.put(result)

# Use accumulate_inference_batches to combine results
combined_result, combined_queries, combined_ground_truths, combined_datasets = accumulate_inference_batches(
inference_results_Q, pending_queries_map, vllm_num_engines, training_step
inference_results_Q, pending_queries_map, num_unique_prompts_rollout, training_step
)

# Verify that the combined results match the original input
Expand All @@ -184,7 +185,7 @@ def test_batch_splitting_logic(self, vllm_num_engines: int, num_unique_prompts_r
self.assertEqual(len(combined_result.finish_reasons), len(queries_next))
self.assertEqual(len(combined_result.masks), len(queries_next))

# Verify that the pending_queries_map is empty after accumulation
# Verify that the test_pending_queries_map is empty after accumulation
self.assertEqual(len(pending_queries_map), 0)

# Verify that the inference_results_Q is empty after accumulation
Expand Down
Loading