diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 713e3fb893..3646e8bff4 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -210,6 +210,105 @@ def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing with pytest.raises(ValueError, match="rollout_func must return keys"): trainer._generate_single_turn(["prompt"]) + def test_generate_single_turn_rollout_func_no_extra_fields(self): + trainer = self._make_trainer() + trainer.rollout_func = MagicMock( + return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + ) + + _, _, _, extra_fields = trainer._generate_single_turn(["prompt"]) + + assert extra_fields == {} + + def test_generate_single_turn_rollout_func_does_not_sync_when_step_unchanged(self): + trainer = self._make_trainer() + trainer.use_vllm = True + trainer._last_loaded_step = trainer.state.global_step # already in sync + trainer.rollout_func = MagicMock( + return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + ) + + trainer._generate_single_turn(["prompt"]) + + trainer.vllm_generation.sync_weights.assert_not_called() + + def test_generate_single_turn_rollout_func_receives_structured_messages_for_conversational_prompts(self): + # Regression test for issue #5120: rollout_func must receive structured messages (list[dict]), not + # chat-template-formatted strings. Flattening to strings destroys multimodal content (images, typed + # content blocks) before rollout logic can access it. + trainer = self._make_trainer() + trainer.processing_class = MagicMock() + trainer.chat_template_kwargs = {} + trainer.rollout_func = MagicMock( + return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + ) + conversational_prompt = [{"role": "user", "content": "hello"}] + + with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl: + trainer._generate_single_turn([conversational_prompt]) + + # apply_chat_template must NOT be called before rollout_func — templating is rollout_func's responsibility + mock_tpl.assert_not_called() + # rollout_func receives the raw structured messages, not a formatted string + trainer.rollout_func.assert_called_once_with([conversational_prompt], trainer) + + def test_generate_single_turn_rollout_func_passes_non_conversational_prompt_unchanged(self): + trainer = self._make_trainer() + trainer.processing_class = MagicMock() + trainer.chat_template_kwargs = {} + trainer.rollout_func = MagicMock( + return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + ) + + with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl: + trainer._generate_single_turn(["plain string prompt"]) + + mock_tpl.assert_not_called() + trainer.rollout_func.assert_called_once_with(["plain string prompt"], trainer) + + @require_vision + def test_generate_single_turn_rollout_func_receives_real_multimodal_messages(self): + """Test for issue #5120: rollout_func must receive structured multimodal messages + with image objects preserved, not flattened strings that destroy image content. + """ + from PIL import Image as PILImage + + trainer = self._make_trainer() + trainer.processing_class = MagicMock() + trainer.chat_template_kwargs = {} + trainer.use_vllm = False + trainer.use_transformers_paged = False + trainer._last_loaded_step = trainer.state.global_step + + received_prompts = [] + + def capture_rollout_func(prompts, trainer): + received_prompts.append(prompts) + return {"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + + trainer.rollout_func = capture_rollout_func + + test_image = PILImage.new("RGB", (10, 10)) + multimodal_prompt = [ + {"role": "user", "content": [{"type": "image", "image": test_image}, {"type": "text", "text": "What is in this image?"}]} + ] + + with patch("trl.trainer.grpo_trainer.apply_chat_template") as mock_tpl: + trainer._generate_single_turn([multimodal_prompt]) + + mock_tpl.assert_not_called() + assert len(received_prompts) == 1 + prompt_received = received_prompts[0][0] + + assert isinstance(prompt_received, list), "Prompt should be a list (conversation)" + assert isinstance(prompt_received[0]["content"], list), "Content should be a list (multimodal)" + assert isinstance(prompt_received[0]["content"][0], dict), "Content blocks should be dicts" + assert prompt_received[0]["content"][0]["type"] == "image", "First content block should be image type" + assert "image" in prompt_received[0]["content"][0], "Image key should be present" + assert isinstance(prompt_received[0]["content"][0]["image"], PILImage.Image), ( + "Image should be preserved as PIL Image object, not flattened to string" + ) + class TestGRPOTrainer(TrlTestCase): def test_init_minimal(self): diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index 5c4c710132..e7eec67b9e 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -127,9 +127,14 @@ def _generate_rollout_completions_server( with profiling_context(trainer, "vLLM.generate_rollout_server"): if as_chat: - # For chat mode, we need to pass messages format - # Since prompts are already formatted strings, we use generate instead - output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) + # Prompts are raw message dicts; use .chat() so the vLLM server applies the chat template. + output = trainer.vllm_generation.vllm_client.chat( + messages=prompts, + **generation_kwargs, + chat_template_kwargs=trainer.chat_template_kwargs, + tools=trainer.tools or None, + chat_template=trainer.chat_template, + ) else: output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e34642bd13..d72d83731d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -225,8 +225,21 @@ class GRPOTrainer(BaseTrainer): rollout_func (`RolloutFunc`, *optional*): Function to use for generating completions. It receives the list of prompts allocated to the current process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and - `"logprobs"` fields. Any other fields are forwarded to the reward functions. This feature is experimental - and may change or be removed at any time without prior notice. + `"logprobs"` fields. Any other fields are forwarded to the reward functions. + + The `prompts` argument type depends on the dataset format: + + - **Non-conversational** datasets: `prompts` is a `list[str]`. + - **Conversational** datasets: `prompts` is a `list[list[dict]]`, where each inner list is a sequence + of `{"role": ..., "content": ...}` messages. Content values may be strings or lists of typed content + blocks (e.g. `[{"type": "image", ...}, {"type": "text", ...}]` for multimodal inputs). + + `rollout_func` is responsible for applying any required formatting (chat template, tokenization) + before calling its generation backend. Structured messages are passed through unmodified so that + multimodal content is not lost before rollout logic runs. The function receives the per-process + prompt slice with no duplication; it is responsible for returning the correct number of completions + per prompt (see `num_generations` / `num_generations_eval` on the trainer). This feature is + experimental and may change or be removed at any time without prior notice. """ _tag_names = ["trl", "grpo"] @@ -425,6 +438,15 @@ def __init__( "it with `pip install jmespath` to use this feature." ) self.tools = tools or [] + + if self.rollout_func is not None and self.tools: + raise ValueError( + "rollout_func and tools cannot be used together. The tool-call loop passes fully-assembled " + "conversation histories to _generate_single_turn, which is incompatible with custom rollout " + "dispatch that expects original prompts. If you need tool-augmented generation, handle the " + "full tool execution loop inside your rollout_func." + ) + self._sync_tool_dict = {} self._async_tool_dict = {} if self.tools: @@ -1156,6 +1178,9 @@ def _generate_single_turn(self, prompts: list): self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step + # Pass prompts to rollout_func preserving structured messages. + # Chat templating must happen inside rollout_func, at the backend boundary, so that + # multimodal content (images, typed content blocks) is not lost before rollout logic runs. output = self.rollout_func(prompts, self) required_keys = {"prompt_ids", "completion_ids", "logprobs"} missing_keys = required_keys - output.keys()