Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,26 @@ trainer.train()

`reset` can return either `None` or a string. In GRPO, when it returns a string, that string is appended to the last user message before generation.

### Multimodal Tool Responses

Tools can return images alongside text by returning a list of content blocks. This is useful for VLM agent training where the tool provides visual feedback (e.g., screenshots, plots, camera captures).

```python
from PIL import Image

def take_screenshot() -> list:
"""
Takes a screenshot of the current screen.

Returns:
The screenshot image with a description.
"""
img = Image.open("screenshot.png")
return [{"type": "image", "image": img}, {"type": "text", "text": "Here is the screenshot."}]
```

The returned images are automatically injected into the conversation and passed to the VLM for subsequent generation turns.

### Supported Models

Tested with:
Expand Down
84 changes: 84 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,90 @@ def fake_generate(input_ids, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.2.0"),
reason="Qwen3.5 models were introduced in transformers-5.2.0",
strict=True,
)
@require_jmespath
@require_vision
def test_training_with_tools_multimodal_response(self):
# Test that tools returning images (multimodal responses) work correctly with a VLM.
# The tool returns a list of content blocks including an image.
from PIL import Image as PILImage

def screenshot_tool() -> list:
"""
Takes a screenshot and returns it.

Returns:
A list of content blocks with the screenshot image.
"""
img = PILImage.new("RGB", (64, 64), color="red")
return [{"type": "image", "image": img}, {"type": "text", "text": "Here is the screenshot"}]

dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=512,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
tools=[screenshot_tool],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

def fake_generate(input_ids, **kwargs):
if input_ids.shape[0] == 3: # first call
# fmt: off
completion_ids = torch.tensor(
[
# '<tool_call>\n<function=screenshot_tool>\n</function>\n</tool_call><|im_end|>'
[248058, 198, 27, 1628, 13744, 30091, 22076, 29, 198, 510, 1628, 29, 198, 248059, 248046],
# "I don't know any tool<|im_end|>" + padding
[40, 1459, 914, 1366, 866, 5224, 248046, 248044, 248044, 248044, 248044, 248044, 248044, 248044, 248044],
# '<tool_call>\n<function=screenshot_tool>\n</function>\n</tool_call><|im_end|>'
[248058, 198, 27, 1628, 13744, 30091, 22076, 29, 198, 510, 1628, 29, 198, 248059, 248046],
],
device=input_ids.device,
)
# fmt: on
else: # second call: 2 tool calls succeeded
completion_ids = torch.tensor(
[
# 'Done!<|im_end|>'
[16936, 0, 248046],
# 'Done!<|im_end|>'
[16936, 0, 248046],
],
device=input_ids.device,
)
return torch.cat([input_ids, completion_ids], dim=-1)

with patch.object(trainer.model, "generate", side_effect=fake_generate):
trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3)
assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(0.0)

# Check that the params have changed (skip vision parts, see test_training_vlm)
params_to_skip = ("model.visual.",)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.xfail(
condition=Version(transformers.__version__) < Version("5.2.0"),
reason="Environment factory support is not available in transformers versions below 5.2.0",
Expand Down
Loading