diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index f07ec527d65..ca6dca82647 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -132,3 +132,7 @@ jobs: cd tests/workers/rollout pytest -s test_sglang_async_rollout_mcp_tools.py # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests + - name: Test the latest SGLang Rollout async with multimodal delta + run: | + cd tests/workers/rollout + pytest -s test_sglang_async_rollout_multimodal_delta.py \ No newline at end of file diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst index 06c3e21063d..5c876178053 100644 --- a/docs/sglang_multiturn/multiturn.rst +++ b/docs/sglang_multiturn/multiturn.rst @@ -56,6 +56,26 @@ If you want rollout with simulated interaction, you can set the ``interaction_co rollout: interaction_config_file: +If your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation. + +Image and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations: + +.. code-block:: python + + async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]: + ... + from verl.utils.dataset.vision_utils import process_image, process_video + + img1 = process_image(img1) + video1 = process_video(video1) + + # due to the (image | video) key is ("image" | "video") instead of ("images" | "videos") in vllm, we need to use ("image" | "video") to specify list of images/videos + # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 + return {"image": [img1, ...], "video": [video1, ...], "text": "..."}, 0, {} + +remeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly. +Refer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details. + Multi-turn Tokenization ~~~~~~~~~~~~~~~~~~~~~~~ @@ -103,7 +123,7 @@ The tokenization sanity check mode can be configured using the ``actor_rollout_r - ``ignore_strippable``: Ignores differences in whitespace characters (``\n``, ``\t``, ``\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable. -- ``off``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training. +- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training. Example configuration: @@ -112,7 +132,17 @@ Example configuration: actor_rollout_ref: rollout: multi_turn: - tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "strict", "ignore_strippable", "off" + tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "disable", "ignore_strippable", "strict" + +Handling Multi-Modal Inputs in Datasets +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset). + +- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch. + +- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch. + Special Cases ^^^^^^^^^^^^^ diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml index cc587378c7c..a9523f19685 100644 --- a/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml @@ -11,6 +11,7 @@ data: max_response_length: 2048 train_batch_size: 256 return_raw_chat: True + return_multi_modal_inputs: False actor_rollout_ref: hybrid_engine: True @@ -20,5 +21,5 @@ actor_rollout_ref: name: sglang multi_turn: enable: True - max_turns: 5 + max_assistant_turns: 5 # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml index 1c902309ed7..5e208f3336e 100644 --- a/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml +++ b/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml @@ -11,6 +11,7 @@ data: max_response_length: 2048 train_batch_size: 256 return_raw_chat: True + return_multi_modal_inputs: False actor_rollout_ref: hybrid_engine: True @@ -20,5 +21,5 @@ actor_rollout_ref: name: sglang multi_turn: enable: True - max_turns: 5 + max_assistant_turns: 5 # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py b/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py new file mode 100644 index 00000000000..c90b731d4c7 --- /dev/null +++ b/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py @@ -0,0 +1,187 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from verl.utils.dataset.vision_utils import process_image +from verl.utils.tokenizer import hf_processor +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + TokenizationSanityCheckModeEnum, +) + + +def _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False): + assert len(image_list) == len(description_list) + # Get the smallest dimensions across all images + processed_images = [] + for img_url in image_list: + img = process_image(img_url) + processed_images.append(img) + + min_width = min(img.size[0] for img in processed_images) + min_height = min(img.size[1] for img in processed_images) + min_size = (min_width, min_height) + + if resize_image: + processed_images_resized = [] + for img in processed_images: + img = img.resize(min_size) + processed_images_resized.append(img) + processed_images = processed_images_resized + + # Initial message history + system_prompt = ( + "You will be provided with an image. Describe this image and then generate a new image for the next round" + ) + messages = [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here is the first image provided: "}, + {"type": "image", "image": [processed_images[0]]}, + ], + }, + ] + + # Initial multi_modal_data with one image + multi_modal_data = {"image": [processed_images[0]], "video": []} + # Minimal required fields for AsyncRolloutRequest + + req = AsyncRolloutRequest( + batch_data_id=0, + request_id="test-req-1", + state=AsyncRolloutRequestStateEnum.PENDING, + messages=messages, + multi_modal_keys=["image", "video"], + multi_modal_data=multi_modal_data.copy(), + tool_schemas=[], + tools_kwargs={}, + interaction_kwargs={}, + input_ids=[], + prompt_ids=[], + response_ids=[], + attention_mask=[], + prompt_attention_mask=[], + response_attention_mask=[], + position_ids=[], + prompt_position_ids=[], + response_position_ids=[], + loss_mask=[], + prompt_loss_mask=[], + response_loss_mask=[], + reward_scores={}, + max_prompt_len=8192, + max_response_len=8192, + max_model_len=16384, + metrics={}, + use_inference_chat_template=True, + tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT, + generation_prompt_ids=[], + base_conv_wo_gen_prompt_end_pos=0, + base_conv_with_gen_prompt_end_pos=0, + processing_class=processor, + ) + + prev_generated_len = 0 + # Add First Assistant Message and first tool response message(image) + for idx, img in enumerate(processed_images): + if idx == 0: + continue + _ = req.get_generation_prompt_ids(processor) + req.add_assistant_message(processor, content=description_list[idx - 1]) + before_tool_call_len = len(req.input_ids) + req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}]) + after_tool_call_len = len(req.input_ids) + if prev_generated_len == 0: + prev_generated_len = after_tool_call_len - before_tool_call_len + else: + if resize_image: + assert after_tool_call_len - before_tool_call_len == prev_generated_len + assert req.multi_modal_data["image"] == processed_images[: idx + 1] + + _ = req.get_generation_prompt_ids(processor) + req.add_assistant_message(processor, content=description_list[-1]) + + messages = [msg.model_dump() for msg in req.messages] + tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None + full_prompt_info = req._handle_apply_chat_template( + processor, + messages, + multi_modal_data=req.multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + full_prompt_ids = full_prompt_info["input_ids"] + assert full_prompt_ids == req.input_ids + + # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict + # because np.array() only keeps the keys for BatchFeature. + full_prompt_multi_modal_inputs = dict(full_prompt_info) + full_prompt_multi_modal_inputs.pop("input_ids", None) + full_prompt_multi_modal_inputs.pop("attention_mask", None) + + for key in full_prompt_multi_modal_inputs: + assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all() + + +@pytest.mark.skipif( + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" +) +def test_add_tool_response_messages_image_delta(): + processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") + + # From Qwen2.5-VL-3B-Instruct HF example + img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." + # GitHub Logo + img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_description = "A GitHub Logo image" + # Octocat + img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} + img_3_description = "An Octocat image" + + image_list = [img_1_url, img_2_url, img_3_url] + description_list = [img_1_description, img_2_description, img_3_description] + _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False) + + +@pytest.mark.skipif( + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" +) +def test_add_tool_response_messages_image_delta_resize_image(): + processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") + + # From Qwen2.5-VL-3B-Instruct HF example + img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." + # GitHub Logo + img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_description = "A GitHub Logo image" + # Octocat + img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} + img_3_description = "An Octocat image" + + image_list = [img_1_url, img_2_url, img_3_url] + description_list = [img_1_description, img_2_description, img_3_description] + _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True) diff --git a/tests/workers/rollout/test_sglang_multi_interaction.py b/tests/workers/rollout/test_sglang_multi_interaction.py index 4fb38bec2e5..b195a20955b 100644 --- a/tests/workers/rollout/test_sglang_multi_interaction.py +++ b/tests/workers/rollout/test_sglang_multi_interaction.py @@ -210,7 +210,7 @@ def test_interaction_selection_by_name(self): max_response_len=16, max_model_len=512, use_inference_chat_template=True, - tokenization_sanity_check_mode="off", + tokenization_sanity_check_mode="disable", processing_class=tokenizer, ) @@ -252,7 +252,7 @@ def test_fallback_to_default_interaction(self): "max_assistant_turns": 5, "max_user_turns": 3, "use_inference_chat_template": True, - "tokenization_sanity_check_mode": "off", + "tokenization_sanity_check_mode": "disable", }, "prompt_length": 32, "response_length": 16, @@ -349,7 +349,7 @@ def test_backward_compatibility_no_interaction_config(self): "max_assistant_turns": 5, "max_user_turns": 3, "use_inference_chat_template": True, - "tokenization_sanity_check_mode": "off", + "tokenization_sanity_check_mode": "disable", }, "prompt_length": 32, "response_length": 16, diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c142051a9ec..1cf7dfdf7da 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -23,6 +23,7 @@ data: class_path: null class_name: null dataloader_num_workers: 8 + return_multi_modal_inputs: True actor_rollout_ref: hybrid_engine: True @@ -249,7 +250,7 @@ actor_rollout_ref: # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: # Qwen/QwQ-32B, Qwen/Qwen3-xxB - # - off: disable tokenization sanity check + # - disable: disable tokenization sanity check # - strict: enable strict tokenization sanity check (default) # - ignore_strippable: ignore strippable tokens when checking tokenization sanity tokenization_sanity_check_mode: strict diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 6842d492e2f..7bfb2a28a7f 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -90,6 +90,9 @@ data: # The name of the dataset class within the specified file. name: null + + # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. + return_multi_modal_inputs: True # settings related to data sampler sampler: @@ -566,7 +569,7 @@ actor_rollout_ref: # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: # Qwen/QwQ-32B, Qwen/Qwen3-xxB - # - off: disable tokenization sanity check + # - disable: disable tokenization sanity check # - strict: enable strict tokenization sanity check (default) # - ignore_strippable: ignore strippable tokens when checking tokenization sanity tokenization_sanity_check_mode: strict diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 6d79f9ef766..afdfa0d9219 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -115,6 +115,8 @@ def __init__( self.need_tools_kwargs = config.get("need_tools_kwargs", False) self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False + self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + self._download() self._read_files_and_tokenize() @@ -223,11 +225,17 @@ def __getitem__(self, item): images = None if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None: images = [process_image(image) for image in row_dict.pop(self.image_key)] + + # due to the image key is "image" instead of "images" in vllm, we need to use "image" here + # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 multi_modal_data["image"] = images videos = None if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None: videos = [process_video(video) for video in row_dict.pop(self.video_key)] + + # due to the video key is "video" instead of "videos" in vllm, we need to use "video" here + # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 multi_modal_data["video"] = [video.numpy() for video in videos] model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") @@ -240,10 +248,14 @@ def __getitem__(self, item): # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature row_dict["multi_modal_data"] = multi_modal_data - row_dict["multi_modal_inputs"] = dict(model_inputs) - # second_per_grid_ts isn't used for training, just for mrope - row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) + # We will do batch.union() in the trainer, + # so we cannot have "multi_modal_inputs" in row_dict if rollout generates new multi_modal_inputs + if self.return_multi_modal_inputs: + row_dict["multi_modal_inputs"] = dict(model_inputs) + + # second_per_grid_ts isn't used for training, just for mrope + row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) else: raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py index 09e6acb8cdf..76a0ef27d8f 100644 --- a/verl/workers/rollout/schemas.py +++ b/verl/workers/rollout/schemas.py @@ -73,7 +73,7 @@ class AsyncRolloutRequestStateEnum(str, Enum): class TokenizationSanityCheckModeEnum(str, Enum): """The enum for tokenization sanity check mode.""" - OFF = "off" + DISABLE = "disable" STRICT = "strict" IGNORE_STRIPPABLE = "ignore_strippable" @@ -89,6 +89,7 @@ class AsyncRolloutRequest(BaseModel): messages: List[Message] multi_modal_keys: Optional[List[str]] = None multi_modal_data: Optional[Dict[str, Any]] = None + multi_modal_inputs: Optional[Dict[str, Any]] = None tool_schemas: Optional[List[OpenAIFunctionToolSchema]] = None tools_kwargs: Dict[str, Any] = {} interaction_kwargs: Dict[str, Any] = {} @@ -98,9 +99,9 @@ class AsyncRolloutRequest(BaseModel): attention_mask: List[int] prompt_attention_mask: List[int] response_attention_mask: List[int] - position_ids: List[int] - prompt_position_ids: List[int] - response_position_ids: List[int] + position_ids: List[int] | List[List[int]] + prompt_position_ids: List[int] | List[List[int]] + response_position_ids: List[int] | List[List[int]] loss_mask: List[int] prompt_loss_mask: List[int] response_loss_mask: List[int] @@ -138,6 +139,8 @@ def initialize_request(cls, values): for key in values["multi_modal_keys"]: if key not in values["multi_modal_data"]: values["multi_modal_data"][key] = [] + if not values.get("multi_modal_inputs"): + values["multi_modal_inputs"] = {} tools = ( [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None @@ -152,7 +155,7 @@ def initialize_request(cls, values): add_generation_prompt=False, tokenize=True, ) - if not values.get("input_ids") or not values.get("attention_mask"): + if not values.get("input_ids") or not values.get("attention_mask") or not values.get("position_ids"): tokenization_dict_with_prompt = cls._handle_apply_chat_template( processing_class, messages, @@ -175,10 +178,17 @@ def initialize_request(cls, values): f"{max_prompt_len} after applied chat template with tools." ) + # Process multi_modal_inputs + multi_modal_inputs = dict(tokenization_dict_with_prompt) + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + values["multi_modal_inputs"] = multi_modal_inputs + + values["position_ids"] = values["prompt_position_ids"] = cls._get_position_ids( + processing_class, values["input_ids"], values["attention_mask"], multi_modal_inputs + ) + values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] - values["position_ids"] = values["prompt_position_ids"] = compute_position_id_with_mask( - torch.tensor(values["attention_mask"]) - ).tolist() values["loss_mask"] = values["prompt_loss_mask"] = [0] * len(values["input_ids"]) values["generation_prompt_ids"] = values["input_ids"][len(tokens_without_prompt) :] values["base_conv_wo_gen_prompt_end_pos"] = len( @@ -237,16 +247,66 @@ def _handle_apply_chat_template( images = images if len(images := multi_modal_data.get("image", [])) > 0 else None videos = videos if len(videos := multi_modal_data.get("video", [])) > 0 else None model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") - assert model_inputs["input_ids"].shape[0] == 1, "input_ids should be a 1D array" - model_inputs = {k: v[0].tolist() if hasattr(v, "tolist") else v for k, v in model_inputs.items()} + assert model_inputs["input_ids"].shape[0] == 1, ( + "request level input_ids should be a 2D array with shape (1, seq_len)" + ) + assert model_inputs["attention_mask"].shape[0] == 1, ( + "request level attention_mask should be a 2D array with shape (1, seq_len)" + ) + + # current req level input_ids/attention_mask needs to be 1D array, + # this is specific for request level input_ids/attention_mask + model_inputs["input_ids"] = model_inputs["input_ids"][0].tolist() + model_inputs["attention_mask"] = model_inputs["attention_mask"][0].tolist() + if return_dict: - return model_inputs + return dict(model_inputs) else: return model_inputs["input_ids"] else: raise ValueError(f"Unsupported processing class type: {type(processing_class)}") - def _update_input_ids(self, new_input_ids: List[int], attention_mask: bool, loss_mask: bool) -> None: + @staticmethod + def _get_position_ids( + processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin], + input_ids: List[int], + attention_mask: List[int], + multi_modal_inputs: Optional[Dict[str, Any]] = None, + ) -> List[int]: + # special case for qwen2vl + is_qwen2vl = ( + hasattr(processing_class, "image_processor") + and "Qwen2VLImageProcessor" in processing_class.image_processor.__class__.__name__ + ) + if is_qwen2vl: + from verl.models.transformers.qwen2_vl import get_rope_index + + image_grid_thw = video_grid_thw = second_per_grid_ts = None + if multi_modal_inputs: + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + video_grid_thw = multi_modal_inputs.get("video_grid_thw") + second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") + + new_position_ids = get_rope_index( + processing_class, + input_ids=torch.tensor(input_ids, dtype=torch.long), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=torch.tensor(attention_mask, dtype=torch.long), + ) + return new_position_ids.tolist() # (3, seq_len) + else: + return compute_position_id_with_mask(torch.tensor(attention_mask)).tolist() # (seq_len,) + + def _update_input_ids( + self, + processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin], + new_input_ids: List[int], + attention_mask: bool, + loss_mask: bool, + new_multi_modal_inputs: Optional[Dict[str, Any]] = None, + ) -> None: """ Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner. """ @@ -254,14 +314,45 @@ def _update_input_ids(self, new_input_ids: List[int], attention_mask: bool, loss attention_mask = [int(attention_mask)] * len(new_input_ids) self.attention_mask += attention_mask self.loss_mask += [int(loss_mask)] * len(new_input_ids) - self.position_ids += ( - compute_position_id_with_mask(torch.tensor(attention_mask)) + (self.position_ids[-1] + 1) - ).tolist() + + if new_multi_modal_inputs: + self._update_multi_modal_inputs(new_multi_modal_inputs) + + new_position_ids = self._get_position_ids( + processing_class, new_input_ids, attention_mask, new_multi_modal_inputs + ) + if isinstance(self.position_ids[0], list): + self.position_ids = [ + self.position_ids[i] + [j + self.position_ids[i][-1] + 1 for j in new_position_ids[i]] + for i in range(len(new_position_ids)) + ] # (3, seq_len) + position_ids_seq_len = len(self.position_ids[0]) + else: + self.position_ids += [j + (self.position_ids[-1] + 1) for j in new_position_ids] # (seq_len,) + position_ids_seq_len = len(self.position_ids) assert ( - len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask) + len(self.input_ids) == len(self.attention_mask) == position_ids_seq_len == len(self.loss_mask) ), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + {len(self.attention_mask)=}, {position_ids_seq_len=}, {len(self.loss_mask)=}""" + + def _update_multi_modal_inputs(self, new_multi_modal_inputs: Dict[str, Any]) -> None: + """ + Update the multi_modal_inputs of the request in additive manner. + """ + + # We just want to have the multi_modal_inputs without input_ids and attention_mask + new_multi_modal_inputs = new_multi_modal_inputs.copy() + new_multi_modal_inputs.pop("input_ids", None) + new_multi_modal_inputs.pop("attention_mask", None) + + for key in new_multi_modal_inputs: + input_tensor = new_multi_modal_inputs[key] + self.multi_modal_inputs[key] = ( + torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0) + if key in self.multi_modal_inputs + else input_tensor + ) def get_generation_prompt_ids( self, processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin] @@ -272,7 +363,7 @@ def get_generation_prompt_ids( else self.generation_prompt_ids ) if generation_prompt_ids: - self._update_input_ids(generation_prompt_ids, attention_mask=True, loss_mask=False) + self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False) if self.use_inference_chat_template: messages = [msg.model_dump() for msg in self.messages] @@ -303,7 +394,7 @@ def add_user_message( content_ids = self._handle_apply_chat_template( processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True )[self.base_conv_wo_gen_prompt_end_pos :] - self._update_input_ids(content_ids, attention_mask=True, loss_mask=False) + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False) def add_assistant_message( self, @@ -321,24 +412,79 @@ def add_assistant_message( content_ids = self._handle_apply_chat_template( processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True )[self.base_conv_with_gen_prompt_end_pos :] - self._update_input_ids(content_ids, attention_mask=True, loss_mask=True) + self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True) def add_tool_response_messages( - self, processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin], contents: list[str] + self, + processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin], + contents: list[str | Dict[str, Any]], ) -> None: if not contents: return + # We also handle the case when tool returns image + # We require the processing of the image and video to be done at tool.execute() level + delta_multi_modal_data = {key: [] for key in self.multi_modal_keys} + for content in contents: + if isinstance(content, dict): + content_list = [] + # When we update multi_model_keys, we also need to update this logic + if "image" in content: + if not isinstance(content["image"], list): + raise ValueError( + f"Image must be a list, but got {type(content['image'])}. Please check the tool.execute(). " + f"For single images, wrap in a list: [image]. " + f"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}." + ) - self.messages.extend([Message(role="tool", content=content) for content in contents]) + content_list.extend([{"type": "image"} for _ in content["image"]]) + delta_multi_modal_data["image"].extend(content["image"]) + if "video" in content: + if not isinstance(content["video"], list): + raise ValueError( + f"Video must be a list, but got {type(content['video'])}. Please check the tool.execute(). " + f"For single videos, wrap in a list: [video]. " + f"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}." + ) + + content_list.extend([{"type": "video"} for _ in content["video"]]) + delta_multi_modal_data["video"].extend(content["video"]) + if "text" in content: + content_list.append({"type": "text", "text": content["text"]}) + for key in content: + if key not in ["image", "video", "text"]: + logger.warning( + f"Tool response message contains unexpected key: {key} " + f"while we only support `image`, `video`, and `text`." + ) + self.messages.append(Message(role="tool", content=content_list)) + else: + self.messages.append(Message(role="tool", content=content)) messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]] tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - # Currently we don't support tool creates multi-modal data - content_ids = self._handle_apply_chat_template( - processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True - )[self.base_conv_wo_gen_prompt_end_pos :] - self._update_input_ids(content_ids, attention_mask=True, loss_mask=False) + for key in self.multi_modal_keys: + if len(delta_multi_modal_data[key]) > 0: + self.multi_modal_data[key].extend(delta_multi_modal_data[key]) + + # We just passed the new multi-modal data to the chat template to update the input_ids. + content_info = self._handle_apply_chat_template( + processing_class, + messages, + multi_modal_data=delta_multi_modal_data, + tools=tools, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + ) + content_ids = content_info["input_ids"][self.base_conv_wo_gen_prompt_end_pos :] + self._update_input_ids( + processing_class, + content_ids, + attention_mask=True, + loss_mask=False, + new_multi_modal_inputs=content_info, + ) def update_metrics(self, metrics: Any, tool_id: str) -> None: """ @@ -411,20 +557,62 @@ def finalize( ) -> None: self.state = AsyncRolloutRequestStateEnum.COMPLETED self.reward_scores = reward_scores - if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.OFF: + + # In case we failed to generate the assistant message and the generation prompt ids were already added to + # input_ids, remove them from the end of input_ids + if self.input_ids[-len(self.generation_prompt_ids) :] == self.generation_prompt_ids: + self.input_ids = self.input_ids[: -len(self.generation_prompt_ids)] + self.attention_mask = self.attention_mask[: -len(self.generation_prompt_ids)] + self.position_ids = ( + [position_ids[: -len(self.generation_prompt_ids)] for position_ids in self.position_ids] + if isinstance(self.position_ids[0], list) and isinstance(self.position_ids[0][0], int) + else self.position_ids[: -len(self.generation_prompt_ids)] + ) + self.loss_mask = self.loss_mask[: -len(self.generation_prompt_ids)] + + self.response_ids = self.input_ids[len(self.prompt_ids) :] + + if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE: # When there is a diff, we log the diffs with diff_surrounding_chars context diff_surrounding_chars = 10 messages = [msg.model_dump() for msg in self.messages] tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - full_prompt_ids = self._handle_apply_chat_template( + full_prompt_info = self._handle_apply_chat_template( processing_class, messages, multi_modal_data=self.multi_modal_data, tools=tools, add_generation_prompt=False, tokenize=True, + return_dict=True, ) + full_prompt_ids = full_prompt_info["input_ids"] + + # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict + # because np.array() only keeps the keys for BatchFeature. + full_prompt_multi_modal_inputs = dict(full_prompt_info) + full_prompt_multi_modal_inputs.pop("input_ids", None) + full_prompt_multi_modal_inputs.pop("attention_mask", None) + + for multi_modal_inputs_key in self.multi_modal_inputs: + if multi_modal_inputs_key in full_prompt_multi_modal_inputs: + if ( + not self.multi_modal_inputs[multi_modal_inputs_key] + .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key]) + .all() + ): + logger.warning( + f"Multi-modal data {multi_modal_inputs_key} is not consistent. " + f"This may lead to unexpected behavior during training. " + f"Please review your multi_modal_inputs logic." + ) + else: + logger.warning( + f"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. " + f"This may lead to unexpected behavior during training." + f"Please review your multi_modal_inputs logic." + ) if diffs := self._get_prompt_diffs( processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars @@ -460,15 +648,6 @@ def finalize( diff_details = "\n".join(diff_details_list) logger.warning(f"Found differences:\n{diff_details}") - # In case we failed to generate the assistant message and the generation prompt ids were already added to - # input_ids, remove them from the end of input_ids - if self.input_ids[-len(self.generation_prompt_ids) :] == self.generation_prompt_ids: - self.input_ids = self.input_ids[: -len(self.generation_prompt_ids)] - self.attention_mask = self.attention_mask[: -len(self.generation_prompt_ids)] - self.position_ids = self.position_ids[: -len(self.generation_prompt_ids)] - self.loss_mask = self.loss_mask[: -len(self.generation_prompt_ids)] - - self.response_ids = self.input_ids[len(self.prompt_ids) :] if finish_reason_type == FinishReasonTypeEnum.STOP: pass elif finish_reason_type == FinishReasonTypeEnum.LENGTH: @@ -476,19 +655,36 @@ def finalize( else: raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") self.truncate_output_ids(processing_class) + + position_ids_seq_len = ( + len(self.position_ids[0]) if isinstance(self.position_ids[0], list) else len(self.position_ids) + ) assert ( - len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask) + len(self.input_ids) == len(self.attention_mask) == position_ids_seq_len == len(self.loss_mask) ), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + {len(self.attention_mask)=}, {position_ids_seq_len=}, {len(self.loss_mask)=}""" def truncate_output_ids( self, processing_class: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin] ) -> None: self.input_ids = self.input_ids[: self.max_model_len] self.attention_mask = self.attention_mask[: self.max_model_len] - self.position_ids = self.position_ids[: self.max_model_len] + + # this is same as torch.tensor(self.position_ids[..., :self.max_model_len]).tolist() + self.position_ids = ( + [position_ids[: self.max_model_len] for position_ids in self.position_ids] + if isinstance(self.position_ids[0], list) and isinstance(self.position_ids[0][0], int) + else self.position_ids[: self.max_model_len] + ) self.loss_mask = self.loss_mask[: self.max_model_len] self.response_ids = self.input_ids[len(self.prompt_ids) :][: self.max_response_len] self.response_attention_mask = self.attention_mask[len(self.prompt_attention_mask) :][: self.max_response_len] - self.response_position_ids = self.position_ids[len(self.prompt_position_ids) :][: self.max_response_len] + self.response_position_ids = ( + [ + position_ids[len(self.prompt_position_ids[0]) :][: self.max_response_len] + for position_ids in self.position_ids + ] + if isinstance(self.position_ids[0], list) and isinstance(self.position_ids[0][0], int) + else self.position_ids[len(self.prompt_position_ids) :][: self.max_response_len] + ) self.response_loss_mask = self.loss_mask[len(self.prompt_loss_mask) :][: self.max_response_len] diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 939e9c52463..347ea154ce8 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -777,17 +777,6 @@ async def _async_rollout_a_request( finish_reason_type = None output = None - image_data = None - video_data = None - if _req.multi_modal_data is not None and isinstance(_req.multi_modal_data, dict): - if "image" in _req.multi_modal_data and _req.multi_modal_data["image"]: - image_data = _req.multi_modal_data["image"] - if "video" in _req.multi_modal_data and _req.multi_modal_data["video"]: - video_data = _req.multi_modal_data["video"] - logger.warning( - "video support is not implemented yet, current length of video data is %d", len(video_data) - ) - current_turns = 0 user_turns = 0 user_turn_rewards = [] @@ -857,7 +846,23 @@ async def _async_rollout_a_request( if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len: finish_reason_type = FinishReasonTypeEnum.LENGTH break + # Video support is not implemented yet + image_data = ( + _req.multi_modal_data["image"] + if _req.multi_modal_data and "image" in _req.multi_modal_data + else None + ) + video_data = ( + _req.multi_modal_data["video"] + if _req.multi_modal_data and "video" in _req.multi_modal_data + else None + ) + if video_data: + logger.warning( + "video support is not implemented yet, current length of video data is %d", len(video_data) + ) + output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data) content = output["text"] finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) @@ -1056,12 +1061,17 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro prompt_loss_mask, response_loss_mask = [], [] messages = [] reward_scores = [] + multi_modal_inputs = [] + for req in sorted_output_req_list: assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" + position_ids_seq_len = ( + len(req.position_ids[0]) if isinstance(req.position_ids[0], list) else len(req.position_ids) + ) assert ( - len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask) + len(req.input_ids) == len(req.attention_mask) == position_ids_seq_len == len(req.loss_mask) ), f"""Request {req.request_id} has different length of - {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" + {len(req.input_ids)=}, {len(req.attention_mask)=}, {position_ids_seq_len=}, {len(req.loss_mask)=}""" error_message_lines = [ f"""Request {req.request_id} has input_ids length {len(req.input_ids)} greater than max_model_len {self.config.max_model_len}""", @@ -1091,6 +1101,7 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) messages.append({"messages": req.messages}) reward_scores.append(req.reward_scores) + multi_modal_inputs.append(req.multi_modal_inputs) prompt_ids = pad_sequence( prompt_ids, @@ -1116,15 +1127,45 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) if response_attention_mask.shape[1] < self.config.response_length: response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) - prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") - if prompt_position_ids.shape[1] < self.config.prompt_length: + + # padding prompt_position_ids + if prompt_position_ids[0].dim() == 2: + # if prompt_position_ids is a 2D tensor + # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len) + transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids] + prompt_position_ids = pad_sequence( + transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + prompt_position_ids = prompt_position_ids.transpose(1, 2) + else: + prompt_position_ids = pad_sequence( + prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + prompt_position_ids_seq_len = ( + prompt_position_ids.shape[2] if prompt_position_ids.dim() == 3 else prompt_position_ids.shape[1] + ) + if prompt_position_ids_seq_len < self.config.prompt_length: prompt_position_ids = pad_sequence_to_length( prompt_position_ids, self.config.prompt_length, 0, left_pad=True ) - response_length = response_ids.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1) - response_position_ids = prompt_position_ids[:, -1:] + delta_position_id + + # padding response_position_ids + if response_position_ids[0].dim() == 2: + # if response_position_ids is a 2D tensor + # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len) + transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids] + response_position_ids = pad_sequence( + transposed_response_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) + response_position_ids = response_position_ids.transpose(1, 2) + else: + response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0) + response_position_ids_seq_len = ( + response_position_ids.shape[2] if response_position_ids.dim() == 3 else response_position_ids.shape[1] + ) + if response_position_ids_seq_len < self.config.response_length: + response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0) + prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") if prompt_loss_mask.shape[1] < self.config.prompt_length: prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) @@ -1160,6 +1201,7 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro "messages": np.array(messages), "reward_scores": np.array(reward_scores), "uid": np.array([req.uid for req in sorted_output_req_list]), + "multi_modal_inputs": np.array(multi_modal_inputs, dtype=object), }, ) @@ -1221,13 +1263,15 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, processing_class=self.processing_class, ) - + position_ids_seq_len = ( + len(req.position_ids[0]) if isinstance(req.position_ids[0], list) else len(req.position_ids) + ) error_message = f"""Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, - position_ids={len(req.position_ids)}, + position_ids={position_ids_seq_len}, loss_mask={len(req.loss_mask)}""" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), ( + assert len(req.input_ids) == len(req.attention_mask) == position_ids_seq_len == len(req.loss_mask), ( error_message ) req_list.append(req)