Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/sgl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,7 @@ jobs:
cd tests/workers/rollout
pytest -s test_sglang_async_rollout_mcp_tools.py
# Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests
- name: Test the latest SGLang Rollout async with multimodal delta
run: |
cd tests/workers/rollout
pytest -s test_sglang_async_rollout_multimodal_delta.py
34 changes: 32 additions & 2 deletions docs/sglang_multiturn/multiturn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ If you want rollout with simulated interaction, you can set the ``interaction_co
rollout:
interaction_config_file: <path_to_interaction_yaml_file>

If your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation.

Image and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations:

.. code-block:: python

async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]:
...
from verl.utils.dataset.vision_utils import process_image, process_video

img1 = process_image(img1)
video1 = process_video(video1)

# due to the (image | video) key is ("image" | "video") instead of ("images" | "videos") in vllm, we need to use ("image" | "video") to specify list of images/videos
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
return {"image": [img1, ...], "video": [video1, ...], "text": "..."}, 0, {}

remeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly.
Refer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details.

Multi-turn Tokenization
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -103,7 +123,7 @@ The tokenization sanity check mode can be configured using the ``actor_rollout_r

- ``ignore_strippable``: Ignores differences in whitespace characters (``\n``, ``\t``, ``\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable.

- ``off``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training.
- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training.

Example configuration:

Expand All @@ -112,7 +132,17 @@ Example configuration:
actor_rollout_ref:
rollout:
multi_turn:
tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "strict", "ignore_strippable", "off"
tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "disable", "ignore_strippable", "strict"

Handling Multi-Modal Inputs in Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset).

- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch.

- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch.


Special Cases
^^^^^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data:
max_response_length: 2048
train_batch_size: 256
return_raw_chat: True
return_multi_modal_inputs: False

actor_rollout_ref:
hybrid_engine: True
Expand All @@ -20,5 +21,5 @@ actor_rollout_ref:
name: sglang
multi_turn:
enable: True
max_turns: 5
max_assistant_turns: 5
# tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml"
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data:
max_response_length: 2048
train_batch_size: 256
return_raw_chat: True
return_multi_modal_inputs: False

actor_rollout_ref:
hybrid_engine: True
Expand All @@ -20,5 +21,5 @@ actor_rollout_ref:
name: sglang
multi_turn:
enable: True
max_turns: 5
max_assistant_turns: 5
# tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml"
187 changes: 187 additions & 0 deletions tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2025 Amazon.com, Inc. or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from verl.utils.dataset.vision_utils import process_image
from verl.utils.tokenizer import hf_processor
from verl.workers.rollout.schemas import (
AsyncRolloutRequest,
AsyncRolloutRequestStateEnum,
TokenizationSanityCheckModeEnum,
)


def _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False):
assert len(image_list) == len(description_list)
# Get the smallest dimensions across all images
processed_images = []
for img_url in image_list:
img = process_image(img_url)
processed_images.append(img)

min_width = min(img.size[0] for img in processed_images)
min_height = min(img.size[1] for img in processed_images)
min_size = (min_width, min_height)

if resize_image:
processed_images_resized = []
for img in processed_images:
img = img.resize(min_size)
processed_images_resized.append(img)
processed_images = processed_images_resized

# Initial message history
system_prompt = (
"You will be provided with an image. Describe this image and then generate a new image for the next round"
)
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": [
{"type": "text", "text": "Here is the first image provided: "},
{"type": "image", "image": [processed_images[0]]},
],
},
]

# Initial multi_modal_data with one image
multi_modal_data = {"image": [processed_images[0]], "video": []}
# Minimal required fields for AsyncRolloutRequest

req = AsyncRolloutRequest(
batch_data_id=0,
request_id="test-req-1",
state=AsyncRolloutRequestStateEnum.PENDING,
messages=messages,
multi_modal_keys=["image", "video"],
multi_modal_data=multi_modal_data.copy(),
tool_schemas=[],
tools_kwargs={},
interaction_kwargs={},
input_ids=[],
prompt_ids=[],
response_ids=[],
attention_mask=[],
prompt_attention_mask=[],
response_attention_mask=[],
position_ids=[],
prompt_position_ids=[],
response_position_ids=[],
loss_mask=[],
prompt_loss_mask=[],
response_loss_mask=[],
reward_scores={},
max_prompt_len=8192,
max_response_len=8192,
max_model_len=16384,
metrics={},
use_inference_chat_template=True,
tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT,
generation_prompt_ids=[],
base_conv_wo_gen_prompt_end_pos=0,
base_conv_with_gen_prompt_end_pos=0,
processing_class=processor,
)

prev_generated_len = 0
# Add First Assistant Message and first tool response message(image)
for idx, img in enumerate(processed_images):
if idx == 0:
continue
_ = req.get_generation_prompt_ids(processor)
req.add_assistant_message(processor, content=description_list[idx - 1])
before_tool_call_len = len(req.input_ids)
req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}])
after_tool_call_len = len(req.input_ids)
if prev_generated_len == 0:
prev_generated_len = after_tool_call_len - before_tool_call_len
else:
if resize_image:
assert after_tool_call_len - before_tool_call_len == prev_generated_len
assert req.multi_modal_data["image"] == processed_images[: idx + 1]

