From 8a0fd35b6ef1da1e39b6613d376ea7d46e83d608 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:35:03 +0100 Subject: [PATCH 1/4] Move rollout_func to top-level in _generate_single_turn and keep vLLM weight sync --- trl/trainer/grpo_trainer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3e16c8979d..c5c483ccf7 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1157,6 +1157,22 @@ def _generate_single_turn(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" + if self.rollout_func is not None: + # Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities. + if self.use_vllm and self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + output = self.rollout_func(prompts, self) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + missing_keys = required_keys - output.keys() + if missing_keys: + missing_keys_list = sorted(missing_keys) + raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.") + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields + # Generate completions using either vLLM or regular generation if self.use_vllm: # Sync weights if training step changed From 7e66f62c51a4f167568a547da83c0ddd95eb6386 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:35:44 +0100 Subject: [PATCH 2/4] Remove rollout wiring from vLLM backend construction --- trl/trainer/grpo_trainer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c5c483ccf7..e34642bd13 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -646,13 +646,6 @@ def cast_outputs_to_original_dtype(module, args, output): if self.use_vllm: # Initialize vLLM generation backend - # Wrap rollout_func to capture trainer context if provided - rollout_func = None - if self.rollout_func is not None: - - def rollout_func(prompts): - return self.rollout_func(prompts, self) - self.vllm_generation = VLLMGeneration( model=self.model, accelerator=self.accelerator, @@ -688,7 +681,6 @@ def rollout_func(prompts): chat_template=self.chat_template, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, - rollout_func=rollout_func, ) self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation else: From 3f11a74c19520eff49c04cb72c30ebc17f69ecd6 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:38:48 +0100 Subject: [PATCH 3/4] Add GRPO rollout-dispatch tests --- tests/test_grpo_trainer.py | 56 +++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 6bcaab5fd6..713e3fb893 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -16,7 +16,8 @@ import os import warnings from collections.abc import Callable -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -157,6 +158,59 @@ def test_compute_entropy_all_masked(self): torch.testing.assert_close(entropy_mask, expected_mask) +class TestGRPORolloutDispatch: + def _make_trainer(self): + trainer = object.__new__(GRPOTrainer) + trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True) + trainer.args = SimpleNamespace(report_to=[]) + trainer.model = SimpleNamespace(training=True) + trainer.state = SimpleNamespace(global_step=2) + trainer._last_loaded_step = 1 + trainer.use_vllm = False + trainer.use_transformers_paged = False + trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock()) + return trainer + + def test_generate_single_turn_prefers_rollout_func(self): + trainer = self._make_trainer() + trainer.rollout_func = MagicMock( + return_value={ + "prompt_ids": [[1]], + "completion_ids": [[2]], + "logprobs": [[-0.1]], + "env_mask": [[1]], + } + ) + + prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"]) + + assert prompt_ids == [[1]] + assert completion_ids == [[2]] + assert logprobs == [[-0.1]] + assert extra_fields == {"env_mask": [[1]]} + trainer.rollout_func.assert_called_once_with(["prompt"], trainer) + + def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self): + trainer = self._make_trainer() + trainer.use_vllm = True + trainer.rollout_func = MagicMock( + return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]} + ) + + trainer._generate_single_turn(["prompt"]) + + trainer.vllm_generation.sync_weights.assert_called_once() + assert trainer._last_loaded_step == trainer.state.global_step + trainer.rollout_func.assert_called_once_with(["prompt"], trainer) + + def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self): + trainer = self._make_trainer() + trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]}) + + with pytest.raises(ValueError, match="rollout_func must return keys"): + trainer._generate_single_turn(["prompt"]) + + class TestGRPOTrainer(TrlTestCase): def test_init_minimal(self): # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset From ed3bbf2cf12e409e2420da4ae3185539412b0e20 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:56:54 +0100 Subject: [PATCH 4/4] Remove rollout_func support from VLLMGeneration --- trl/generation/vllm_generation.py | 204 ++++++++++++------------------ 1 file changed, 83 insertions(+), 121 deletions(-) diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 3956a70aeb..780117965e 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -18,7 +18,6 @@ import logging import math import os -from collections.abc import Callable from contextlib import nullcontext from typing import TYPE_CHECKING @@ -29,7 +28,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available -from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_vllm +from ..data_utils import is_conversational, prepare_multimodal_messages_vllm from ..extras.profiling import ProfilingContext from ..import_utils import is_vllm_available from ..trainer.utils import ensure_master_addr_port @@ -171,12 +170,6 @@ class VLLMGeneration: Additional keyword arguments to customize the chat template used by the model. tools (`list`, *optional*): Tools available for tool calling during chat generation. - rollout_func (`Callable`, *optional*): Optional custom rollout function that accepts prompts and returns - a dict with 'prompt_ids', 'completion_ids', 'logprobs', and optional extra fields. Should be a - single-argument callable: rollout_func(prompts) -> dict. To pass additional context (e.g., trainer), use a - closure or functools.partial: - rollout_func = lambda prompts: my_custom_rollout(prompts, trainer) - The closure will hold a reference to trainer and see its state updates. """ def __init__( @@ -213,7 +206,6 @@ def __init__( chat_template: str | None = None, chat_template_kwargs: dict | None = None, tools: list | None = None, - rollout_func: Callable | None = None, ): self.model = model self.accelerator = accelerator @@ -252,7 +244,6 @@ def __init__( self.chat_template = chat_template self.chat_template_kwargs = chat_template_kwargs or {} self.tools = tools - self.rollout_func = rollout_func self._init_vllm() @@ -499,14 +490,12 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte """ profiler = profiler or nullcontext() accelerator = self.accelerator - rollout_func = self.rollout_func temperature = self.temperature top_p = self.top_p top_k = self.top_k min_p = self.min_p repetition_penalty = self.repetition_penalty max_completion_length = self.max_completion_length - processing_class = self.processing_class chat_template_kwargs = self.chat_template_kwargs tools = self.tools chat_template = self.chat_template @@ -553,26 +542,16 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte "generation_kwargs": self.generation_kwargs, } with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): - if rollout_func is not None: - # Pass all prompts (with duplicates) to rollout_func for consistency with colocate mode - rollout_prompts = all_prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"] - for p in rollout_prompts - ] - output = rollout_func(rollout_prompts) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat( + messages=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) else: - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat( - messages=ordered_set_of_prompts, - **sampling_params, - chat_template_kwargs=chat_template_kwargs, - tools=tools, - chat_template=chat_template, - ) - else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) # Extract required fields and collect any extra fields for reward functions required_keys = {"prompt_ids", "completion_ids", "logprobs"} extra_fields = {k: v for k, v in output.items() if k not in required_keys} @@ -585,12 +564,9 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte broadcast_object_list(obj_list, from_process=0) all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] - # When using rollout_func, it handles its own generation logic and returns one result per prompt. - # When NOT using rollout_func, vllm_client.generate(n=num_generations) returns num_generations - # completions per prompt, so we need to duplicate prompt_ids to match. - if self.rollout_func is None: - # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times - all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] + # vllm_client.generate/chat(n=num_generations) returns num_generations completions per prompt. + # Duplicate prompt_ids to align with per-completion entries. + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] process_slice = slice( accelerator.process_index * len(prompts), @@ -610,97 +586,83 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.mode == "colocate": - if rollout_func is not None: - rollout_prompts = prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template({"prompt": prompt}, processing_class, **chat_template_kwargs)["prompt"] - for prompt in rollout_prompts - ] - output = rollout_func(rollout_prompts) - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - prompt_ids = output["prompt_ids"] - completion_ids = output["completion_ids"] - logprobs = output["logprobs"] + if Version(vllm.__version__) <= Version("0.10.2"): + structured_outputs_key = "guided_decoding" + if self.structured_outputs_regex: + structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None else: - if Version(vllm.__version__) <= Version("0.10.2"): - structured_outputs_key = "guided_decoding" - if self.structured_outputs_regex: - structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None + structured_outputs_key = "structured_outputs" + if self.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) else: - structured_outputs_key = "structured_outputs" - if self.structured_outputs_regex: - structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) - else: - structured_outputs = None + structured_outputs = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": 0.0 if min_p is None else min_p, + "max_tokens": max_completion_length, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + generation_kwargs[structured_outputs_key] = structured_outputs + generation_kwargs.update(self.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": repetition_penalty, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "min_p": 0.0 if min_p is None else min_p, - "max_tokens": max_completion_length, - "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only - } - generation_kwargs[structured_outputs_key] = structured_outputs - generation_kwargs.update(self.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts) - gathered_prompts = [None for _ in range(self.tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) - all_prompts = [p for sublist in gathered_prompts for p in sublist] + if self.tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts = prompts + + if self.enable_sleep_mode: + self.llm.wake_up(tags=["kv_cache"]) + + with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat( + all_prompts, + sampling_params=sampling_params, + use_tqdm=False, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) else: - all_prompts = prompts + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) - if self.enable_sleep_mode: - self.llm.wake_up(tags=["kv_cache"]) + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [sanitize_logprob(next(iter(lp.values()))) for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] - with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): - if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat( - all_prompts, - sampling_params=sampling_params, - use_tqdm=False, - chat_template_kwargs=chat_template_kwargs, - tools=tools, - chat_template=chat_template, - ) - else: - all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) - - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [sanitize_logprob(next(iter(lp.values()))) for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - logprobs = all_logprobs + if self.tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + prompt_ids = all_prompt_ids + completion_ids = all_completion_ids + logprobs = all_logprobs - extra_fields = {} # No extra fields for colocate mode + extra_fields = {} # No extra fields for colocate mode - if self.enable_sleep_mode: - self.llm.sleep(level=2) + if self.enable_sleep_mode: + self.llm.sleep(level=2) return prompt_ids, completion_ids, logprobs, extra_fields