diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index c9890bf1035..2b275bae81e 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -162,17 +162,44 @@ def test_compute_entropy_all_masked(self): class TestGRPORolloutDispatch: def _make_trainer(self): trainer = object.__new__(GRPOTrainer) - trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True) + trainer.accelerator = SimpleNamespace( + device=torch.device("cpu"), + is_main_process=True, + gather=lambda t: t, + ) trainer.args = SimpleNamespace(report_to=[]) trainer.model = SimpleNamespace(training=True) - trainer.state = SimpleNamespace(global_step=2) + trainer.state = SimpleNamespace(global_step=2, num_input_tokens_seen=0) trainer._last_loaded_step = 1 trainer.use_vllm = False trainer.use_transformers_paged = False trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock()) + trainer.processing_class = SimpleNamespace( + batch_decode=MagicMock(return_value=["decoded"]), + ) + trainer.tools = None + trainer.eos_token_id = 2 + trainer.pad_token_id = 0 + trainer._metrics = { + "train": { + "num_tokens": [], + **{ + k: [] + for k in [ + "completions/mean_length", + "completions/min_length", + "completions/max_length", + "completions/clipped_ratio", + "completions/mean_terminated_length", + "completions/min_terminated_length", + "completions/max_terminated_length", + ] + }, + } + } return trainer - def test_generate_single_turn_prefers_rollout_func(self): + def test_generate_prefers_rollout_func(self): trainer = self._make_trainer() trainer.rollout_func = MagicMock( return_value={ @@ -183,33 +210,32 @@ def test_generate_single_turn_prefers_rollout_func(self): } ) - prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"]) + result = trainer._generate(["prompt"]) - assert prompt_ids == [[1]] - assert completion_ids == [[2]] - assert logprobs == [[-0.1]] - assert extra_fields == {"env_mask": [[1]]} + assert result[0] == [[1]] # prompt_ids + assert result[1] == [[2]] # completion_ids + assert result[2] == [[1]] # tool_mask (from env_mask) trainer.rollout_func.assert_called_once_with(["prompt"], trainer) - def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self): + def test_generate_rollout_func_syncs_vllm_weights_when_needed(self): trainer = self._make_trainer() trainer.use_vllm = True trainer.rollout_func = MagicMock( return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} ) - trainer._generate_single_turn(["prompt"]) + trainer._generate(["prompt"]) trainer.vllm_generation.sync_weights.assert_called_once() assert trainer._last_loaded_step == trainer.state.global_step trainer.rollout_func.assert_called_once_with(["prompt"], trainer) - def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self): + def test_generate_rollout_func_raises_when_required_keys_are_missing(self): trainer = self._make_trainer() trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]}) with pytest.raises(ValueError, match="rollout_func must return keys"): - trainer._generate_single_turn(["prompt"]) + trainer._generate(["prompt"]) class TestGRPOTrainer(TrlTestCase): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3ec051e4086..811c89a6691 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1216,25 +1216,6 @@ def _generate_single_turn(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - if self.rollout_func is not None: - # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. - if self.use_vllm and self.state.global_step != self._last_loaded_step: - with profiling_context(self, "sync_weights"): - self.vllm_generation.sync_weights() - self._last_loaded_step = self.state.global_step - - # Pass prompts to rollout_func preserving structured messages. - # Chat templating must happen inside rollout_func, at the backend boundary, so that - # multimodal content (images, typed content blocks) is not lost before rollout logic runs. - output = self.rollout_func(prompts, self) - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - missing_keys = required_keys - output.keys() - if missing_keys: - missing_keys_list = sorted(missing_keys) - raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields - # Generate completions using either vLLM or regular generation if self.use_vllm: # Sync weights if training step changed @@ -1521,7 +1502,26 @@ def _generate(self, prompts: list): # Copy the prompts to avoid modifying the original list prompts = copy.deepcopy(prompts) - prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + if self.rollout_func is not None: + # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. + if self.use_vllm and self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + # Pass prompts to rollout_func preserving structured messages. + # Chat templating must happen inside rollout_func, at the backend boundary, so that + # multimodal content (images, typed content blocks) is not lost before rollout logic runs. + output = self.rollout_func(prompts, self) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + missing_keys = required_keys - output.keys() + if missing_keys: + missing_keys_list = sorted(missing_keys) + raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"] + else: + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}):