Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f10285e
support prompts or token IDs in VLLMClient and update API request han…
qgallouedec Mar 5, 2026
7d2bb67
test
qgallouedec Mar 5, 2026
3b356ac
consistency
qgallouedec Mar 5, 2026
82c4508
fix
qgallouedec Mar 5, 2026
3ea2fcf
another fix
qgallouedec Mar 5, 2026
445f4ba
fix docstring
qgallouedec Mar 5, 2026
8c6c88d
Add support for multi-modal inputs in VLLMClient and vllm_serve
qgallouedec Mar 5, 2026
f617b2d
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
eaffd67
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 6, 2026
f3f6a5d
Move `rollout_func from `_generate_single_turn` to `_generate`
qgallouedec Mar 6, 2026
d417543
fix style
qgallouedec Mar 6, 2026
4b927d6
support multi-image
qgallouedec Mar 6, 2026
029fc1f
style
qgallouedec Mar 6, 2026
20b4039
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 6, 2026
b8e3912
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 6, 2026
07181cb
Fix handling of images in OnlineDPOTrainer to ensure proper structure…
qgallouedec Mar 7, 2026
6ff1e56
Merge branch 'main' into vllm-accept-token-ids
qgallouedec Mar 7, 2026
9f340e4
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 7, 2026
d138be7
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 7, 2026
f033e63
revert doc modif
qgallouedec Mar 9, 2026
5a1f609
Merge branch 'vllm-accept-token-ids' into vllm-support-image-with-raw…
qgallouedec Mar 9, 2026
1eb3540
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
d3f7971
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
319d52a
simplify multimodal
qgallouedec Mar 9, 2026
d5e1906
Merge branch 'main' into vllm-support-image-with-raw-token
qgallouedec Mar 9, 2026
4ccadcf
Merge branch 'vllm-support-image-with-raw-token' into move-rollout-func
qgallouedec Mar 9, 2026
0558dc9
Merge branch 'main' into move-rollout-func
qgallouedec Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,44 @@ def test_compute_entropy_all_masked(self):
class TestGRPORolloutDispatch:
def _make_trainer(self):
trainer = object.__new__(GRPOTrainer)
trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True)
trainer.accelerator = SimpleNamespace(
device=torch.device("cpu"),
is_main_process=True,
gather=lambda t: t,
)
trainer.args = SimpleNamespace(report_to=[])
trainer.model = SimpleNamespace(training=True)
trainer.state = SimpleNamespace(global_step=2)
trainer.state = SimpleNamespace(global_step=2, num_input_tokens_seen=0)
trainer._last_loaded_step = 1
trainer.use_vllm = False
trainer.use_transformers_paged = False
trainer.vllm_generation = SimpleNamespace(sync_weights=MagicMock())
trainer.processing_class = SimpleNamespace(
batch_decode=MagicMock(return_value=["decoded"]),
)
trainer.tools = None
trainer.eos_token_id = 2
trainer.pad_token_id = 0
trainer._metrics = {
"train": {
"num_tokens": [],
**{
k: []
for k in [
"completions/mean_length",
"completions/min_length",
"completions/max_length",
"completions/clipped_ratio",
"completions/mean_terminated_length",
"completions/min_terminated_length",
"completions/max_terminated_length",
]
},
}
}
return trainer

def test_generate_single_turn_prefers_rollout_func(self):
def test_generate_prefers_rollout_func(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(
return_value={
Expand All @@ -183,33 +210,32 @@ def test_generate_single_turn_prefers_rollout_func(self):
}
)

prompt_ids, completion_ids, logprobs, extra_fields = trainer._generate_single_turn(["prompt"])
result = trainer._generate(["prompt"])

assert prompt_ids == [[1]]
assert completion_ids == [[2]]
assert logprobs == [[-0.1]]
assert extra_fields == {"env_mask": [[1]]}
assert result[0] == [[1]] # prompt_ids
assert result[1] == [[2]] # completion_ids
assert result[2] == [[1]] # tool_mask (from env_mask)
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_syncs_vllm_weights_when_needed(self):
def test_generate_rollout_func_syncs_vllm_weights_when_needed(self):
trainer = self._make_trainer()
trainer.use_vllm = True
trainer.rollout_func = MagicMock(
return_value={"prompt_ids": [[1]], "completion_ids": [[2]], "logprobs": [[0.0]]}
)

trainer._generate_single_turn(["prompt"])
trainer._generate(["prompt"])

trainer.vllm_generation.sync_weights.assert_called_once()
assert trainer._last_loaded_step == trainer.state.global_step
trainer.rollout_func.assert_called_once_with(["prompt"], trainer)

def test_generate_single_turn_rollout_func_raises_when_required_keys_are_missing(self):
def test_generate_rollout_func_raises_when_required_keys_are_missing(self):
trainer = self._make_trainer()
trainer.rollout_func = MagicMock(return_value={"prompt_ids": [[1]], "completion_ids": [[2]]})

with pytest.raises(ValueError, match="rollout_func must return keys"):
trainer._generate_single_turn(["prompt"])
trainer._generate(["prompt"])


class TestGRPOTrainer(TrlTestCase):
Expand Down
98 changes: 97 additions & 1 deletion tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from transformers.testing_utils import torch_device

from trl.generation.vllm_client import VLLMClient
Expand All @@ -31,6 +31,7 @@
kill_process,
require_3_accelerators,
require_torch_multi_accelerator,
require_vision,
require_vllm,
)

Expand Down Expand Up @@ -874,3 +875,98 @@ def teardown_class(cls):
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)


@pytest.mark.slow
@require_vllm
@require_vision
class TestVLLMClientServerVLM(TrlTestCase):
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"

@classmethod
def setup_class(cls):
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)

# Initialize the client (no communicator needed for generation-only tests)
cls.client = VLLMClient(connection_timeout=240, host="localhost")

def test_generate_with_token_ids_and_image(self):
from PIL import Image

processor = AutoProcessor.from_pretrained(self.model_id)
image1 = Image.new("RGB", (64, 64), color="red")
image2 = Image.new("RGB", (64, 64), color="blue")
image3 = Image.new("RGB", (64, 64), color="green")
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "image": image1},
{"type": "image", "image": image2},
{"type": "text", "text": "What are the differences between these two images?"},
],
}
],
[
{
"role": "user",
"content": [
{"type": "image", "image": image3},
{"type": "text", "text": "What is the color of this image?"},
],
}
],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[0])

def test_generate_with_token_ids_mixed_images(self):
"""Test a batch where one prompt has an image and the other does not."""
from PIL import Image

processor = AutoProcessor.from_pretrained(self.model_id)
image = Image.new("RGB", (64, 64), color="red")
messages = [
[
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}],
}
],
[
{
"role": "user",
"content": [{"type": "text", "text": "What is 1+1?"}],
}
],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in prompt_ids[1])
assert all(isinstance(tok, int) for tok in completion_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[1])

@classmethod
def teardown_class(cls):
kill_process(cls.server_process)
4 changes: 3 additions & 1 deletion trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,9 @@ def _generate_vllm_server(self, prompts, images=None):
# prompt individually.
ordered_set_of_prompts = all_prompts[:: self.num_generations]
if has_images:
ordered_set_of_images = all_images[:: self.num_generations]
ordered_set_of_images = [
[img] if img is not None else None for img in all_images[:: self.num_generations]
]
else:
ordered_set_of_images = None
completion_ids = self.vllm_client.generate(
Expand Down
13 changes: 9 additions & 4 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def generate(
Args:
prompts (`list[str]` or `list[list[int]]`):
List of text prompts or list of token ID lists for which the model will generate completions.
images (`list[PIL.Image]`, *optional*):
List of PIL Images to send along with the prompts. Only valid when `prompts` is a list of strings.
images (`list[list[PIL.Image] | None]`, *optional*):
List of image lists for VLM support. Each element is a list of PIL images for the corresponding prompt,
or `None` if no images for that prompt.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each prompt.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Expand Down Expand Up @@ -260,8 +261,12 @@ def generate(
"""
url = f"{self.base_url}/generate/"

# Convert PIL images to base64 strings
images = [pil_to_base64(img) for img in images] if images else None
# Convert PIL images to base64 strings. Each element is a list of images for the corresponding prompt,
# or None if no images for that prompt.
if images:
images = [
[pil_to_base64(img) for img in img_list] if img_list is not None else None for img_list in images
]

response = self.session.post(
url,
Expand Down
28 changes: 12 additions & 16 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ async def get_world_size():

class GenerateRequest(BaseModel):
prompts: list[str] | list[list[int]]
images: list[str] | None = None
images: list[list[str] | None] | None = None
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
Expand All @@ -518,8 +518,8 @@ async def generate(request: GenerateRequest):
request (`GenerateRequest`):
- `prompts` (list of `str` or list of list of `int`): A list of prompts. It accepts either text strings
or pre-tokenized token ID lists. When text strings are provided, `images` can optionally be included.
- `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process
along with prompts.
- `images` (list of list of `str` or `None`, *optional*): A list of image lists. Each element is a list
of base64-encoded images for the corresponding prompt, or `None` if no images for that prompt.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
generation.
Expand Down Expand Up @@ -571,19 +571,15 @@ async def generate(request: GenerateRequest):
```
"""
# Build vLLM-compatible prompt inputs
if request.prompts and isinstance(request.prompts[0], list):
# Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM
prompts = [{"prompt_token_ids": ids} for ids in request.prompts]
else:
# Text prompts path: build prompt dicts with optional images
request.images = request.images or [None] * len(request.prompts)

prompts = []
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)
is_token_ids = request.prompts and isinstance(request.prompts[0], list)
request.images = request.images or [None] * len(request.prompts)

prompts = []
for prompt, image_list in zip(request.prompts, request.images, strict=True):
row = {"prompt_token_ids": prompt} if is_token_ids else {"prompt": prompt}
if image_list is not None:
row["multi_modal_data"] = {"image": [Image.open(BytesIO(base64.b64decode(img))) for img in image_list]}
prompts.append(row)

generation_kwargs = {
"n": request.n,
Expand Down
40 changes: 20 additions & 20 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,25 +1216,6 @@ def _generate_single_turn(self, prompts: list):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

if self.rollout_func is not None:
# Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities.
if self.use_vllm and self.state.global_step != self._last_loaded_step:
with profiling_context(self, "sync_weights"):
self.vllm_generation.sync_weights()
self._last_loaded_step = self.state.global_step

# Pass prompts to rollout_func preserving structured messages.
# Chat templating must happen inside rollout_func, at the backend boundary, so that
# multimodal content (images, typed content blocks) is not lost before rollout logic runs.
output = self.rollout_func(prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
missing_keys = required_keys - output.keys()
if missing_keys:
missing_keys_list = sorted(missing_keys)
raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.")
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
return output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields

# Generate completions using either vLLM or regular generation
if self.use_vllm:
# Sync weights if training step changed
Expand Down Expand Up @@ -1521,7 +1502,26 @@ def _generate(self, prompts: list):
# Copy the prompts to avoid modifying the original list
prompts = copy.deepcopy(prompts)

prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
if self.rollout_func is not None:
# Keep vLLM weights in sync for custom rollouts that rely on vLLM utilities.
if self.use_vllm and self.state.global_step != self._last_loaded_step:
with profiling_context(self, "sync_weights"):
self.vllm_generation.sync_weights()
self._last_loaded_step = self.state.global_step

# Pass prompts to rollout_func preserving structured messages.
# Chat templating must happen inside rollout_func, at the backend boundary, so that
# multimodal content (images, typed content blocks) is not lost before rollout logic runs.
output = self.rollout_func(prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
missing_keys = required_keys - output.keys()
if missing_keys:
missing_keys_list = sorted(missing_keys)
raise ValueError(f"rollout_func must return keys {missing_keys_list} in its output dict.")
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids, completion_ids, logprobs = output["prompt_ids"], output["completion_ids"], output["logprobs"]
Comment thread
qgallouedec marked this conversation as resolved.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject rollout_func with tools on non-vLLM backends

After moving rollout_func handling into _generate, the first turn can return non-None logprobs from rollout_func, but post-tool turns in _tool_call_loop now always use _generate_single_turn, which returns post_tool_logprobs=None for regular and paged Transformers generation. In that case, _tool_call_loop still takes the if logprobs is not None branch and later indexes post_tool_logprobs[idx], causing a runtime crash when a tool call is present. This affects runs that set both rollout_func and tools without vLLM, so the combination should be blocked or post_tool_logprobs should be normalized before use.

Useful? React with 👍 / 👎.

else:
prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
Expand Down
Loading