Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 24 additions & 40 deletions open_instruct/benchmark_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import json
import logging
import pathlib
import queue
import threading
import time
from typing import Any

Expand All @@ -26,6 +24,7 @@
import torch.utils.flop_counter
import transformers
import vllm
from ray.util import queue as ray_queue

from open_instruct import dataset_transformation, grpo_fast, model_utils, utils, vllm_utils3

Expand Down Expand Up @@ -204,21 +203,22 @@ def setup_dataset(args: grpo_fast.Args, tokenizer_config: dataset_transformation

def setup_vllm_engines(
args: grpo_fast.Args, model_config: model_utils.ModelConfig, max_model_len: int = 20480
) -> list[ray.actor.ActorHandle]:
"""Set up vLLM engines."""
) -> tuple[list[ray.actor.ActorHandle], ray_queue.Queue, ray_queue.Queue]:
"""Set up vLLM engines and queues."""
logger.info("Setting up vLLM engines...")

# Initialize Ray
if ray.is_initialized():
ray.shutdown()
ray.init(num_cpus=4, num_gpus=1, ignore_reinit_error=True, runtime_env={"excludes": ["/benchmark_cache/"]})

# Create placement group for multiple engines
bundles = [{"GPU": 1, "CPU": 1} for _ in range(args.vllm_num_engines)]
pg = ray.util.placement_group(bundles, strategy="PACK")
ray.get(pg.ready())

# Create vLLM engines
param_prompt_Q = ray_queue.Queue(maxsize=10)
inference_results_Q = ray_queue.Queue(maxsize=10)

vllm_engines = vllm_utils3.create_vllm_engines(
num_engines=args.vllm_num_engines,
tensor_parallel_size=args.vllm_tensor_parallel_size,
Expand All @@ -234,11 +234,13 @@ def setup_vllm_engines(
pg=pg,
tools={},
max_tool_calls=[0],
prompt_queue=param_prompt_Q,
results_queue=inference_results_Q,
)

logger.info("vLLM engines ready")

return vllm_engines
return vllm_engines, param_prompt_Q, inference_results_Q


def get_batch_data(
Expand All @@ -254,34 +256,33 @@ def get_batch_data(


def run_generation_batch(
inference_results_Q: queue.Queue, param_prompt_Q: queue.Queue, prompts: list[list[int]]
inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, prompts: list[list[int]], batch_idx: int
) -> dict[str, Any]:
"""Run generation for a batch of prompts and measure performance."""

start_time = time.time()
param_prompt_Q.put((None, prompts))
param_prompt_Q.put(vllm_utils3.PromptRequest(prompts=prompts, dataset_index=batch_idx))
result = inference_results_Q.get()
generation_time = time.time() - start_time

response_ids, finish_reasons, _, _ = result
collated_finish_reasons = collections.Counter(finish_reasons)

new_tokens = sum(len(response) for response in response_ids)
new_tokens = sum(len(response) for response in result.responses)
tokens_per_second = new_tokens / generation_time
return {
"tokens_per_second": tokens_per_second,
"generation_time": generation_time,
"num_new_tokens": new_tokens,
# dict mapping string reasons to counts.
"finish_reasons": collated_finish_reasons,
"response_lengths": [len(response) for response in response_ids],
"finish_reasons": collections.Counter(result.finish_reasons),
"response_lengths": [len(response) for response in result.responses],
"prompt_lengths": [len(prompt) for prompt in prompts], # Original unique prompts
}


def run_benchmark(
dataset: datasets.Dataset,
vllm_engines: list[ray.actor.ActorHandle],
param_prompt_Q: ray_queue.Queue,
inference_results_Q: ray_queue.Queue,
args: grpo_fast.Args,
model_config: model_utils.ModelConfig,
timestamp: int,
Expand All @@ -290,11 +291,6 @@ def run_benchmark(
"""Run the full benchmark."""
logger.info(f"Starting benchmark with {num_batches} batches of size {args.num_unique_prompts_rollout}")

# Create persistent queues
inference_results_Q = queue.Queue(maxsize=10)
param_prompt_Q = queue.Queue(maxsize=10)
evaluation_inference_results_Q = queue.Queue(maxsize=10)

# Create sampling parameters with 'n' for multiple samples per prompt
generation_config = vllm.SamplingParams(
temperature=args.temperature,
Expand All @@ -311,26 +307,17 @@ def run_benchmark(
# Unclear why we need this. We didn't need it before torch 2.7.0.
free_all_gpu_memory()

# Start persistent vLLM generation thread
def wrapped_vllm_generate_thread() -> None:
grpo_fast.vllm_generate_thread(
vllm_engines,
# Start vLLM engines to process from queues
for engine in vllm_engines:
engine.process_from_queue.remote(
generation_config,
eval_generation_config,
inference_results_Q,
param_prompt_Q,
num_batches, # num_training_steps
None, # eval_prompt_token_ids
evaluation_inference_results_Q,
num_batches + 1, # eval_freq (avoid evaluation)
num_batches, # num_training_steps
1, # resume_training_step
False, # tool_use
)

thread = threading.Thread(target=wrapped_vllm_generate_thread)
thread.start()

# Wait for thread to be ready
# Wait for engines to be ready
time.sleep(0.1)

results = []
Expand All @@ -343,7 +330,7 @@ def wrapped_vllm_generate_thread() -> None:
prompts = get_batch_data(dataset, args.num_unique_prompts_rollout, batch_idx)

# Run generation
result = run_generation_batch(inference_results_Q, param_prompt_Q, prompts)
result = run_generation_batch(inference_results_Q, param_prompt_Q, prompts, batch_idx)
result["mfu"] = 100 * result["tokens_per_second"] * flops_per_token / device_flops
# We incrementally save completion lengths so even if the job dies, we still have data.
save_completion_lengths([result], timestamp, batch_idx)
Expand All @@ -360,9 +347,6 @@ def wrapped_vllm_generate_thread() -> None:
# Send stop signal
param_prompt_Q.put(None)

# Wait for thread to finish
thread.join(timeout=10)

print_summary(results, total_time, args, model_config)


Expand Down Expand Up @@ -477,12 +461,12 @@ def main() -> None:
DATA_DIR.mkdir(parents=True, exist_ok=True)

dataset = setup_dataset(args, tokenizer_config)
vllm_engines = setup_vllm_engines(args, model_config)
vllm_engines, param_prompt_Q, inference_results_Q = setup_vllm_engines(args, model_config)

# Create the timestamp here so we use it for both filenames.
timestamp = int(time.time())
save_config(args, tokenizer_config, model_config, timestamp)
run_benchmark(dataset, vllm_engines, args, model_config, timestamp)
run_benchmark(dataset, vllm_engines, param_prompt_Q, inference_results_Q, args, model_config, timestamp)

cleanup(vllm_engines)

Expand Down
Loading