Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,105 @@ def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing
with pytest.raises(ValueError, match="rollout_func must return keys"):
trainer._generate_single_turn(["prompt"])

def test_generate_single_turn_rollout_func_no_extra_fields(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

_, _, _, extra_fields = trainer._generate_single_turn(["prompt"])

assert extra_fields == {}

def test_generate_single_turn_rollout_func_does_not_sync_when_step_unchanged(self):
trainer = self._make_trainer()
trainer.use_vllm = True
trainer._last_loaded_step = trainer.state.global_step # already in sync
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

trainer._generate_single_turn(["prompt"])

trainer.vllm_generation.sync_weights.assert_not_called()

def test_generate_single_turn_rollout_func_receives_structured_messages_for_conversational_prompts(self):
# Regression test for issue #5120: rollout_func must receive structured messages (list[dict]), not
# chat-template-formatted strings. Flattening to strings destroys multimodal content (images, typed
# content blocks) before rollout logic can access it.
trainer = self._make_trainer()
trainer.processing_class = MagicMock()
trainer.chat_template_kwargs = {}
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)
conversational_prompt = [{"role": "user", "content": "hello"}]

with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl:
trainer._generate_single_turn([conversational_prompt])

# apply_chat_template must NOT be called before rollout_func — templating is rollout_func's responsibility
mock_tpl.assert_not_called()
# rollout_func receives the raw structured messages, not a formatted string
trainer.rollout_func.assert_called_once_with([conversational_prompt], trainer)

def test_generate_single_turn_rollout_func_passes_non_conversational_prompt_unchanged(self):
trainer = self._make_trainer()
trainer.processing_class = MagicMock()
trainer.chat_template_kwargs = {}
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl:
trainer._generate_single_turn(["plain string prompt"])

mock_tpl.assert_not_called()
trainer.rollout_func.assert_called_once_with(["plain string prompt"], trainer)

@require_vision
def test_generate_single_turn_rollout_func_receives_real_multimodal_messages(self):
"""Test for issue #5120: rollout_func must receive structured multimodal messages
with image objects preserved, not flattened strings that destroy image content.
"""
from PIL import Image as PILImage

trainer = self._make_trainer()
trainer.processing_class = MagicMock()
trainer.chat_template_kwargs = {}
trainer.use_vllm = False
trainer.use_transformers_paged = False
trainer._last_loaded_step = trainer.state.global_step

received_prompts = []

def capture_rollout_func(prompts, trainer):
received_prompts.append(prompts)
return {"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}

trainer.rollout_func = capture_rollout_func

test_image = PILImage.new("RGB", (10, 10))
multimodal_prompt = [
{"role": "user", "content": [{"type": "image", "image": test_image}, {"type": "text", "text": "What is in this image?"}]}
]

with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl:
trainer._generate_single_turn([multimodal_prompt])

mock_tpl.assert_not_called()
assert len(received_prompts) == 1
prompt_received = received_prompts[0][0]

assert isinstance(prompt_received, list), "Prompt should be a list (conversation)"
assert isinstance(prompt_received[0]["content"], list), "Content should be a list (multimodal)"
assert isinstance(prompt_received[0]["content"][0], dict), "Content blocks should be dicts"
assert prompt_received[0]["content"][0]["type"] == "image", "First content block should be image type"
assert "image" in prompt_received[0]["content"][0], "Image key should be present"
assert isinstance(prompt_received[0]["content"][0]["image"], PILImage.Image), (
"Image should be preserved as PIL Image object, not flattened to string"
)


class TestGRPOTrainer(TrlTestCase):
def test_init_minimal(self):
Expand Down
11 changes: 8 additions & 3 deletions trl/experimental/openenv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,14 @@ def _generate_rollout_completions_server(

with profiling_context(trainer, "vLLM.generate_rollout_server"):
if as_chat:
# For chat mode, we need to pass messages format
# Since prompts are already formatted strings, we use generate instead
output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs)
# Prompts are raw message dicts; use .chat() so the vLLM server applies the chat template.
output = trainer.vllm_generation.vllm_client.chat(
messages=prompts,
**generation_kwargs,
chat_template_kwargs=trainer.chat_template_kwargs,
tools=trainer.tools or None,
chat_template=trainer.chat_template,
)
else:
output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs)

Expand Down
29 changes: 27 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,21 @@ class GRPOTrainer(BaseTrainer):
rollout_func (`RolloutFunc`, *optional*):
Function to use for generating completions. It receives the list of prompts allocated to the current
process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and
`"logprobs"` fields. Any other fields are forwarded to the reward functions. This feature is experimental
and may change or be removed at any time without prior notice.
`"logprobs"` fields. Any other fields are forwarded to the reward functions.

The `prompts` argument type depends on the dataset format:

- **Non-conversational** datasets: `prompts` is a `list[str]`.
- **Conversational** datasets: `prompts` is a `list[list[dict]]`, where each inner list is a sequence
of `{"role": ..., "content": ...}` messages. Content values may be strings or lists of typed content
blocks (e.g. `[{"type": "image", ...}, {"type": "text", ...}]` for multimodal inputs).

`rollout_func` is responsible for applying any required formatting (chat template, tokenization)
before calling its generation backend. Structured messages are passed through unmodified so that
multimodal content is not lost before rollout logic runs. The function receives the per-process
prompt slice with no duplication; it is responsible for returning the correct number of completions
per prompt (see `num_generations` / `num_generations_eval` on the trainer). This feature is
experimental and may change or be removed at any time without prior notice.
"""

_tag_names = ["trl", "grpo"]
Expand Down Expand Up @@ -425,6 +438,15 @@ def __init__(
"it with `pip install jmespath` to use this feature."
)
self.tools = tools or []

if self.rollout_func is not None and self.tools:
raise ValueError(
"rollout_func and tools cannot be used together. The tool-call loop passes fully-assembled "
"conversation histories to _generate_single_turn, which is incompatible with custom rollout "
"dispatch that expects original prompts. If you need tool-augmented generation, handle the "
"full tool execution loop inside your rollout_func."
)

self._sync_tool_dict = {}
self._async_tool_dict = {}
if self.tools:
Expand Down Expand Up @@ -1156,6 +1178,9 @@ def _generate_single_turn(self, prompts: list):
self.vllm_generation.sync_weights()
self._last_loaded_step = self.state.global_step

# Pass prompts to rollout_func preserving structured messages.
# Chat templating must happen inside rollout_func, at the backend boundary, so that
# multimodal content (images, typed content blocks) is not lost before rollout logic runs.
output = self.rollout_func(prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
missing_keys = required_keys - output.keys()
Expand Down