_ = req.get_generation_prompt_ids(processor)
req.add_assistant_message(processor, content=description_list[-1])

messages = [msg.model_dump() for msg in req.messages]
tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None
full_prompt_info = req._handle_apply_chat_template(
processor,
messages,
multi_modal_data=req.multi_modal_data,
tools=tools,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
)
full_prompt_ids = full_prompt_info["input_ids"]
assert full_prompt_ids == req.input_ids

# We must use dict(full_prompt_info) to convert BatchFeature values to a new dict
# because np.array() only keeps the keys for BatchFeature.
full_prompt_multi_modal_inputs = dict(full_prompt_info)
full_prompt_multi_modal_inputs.pop("input_ids", None)
full_prompt_multi_modal_inputs.pop("attention_mask", None)

for key in full_prompt_multi_modal_inputs:
assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all()


@pytest.mark.skipif(
hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct"
)
def test_add_tool_response_messages_image_delta():
processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct")

# From Qwen2.5-VL-3B-Instruct HF example
img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}
img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog."
# GitHub Logo
img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"}
img_2_description = "A GitHub Logo image"
# Octocat
img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"}
img_3_description = "An Octocat image"

image_list = [img_1_url, img_2_url, img_3_url]
description_list = [img_1_description, img_2_description, img_3_description]
_test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False)


@pytest.mark.skipif(
hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct"
)
def test_add_tool_response_messages_image_delta_resize_image():
processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct")

# From Qwen2.5-VL-3B-Instruct HF example
img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}
img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog."
# GitHub Logo
img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"}
img_2_description = "A GitHub Logo image"
# Octocat
img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"}
img_3_description = "An Octocat image"

image_list = [img_1_url, img_2_url, img_3_url]
description_list = [img_1_description, img_2_description, img_3_description]
_test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True)
6 changes: 3 additions & 3 deletions tests/workers/rollout/test_sglang_multi_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_interaction_selection_by_name(self):
max_response_len=16,
max_model_len=512,
use_inference_chat_template=True,
tokenization_sanity_check_mode="off",
tokenization_sanity_check_mode="disable",
processing_class=tokenizer,
)

Expand Down Expand Up @@ -252,7 +252,7 @@ def test_fallback_to_default_interaction(self):
"max_assistant_turns": 5,
"max_user_turns": 3,
"use_inference_chat_template": True,
"tokenization_sanity_check_mode": "off",
"tokenization_sanity_check_mode": "disable",
},
"prompt_length": 32,
"response_length": 16,
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_backward_compatibility_no_interaction_config(self):
"max_assistant_turns": 5,
"max_user_turns": 3,
"use_inference_chat_template": True,
"tokenization_sanity_check_mode": "off",
"tokenization_sanity_check_mode": "disable",
},
"prompt_length": 32,
"response_length": 16,
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ data:
class_path: null
class_name: null
dataloader_num_workers: 8
return_multi_modal_inputs: True

actor_rollout_ref:
hybrid_engine: True
Expand Down Expand Up @@ -249,7 +250,7 @@ actor_rollout_ref:
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
# - off: disable tokenization sanity check
# - disable: disable tokenization sanity check
# - strict: enable strict tokenization sanity check (default)
# - ignore_strippable: ignore strippable tokens when checking tokenization sanity
tokenization_sanity_check_mode: strict
Expand Down
5 changes: 4 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ data:

# The name of the dataset class within the specified file.
name: null

# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.
return_multi_modal_inputs: True

# settings related to data sampler
sampler:
Expand Down Expand Up @@ -566,7 +569,7 @@ actor_rollout_ref:
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
# - off: disable tokenization sanity check
# - disable: disable tokenization sanity check
# - strict: enable strict tokenization sanity check (default)
# - ignore_strippable: ignore strippable tokens when checking tokenization sanity
tokenization_sanity_check_mode: strict
Expand Down
18 changes: 15 additions & 3 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def __init__(
self.need_tools_kwargs = config.get("need_tools_kwargs", False)
self.filter_prompts = config.get("filter_prompts", True)
self.serialize_dataset = False
self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True)

self._download()
self._read_files_and_tokenize()

Expand Down Expand Up @@ -223,11 +225,17 @@ def __getitem__(self, item):
images = None
if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None:
images = [process_image(image) for image in row_dict.pop(self.image_key)]

# due to the image key is "image" instead of "images" in vllm, we need to use "image" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
multi_modal_data["image"] = images

videos = None
if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None:
videos = [process_video(video) for video in row_dict.pop(self.video_key)]

# due to the video key is "video" instead of "videos" in vllm, we need to use "video" here
# link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205
multi_modal_data["video"] = [video.numpy() for video in videos]

model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")
Expand All @@ -240,10 +248,14 @@ def __getitem__(self, item):

# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
row_dict["multi_modal_data"] = multi_modal_data
row_dict["multi_modal_inputs"] = dict(model_inputs)

# second_per_grid_ts isn't used for training, just for mrope
row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)
# We will do batch.union() in the trainer,
# so we cannot have "multi_modal_inputs" in row_dict if rollout generates new multi_modal_inputs
if self.return_multi_modal_inputs:
row_dict["multi_modal_inputs"] = dict(model_inputs)

# second_per_grid_ts isn't used for training, just for mrope
row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)

else:
raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
Expand Down
Loading