Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import os
import warnings
from collections.abc import Callable
from unittest.mock import patch
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -157,6 +158,59 @@ def test_compute_entropy_all_masked(self):
torch.testing.assert_close(entropy_mask, expected_mask)


class TestGRPORolloutDispatch:
def _make_trainer(self):
trainer = object.__new__(GRPOTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
trainer.args = SimpleNamespace(report_to=[])
trainer.model = SimpleNamespace(training=True)
trainer.state = SimpleNamespace(global_step=2)
trainer._last_loaded_step = 1
trainer.use_vllm = False
trainer.use_transformers_paged = False
trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock())
return trainer

def test_generate_single_turn_prefers_rollout_func(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(
return_value={
"prompt_ids": [[1]],
"completion_ids": [[2]],
"logprobs": [[-0.1]],
"env_mask": [[1]],
}
)

prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"])

assert prompt_ids == [[1]]
assert completion_ids == [[2]]
assert logprobs == [[-0.1]]
assert extra_fields == {"env_mask": [[1]]}
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self):
trainer = self._make_trainer()
trainer.use_vllm = True
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

trainer._generate_single_turn(["prompt"])

trainer.vllm_generation.sync_weights.assert_called_once()
assert trainer._last_loaded_step == trainer.state.global_step
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]})

with pytest.raises(ValueError, match="rollout_func must return keys"):
trainer._generate_single_turn(["prompt"])


class TestGRPOTrainer(TrlTestCase):
def test_init_minimal(self):
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
Expand Down
204 changes: 83 additions & 121 deletions trl/generation/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
import math
import os
from collections.abc import Callable
from contextlib import nullcontext
from typing import TYPE_CHECKING

Expand All @@ -29,7 +28,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available

from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_vllm
from ..data_utils import is_conversational, prepare_multimodal_messages_vllm
from ..extras.profiling import ProfilingContext
from ..import_utils import is_vllm_available
from ..trainer.utils import ensure_master_addr_port
Expand Down Expand Up @@ -171,12 +170,6 @@ class VLLMGeneration:
Additional keyword arguments to customize the chat template used by the model.
tools (`list`, *optional*):
Tools available for tool calling during chat generation.
rollout_func (`Callable`, *optional*): Optional custom rollout function that accepts prompts and returns
a dict with 'prompt_ids', 'completion_ids', 'logprobs', and optional extra fields. Should be a
single-argument callable: rollout_func(prompts) -> dict. To pass additional context (e.g., trainer), use a
closure or functools.partial:
rollout_func = lambda prompts: my_custom_rollout(prompts, trainer)
The closure will hold a reference to trainer and see its state updates.
"""

def __init__(
Expand Down Expand Up @@ -213,7 +206,6 @@ def __init__(
chat_template: str | None = None,
chat_template_kwargs: dict | None = None,
tools: list | None = None,
rollout_func: Callable | None = None,
):
self.model = model
self.accelerator = accelerator
Expand Down Expand Up @@ -252,7 +244,6 @@ def __init__(
self.chat_template = chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.tools = tools
self.rollout_func = rollout_func

self._init_vllm()

Expand Down Expand Up @@ -499,14 +490,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
"""
profiler = profiler or nullcontext()
accelerator = self.accelerator
rollout_func = self.rollout_func
temperature = self.temperature
top_p = self.top_p
top_k = self.top_k
min_p = self.min_p
repetition_penalty = self.repetition_penalty
max_completion_length = self.max_completion_length
processing_class = self.processing_class
chat_template_kwargs = self.chat_template_kwargs
tools = self.tools
chat_template = self.chat_template
Expand Down Expand Up @@ -553,26 +542,16 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
"generation_kwargs": self.generation_kwargs,
}
with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
if rollout_func is not None:
# Pass all prompts (with duplicates) to rollout_func for consistency with colocate mode
rollout_prompts = all_prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"]
for p in rollout_prompts
]
output = rollout_func(rollout_prompts)
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
Expand All @@ -585,12 +564,9 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
broadcast_object_list(obj_list, from_process=0)
all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0]

# When using rollout_func, it handles its own generation logic and returns one result per prompt.
# When NOT using rollout_func, vllm_client.generate(n=num_generations) returns num_generations
# completions per prompt, so we need to duplicate prompt_ids to match.
if self.rollout_func is None:
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)]
# vllm_client.generate/chat(n=num_generations) returns num_generations completions per prompt.
# Duplicate prompt_ids to align with per-completion entries.
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)]

process_slice = slice(
accelerator.process_index * len(prompts),
Expand All @@ -610,97 +586,83 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte

# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.mode == "colocate":
if rollout_func is not None:
rollout_prompts = prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template({"prompt": prompt}, processing_class, **chat_template_kwargs)["prompt"]
for prompt in rollout_prompts
]
output = rollout_func(rollout_prompts)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids = output["prompt_ids"]
completion_ids = output["completion_ids"]
logprobs = output["logprobs"]
if Version(vllm.__version__) <= Version("0.10.2"):
structured_outputs_key = "guided_decoding"
if self.structured_outputs_regex:
structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex)
else:
structured_outputs = None
else:
if Version(vllm.__version__) <= Version("0.10.2"):
structured_outputs_key = "guided_decoding"
if self.structured_outputs_regex:
structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex)
else:
structured_outputs = None
structured_outputs_key = "structured_outputs"
if self.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
else:
structured_outputs_key = "structured_outputs"
if self.structured_outputs_regex:
structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex)
else:
structured_outputs = None
structured_outputs = None

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": 0.0 if min_p is None else min_p,
"max_tokens": max_completion_length,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
generation_kwargs[structured_outputs_key] = structured_outputs
generation_kwargs.update(self.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": repetition_penalty,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"min_p": 0.0 if min_p is None else min_p,
"max_tokens": max_completion_length,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
generation_kwargs[structured_outputs_key] = structured_outputs
generation_kwargs.update(self.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)

if self.tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
if self.tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts = prompts

if self.enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])

with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
all_prompts = prompts
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)

if self.enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[sanitize_logprob(next(iter(lp.values()))) for lp in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]

with profiler: # TODO: profiling_context(trainer, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=chat_template_kwargs,
tools=tools,
chat_template=chat_template,
)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)

all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[sanitize_logprob(next(iter(lp.values()))) for lp in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]

if self.tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs
if self.tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs

extra_fields = {} # No extra fields for colocate mode
extra_fields = {} # No extra fields for colocate mode

if self.enable_sleep_mode:
self.llm.sleep(level=2)
if self.enable_sleep_mode:
self.llm.sleep(level=2)

return prompt_ids, completion_ids, logprobs, extra_fields
Loading
Loading