Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f10285e
support prompts or token IDs in VLLMClient and update API request han…
qgallouedec Mar 5, 2026
7d2bb67
test
qgallouedec Mar 5, 2026
3b356ac
consistency
qgallouedec Mar 5, 2026
82c4508
fix
qgallouedec Mar 5, 2026
3ea2fcf
another fix
qgallouedec Mar 5, 2026
445f4ba
fix docstring
qgallouedec Mar 5, 2026
8c6c88d
Add support for multi-modal inputs in VLLMClient and vllm_serve
qgallouedec Mar 5, 2026
f617b2d
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
eaffd67
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
f3f6a5d
Move `rollout_func from `_generate_single_turn` to `_generate`
qgallouedec Mar 6, 2026
d417543
fix style
qgallouedec Mar 6, 2026
4b927d6
support multi-image
qgallouedec Mar 6, 2026
029fc1f
style
qgallouedec Mar 6, 2026
20b4039
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 6, 2026
b8e3912
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 6, 2026
07181cb
Fix handling of images in OnlineDPOTrainer to ensure proper structure…
qgallouedec Mar 7, 2026
6ff1e56
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 7, 2026
9f340e4
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 7, 2026
d138be7
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 7, 2026
f033e63
revert doc modif
qgallouedec Mar 9, 2026
5a1f609
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 9, 2026
1eb3540
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
d3f7971
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
319d52a
simplify multimodal
qgallouedec Mar 9, 2026
d5e1906
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
4ccadcf
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
0558dc9
Merge branch 'main' into move-rollout-func
qgallouedec Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,44 @@ def test_compute_entropy_all_masked(self):
class TestGRPORolloutDispatch:
def _make_trainer(self):
trainer = object.__new__(GRPOTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
trainer.accelerator = SimpleNamespace(
device=torch.device("cpu"),
is_main_process=True,
gather=lambda t: t,
)
trainer.args = SimpleNamespace(report_to=[])
trainer.model = SimpleNamespace(training=True)
trainer.state = SimpleNamespace(global_step=2)
trainer.state = SimpleNamespace(global_step=2, num_input_tokens_seen=0)
trainer._last_loaded_step = 1
trainer.use_vllm = False
trainer.use_transformers_paged = False
trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock())
trainer.processing_class = SimpleNamespace(
batch_decode=MagicMock(return_value=["decoded"]),
)
trainer.tools = None
trainer.eos_token_id = 2
trainer.pad_token_id = 0
trainer._metrics = {
"train": {
"num_tokens": [],
**{
k: []
for k in [
"completions/mean_length",
"completions/min_length",
"completions/max_length",
"completions/clipped_ratio",
"completions/mean_terminated_length",
"completions/min_terminated_length",
"completions/max_terminated_length",
]
},
}
}
return trainer

def test_generate_single_turn_prefers_rollout_func(self):
def test_generate_prefers_rollout_func(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(
return_value={
Expand All @@ -183,33 +210,32 @@ def test_generate_single_turn_prefers_rollout_func(self):
}
)

prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"])
result = trainer._generate(["prompt"])

assert prompt_ids == [[1]]
assert completion_ids == [[2]]
assert logprobs == [[-0.1]]
assert extra_fields == {"env_mask": [[1]]}
assert result[0] == [[1]] # prompt_ids
assert result[1] == [[2]] # completion_ids
assert result[2] == [[1]] # tool_mask (from env_mask)
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self):
def test_generate_rollout_func_syncs_vllm_weights_when_needed(self):
trainer = self._make_trainer()
trainer.use_vllm = True
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

trainer._generate_single_turn(["prompt"])
trainer._generate(["prompt"])

trainer.vllm_generation.sync_weights.assert_called_once()
assert trainer._last_loaded_step == trainer.state.global_step
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self):
def test_generate_rollout_func_raises_when_required_keys_are_missing(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]})

with pytest.raises(ValueError, match="rollout_func must return keys"):
trainer._generate_single_turn(["prompt"])
trainer._generate(["prompt"])


class TestGRPOTrainer(TrlTestCase):
Expand Down
40 changes: 20 additions & 20 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,25 +1216,6 @@ def _generate_single_turn(self, prompts: list):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

if self.rollout_func is not None:
# Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities.
if self.use_vllm and self.state.global_step != self._last_loaded_step:
with profiling_context(self, "sync_weights"):
self.vllm_generation.sync_weights()
self._last_loaded_step = self.state.global_step

# Pass prompts to rollout_func preserving structured messages.
# Chat templating must happen inside rollout_func, at the backend boundary, so that
# multimodal content (images, typed content blocks) is not lost before rollout logic runs.
output = self.rollout_func(prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
missing_keys = required_keys - output.keys()
if missing_keys:
missing_keys_list = sorted(missing_keys)
raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.")
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields

# Generate completions using either vLLM or regular generation
if self.use_vllm:
# Sync weights if training step changed
Expand Down Expand Up @@ -1521,7 +1502,26 @@ def _generate(self, prompts: list):
# Copy the prompts to avoid modifying the original list
prompts = copy.deepcopy(prompts)

prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
if self.rollout_func is not None:
# Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities.
if self.use_vllm and self.state.global_step != self._last_loaded_step:
with profiling_context(self, "sync_weights"):
self.vllm_generation.sync_weights()
self._last_loaded_step = self.state.global_step

# Pass prompts to rollout_func preserving structured messages.
# Chat templating must happen inside rollout_func, at the backend boundary, so that
# multimodal content (images, typed content blocks) is not lost before rollout logic runs.
output = self.rollout_func(prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
missing_keys = required_keys - output.keys()
if missing_keys:
missing_keys_list = sorted(missing_keys)
raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.")
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
Comment thread
qgallouedec marked this conversation as resolved.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject rollout_func with tools on non-vLLM backends

After moving rollout_func handling into _generate, the first turn can return non-None logprobs from rollout_func, but post-tool turns in _tool_call_loop now always use _generate_single_turn, which returns post_tool_logprobs=None for regular and paged Transformers generation. In that case, _tool_call_loop still takes the if logprobs is not None branch and later indexes post_tool_logprobs[idx], causing a runtime crash when a tool call is present. This affects runs that set both rollout_func and tools without vLLM, so the combination should be blocked or post_tool_logprobs should be normalized before use.

Useful? React with 👍 / 👎.

else:
prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
Expand Down
Loading