diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst index 970ba46c1d2..70db63e23c6 100644 --- a/docs/sglang_multiturn/multiturn.rst +++ b/docs/sglang_multiturn/multiturn.rst @@ -40,6 +40,151 @@ Finally, set the ``tools_config_file`` in your rollout config: This allows integration of customized tool behaviors during actor rollout steps. +Multi-turn Tokenization +~~~~~~~~~~~~~~~~~~~~~~~ + +Tokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it’s hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles. + +To address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we: + +1. Apply the chat template to all prior messages (`messages[:i]`). +2. Apply the chat template again including the latest message (`messages[:i+1]`). +3. Tokenize only the *delta* between these two serialized message strings. + +This ensures that only tokens generated by the assistant are included in the loss mask. + +.. code-block:: python + + # Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True + prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False) + curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False) + token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) + loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens + +While we’ve validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout. + +If you see the following warning, enable `INFO` log level to inspect the mismatched outputs: + +.. code-block:: + + Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md. + +If the discrepancy is expected, you can disable the sanity check via: + +``actor_rollout_ref.rollout.multi_turn.enable_tokenization_sanity_check=False`` + +Special Cases +^^^^^^^^^^^^^ + +Some models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate. + +For example, for the following conversation: + +.. code-block:: python + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "user asked about a simple math question. 2 + 2 = 4."}, + {"role": "user", "content": "Explain why."}, + {"role": "assistant", "content": "user wants to know the reasoning behind the answer. Search for a good explanation", + "tool_calls": [{"id": "tool1", "type": "search", "arguments": {"query": "Why is 2 + 2 = 4?"}}]}, + {"role": "tool", "content": "The sum of two and two is four because it is a basic arithmetic operation."}, + {"role": "assistant", "content": "The tool provided a good explanation.The sum of two and two is four because it is a basic arithmetic operation."} + ] + +1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template. + +.. code-block:: text + + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What is 2 + 2?<|im_end|> + <|im_start|>assistant + 2 + 2 = 4.<|im_end|> + <|im_start|>user + Explain why.<|im_end|> + <|im_start|>assistant + + {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} + <|im_end|> + <|im_start|>user + + The sum of two and two is four because it is a basic arithmetic operation. + <|im_end|> + <|im_start|>assistant + The tool provided a good explanation. The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> + +2. Qwen3 series will remove all reasoning content before the last user message. + +.. code-block:: text + + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What is 2 + 2?<|im_end|> + <|im_start|>assistant + 2 + 2 = 4.<|im_end|> + <|im_start|>user + Explain why.<|im_end|> + <|im_start|>assistant + + user wants to know the reasoning behind the answer. Search for a good explanation + + + + {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} + <|im_end|> + <|im_start|>user + + The sum of two and two is four because it is a basic arithmetic operation. + <|im_end|> + <|im_start|>assistant + + The tool provided a good explanation. + + + The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> + +To handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn’t include assistant messages or reasoning content, it remains consistent across turns. + +.. code-block:: python + + BASE_CHAT_HISTORY = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "I am a user."} + ] + prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False) + curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False) + token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) + loss_mask += [1] * len(token_ids) + +This method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision: + +.. code-block:: bash + + pip install huggingface_hub + huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81 + +.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81 + +Discrepancy Between Training and Inference Templates +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Although the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not. + +This mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout. + +However, this approach comes with trade-offs: + +1. Long reasoning contents can easily exceed the model’s context window, especially in multi-turn rollout. +2. There’s a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production. + +We are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable: + +``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True`` + GSM8K Multi-turn Training Performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml index db133f8af77..737ef2de968 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml @@ -19,5 +19,4 @@ actor_rollout_ref: multi_turn: enable: True max_turns: 5 - format: qwen # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml index 8609d890166..39f2ff7faff 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml @@ -19,6 +19,5 @@ actor_rollout_ref: multi_turn: enable: True max_turns: 5 - format: qwen # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" \ No newline at end of file diff --git a/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh new file mode 100755 index 00000000000..681375363fc --- /dev/null +++ b/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh @@ -0,0 +1,53 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen3-4B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py index 667d4927103..efd6f17d4c6 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py @@ -164,9 +164,9 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tools) == 1 - print(type(req_list[0].tools[0])) - assert req_list[0].tools[0] == OpenAIFunctionToolSchema( + assert len(req_list[0].tool_schemas) == 1 + print(type(req_list[0].tool_schemas[0])) + assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( type="function", function=OpenAIFunctionSchema( name="search", diff --git a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py index fe027a60e27..43a76fefa22 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py @@ -220,9 +220,9 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tools) == 1 - print(type(req_list[0].tools[0])) - assert req_list[0].tools[0] == OpenAIFunctionToolSchema( + assert len(req_list[0].tool_schemas) == 1 + print(type(req_list[0].tool_schemas[0])) + assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( type="function", function=OpenAIFunctionSchema( name="code_interpreter", diff --git a/tests/workers/rollout/utils_sglang.py b/tests/workers/rollout/utils_sglang.py index 35c43a83a88..2596a074eee 100644 --- a/tests/workers/rollout/utils_sglang.py +++ b/tests/workers/rollout/utils_sglang.py @@ -152,7 +152,8 @@ def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_par "max_turns": 4, "enable": True, "tool_config_path": tool_config_path, - "format": "chatml", + "use_inference_chat_template": False, + "enable_tokenization_sanity_check": True, }, "max_model_len": None, **sampling_params, diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 9a4fb92a4ab..659243f0450 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -178,7 +178,16 @@ actor_rollout_ref: enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well max_turns: null # null for no limit (default max_length // 3) tool_config_path: null # null for no tool - format: chatml # chatml, more formats will be supported in the future + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. And this will lead to longer prompts. + use_inference_chat_template: True + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + enable_tokenization_sanity_check: True critic: rollout_n: ${actor_rollout_ref.rollout.n} diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index fa9299dd613..61974abc2a7 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -434,12 +434,21 @@ actor_rollout_ref: # null for no tool tool_config_path: null - # chatml, more formats will be supported in the future - format: chatml - # null for default callback completion_callback: null + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + enable_tokenization_sanity_check: True + # configs for the critic critic: diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py index 1c5df09fa27..b259d04620e 100644 --- a/verl/workers/rollout/schemas.py +++ b/verl/workers/rollout/schemas.py @@ -12,17 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging +import os from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Optional import torch -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from transformers import PreTrainedTokenizer from verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema from verl.utils.model import compute_position_id_with_mask +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +BASE_CHAT_HISTORY = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "I am a user."}] + class FinishReasonTypeEnum(str, Enum): """The enum for finish reason type.""" @@ -67,7 +73,7 @@ class AsyncRolloutRequest(BaseModel): request_id: str state: AsyncRolloutRequestStateEnum messages: List[Message] - tools: Optional[List[OpenAIFunctionToolSchema]] = None + tool_schemas: Optional[List[OpenAIFunctionToolSchema]] = None tools_kwargs: Dict[str, Any] = {} input_ids: List[int] prompt_ids: List[int] @@ -82,128 +88,87 @@ class AsyncRolloutRequest(BaseModel): prompt_loss_mask: List[int] response_loss_mask: List[int] reward_scores: Dict[str, float] + max_prompt_len: int max_response_len: int = 8192 max_model_len: int = 32768 metrics: Dict[str, List[Any]] = {} - format_config: dict = { - "chatml": { - "assistant_prefix_msg": "\n<|im_start|>assistant\n", - "assistant_suffix_msg": "<|im_end|>", - "tool_prefix_msg": "\n<|im_start|>tool\n", - "tool_suffix_msg": "<|im_end|>", - }, - "qwen": { - "assistant_prefix_msg": "\n<|im_start|>assistant\n", - "assistant_suffix_msg": "<|im_end|>", - "merge_tool_response": True, - "tool_prefix_msg": "\n<|im_start|>user", - "tool_suffix_msg": "<|im_end|>", - "tool_response_prefix_msg": "\n\n", - "tool_response_suffix_msg": "\n", - }, - } - - def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]: - return tokenizer.apply_chat_template( # type: ignore - conversation=[msg.model_dump() for msg in self.messages], - tools=[tool.model_dump() for tool in self.tools] if self.tools else None, - add_generation_prompt=True, - tokenize=True, - ) + use_inference_chat_template: bool + enable_tokenization_sanity_check: bool + generation_prompt_ids: List[int] + base_conv_wo_gen_prompt_end_pos: int + base_conv_with_gen_prompt_end_pos: int + + @model_validator(mode="before") + @classmethod + def initialize_request(cls, values): + if not (messages := values.get("messages")): + raise ValueError("messages is required for AsyncRolloutRequest initialization") + if not (max_prompt_len := values.get("max_prompt_len")): + raise ValueError("max_prompt_len is required for AsyncRolloutRequest initialization") + if not (tokenizer := values.pop("tokenizer", None)): + raise ValueError("tokenizer is required for AsyncRolloutRequest initialization") + + values["messages"] = [Message.model_validate(msg) for msg in messages] + + tools = [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None + tokens_without_prompt = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=False, tokenize=True) + if not values.get("input_ids") or not values.get("attention_mask"): + tokenization_dict_with_prompt = tokenizer.apply_chat_template(messages, tools=[tool.model_dump() for tool in tool_schemas], add_generation_prompt=True, tokenize=True, return_dict=True) + values["input_ids"], values["attention_mask"] = tokenization_dict_with_prompt["input_ids"], tokenization_dict_with_prompt["attention_mask"] + if len(values["input_ids"]) > max_prompt_len: + # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an error for this case in the future. + logger.warning(f"Prompt {values['batch_data_id']} length {len(values['input_ids'])} greater than max_prompt_len {max_prompt_len} after applied chat template with tools.") + + values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] + values["position_ids"] = values["prompt_position_ids"] = compute_position_id_with_mask(torch.tensor(values["attention_mask"])).tolist() + values["loss_mask"] = values["prompt_loss_mask"] = [0] * len(values["input_ids"]) + values["generation_prompt_ids"] = values["input_ids"][len(tokens_without_prompt) :] + values["base_conv_wo_gen_prompt_end_pos"] = len(tokenizer.apply_chat_template(BASE_CHAT_HISTORY, tools=tools, add_generation_prompt=False, tokenize=False)) + values["base_conv_with_gen_prompt_end_pos"] = len(tokenizer.apply_chat_template(BASE_CHAT_HISTORY, tools=tools, add_generation_prompt=True, tokenize=False)) + return values + + def _update_input_ids(self, new_input_ids: List[int], attention_mask: bool, loss_mask: bool) -> None: + """ + Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner. + """ + self.input_ids += new_input_ids + attention_mask = [int(attention_mask)] * len(new_input_ids) + self.attention_mask += attention_mask + self.loss_mask += [int(loss_mask)] * len(new_input_ids) + self.position_ids += (compute_position_id_with_mask(torch.tensor(attention_mask)) + (self.position_ids[-1] + 1)).tolist() + + assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, + {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + + def get_generation_prompt_ids(self, tokenizer: PreTrainedTokenizer) -> list[int]: + generation_prompt_ids = [] if self.input_ids[-len(self.generation_prompt_ids) :] == self.generation_prompt_ids else self.generation_prompt_ids + if generation_prompt_ids: + self._update_input_ids(generation_prompt_ids, attention_mask=True, loss_mask=False) + + if self.use_inference_chat_template: + return tokenizer.apply_chat_template([msg.model_dump() for msg in self.messages], tools=([tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None), add_generation_prompt=True, tokenize=True) + else: + return self.input_ids def add_assistant_message( self, tokenizer: PreTrainedTokenizer, content: str, tool_calls: Optional[List[OpenAIFunctionToolCall]] = None, - format: Literal["chatml", "qwen"] = "chatml", - already_over_long: bool = False, ) -> None: - """Currently, we only support chatml format.""" - msg = Message(role="assistant", content=content, tool_calls=tool_calls) - self.messages.append(msg) - if tool_calls is not None: - content_with_tool_calls: str = tokenizer.apply_chat_template( # type: ignore - conversation=[msg.model_dump()], add_generation_prompt=False, tokenize=False - ) - else: - content_with_tool_calls = content - # TODO: support other formats - if format in self.format_config: - prefix_msg = self.format_config[format]["assistant_prefix_msg"] - prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) - suffix_msg = self.format_config[format]["assistant_suffix_msg"] - suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) - if tool_calls is not None: - content = content_with_tool_calls.split(f"{prefix_msg}")[-1].split(f"{suffix_msg}")[0] - content_token_ids = tokenizer.encode(content, add_special_tokens=False) - if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids: - append_token_ids = content_token_ids - _loss_mask = [1] * len(content_token_ids) - elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: - append_token_ids = prefix_token_ids + content_token_ids - _loss_mask = [0] * len(prefix_token_ids) + [1] * len(content_token_ids) - else: - max_len = max(len(prefix_token_ids), len(suffix_token_ids)) - raise ValueError( - f"""Unsupported end of message format: - {tokenizer.decode(self.input_ids[-max_len:])}, - {tokenizer.decode(self.input_ids)=}, {self.messages=}""" - ) - if not already_over_long: - append_token_ids += suffix_token_ids - _loss_mask += [1] * len(suffix_token_ids) - self.input_ids += append_token_ids - _attention_mask = [1] * len(append_token_ids) - self.attention_mask += _attention_mask - _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - last_position_id = self.position_ids[-1] - _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] - self.loss_mask += _loss_mask - self.position_ids += _position_ids - else: - raise ValueError(f"Unsupported format: {format}") - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" - - def add_tool_response_message(self, tokenizer: PreTrainedTokenizer, content: str, last_tool: bool, format: Literal["chatml", "qwen"] = "chatml") -> None: - """Currently, we only support chatml format.""" - msg = Message(role="tool", content=content) - self.messages.append(msg) - # TODO: support other formats - if format in self.format_config: - merge_tool_responses = self.format_config[format].get("merge_tool_response", False) - prefix_msg = self.format_config[format]["tool_prefix_msg"] - prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) - suffix_msg = self.format_config[format]["tool_suffix_msg"] - suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) - prefix_resp = self.format_config[format].get("tool_response_prefix_msg", "") - prefix_resp_token_ids = tokenizer.encode(prefix_resp, add_special_tokens=False) - suffix_resp = self.format_config[format].get("tool_response_suffix_msg", "") - suffix_resp_token_ids = tokenizer.encode(suffix_resp, add_special_tokens=False) - full_suffix_token_ids = suffix_resp_token_ids + (suffix_token_ids if last_tool or not merge_tool_responses else []) - content_token_ids = tokenizer.encode(content, add_special_tokens=False) - if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids or self.input_ids[-len(suffix_resp_token_ids) :] == suffix_resp_token_ids: - append_token_ids = prefix_resp_token_ids + content_token_ids + full_suffix_token_ids - elif self.input_ids[-len(prefix_resp_token_ids) :] == prefix_resp_token_ids: - append_token_ids = content_token_ids + full_suffix_token_ids - elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: - append_token_ids = prefix_token_ids + prefix_resp_token_ids + content_token_ids + full_suffix_token_ids - else: - raise ValueError(f"Unsupported end of message format: {tokenizer.decode(self.input_ids[-len(prefix_token_ids) :])}") - self.input_ids += append_token_ids - _attention_mask = [1] * len(append_token_ids) - self.attention_mask += _attention_mask - _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - last_position_id = self.position_ids[-1] - _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] - self.loss_mask += [0] * len(append_token_ids) - self.position_ids += _position_ids - else: - raise ValueError(f"Unsupported format: {format}") - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, - {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) + content = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, self.messages[-1]], tools=([tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None), add_generation_prompt=False, tokenize=False) + content_ids = tokenizer.encode(content[self.base_conv_with_gen_prompt_end_pos :], add_special_tokens=False) + self._update_input_ids(content_ids, attention_mask=True, loss_mask=True) + + def add_tool_response_messages(self, tokenizer: PreTrainedTokenizer, contents: list[str]) -> None: + if not contents: + return + self.messages.extend([Message(role="tool", content=content) for content in contents]) + content = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]], tools=([tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None), add_generation_prompt=False, tokenize=False) + content_ids = tokenizer.encode(content[self.base_conv_wo_gen_prompt_end_pos :], add_special_tokens=False) + self._update_input_ids(content_ids, attention_mask=True, loss_mask=False) def update_metrics(self, metrics: Any, tool_id: str) -> None: """ @@ -221,6 +186,19 @@ def finalize( ) -> None: self.state = AsyncRolloutRequestStateEnum.COMPLETED self.reward_scores = reward_scores + if self.enable_tokenization_sanity_check: + full_tokens = tokenizer.apply_chat_template([msg.model_dump() for msg in self.messages], tools=([tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None), add_generation_prompt=False, tokenize=True) + if self.input_ids != full_tokens: + logger.warning("Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.") + logger.info(f"Inference tokenization result:\n{tokenizer.decode(full_tokens, skip_special_tokens=False)}\ntraining content:\n{tokenizer.decode(self.input_ids, skip_special_tokens=False)}") + + # In case we failed to generate the assistant message and the generation prompt ids were already added to input_ids, remove them from the end of input_ids + if self.input_ids[-len(self.generation_prompt_ids) :] == self.generation_prompt_ids: + self.input_ids = self.input_ids[: -len(self.generation_prompt_ids)] + self.attention_mask = self.attention_mask[: -len(self.generation_prompt_ids)] + self.position_ids = self.position_ids[: -len(self.generation_prompt_ids)] + self.loss_mask = self.loss_mask[: -len(self.generation_prompt_ids)] + self.response_ids = self.input_ids[len(self.prompt_ids) :] if finish_reason_type == FinishReasonTypeEnum.STOP: pass diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index bd4d635d259..fdadb324f77 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -63,7 +63,6 @@ OpenAIFunctionToolCall, ) from verl.utils.debug import GPUMemoryLogger -from verl.utils.model import compute_position_id_with_mask from verl.utils.net_utils import is_ipv6 from verl.utils.torch_functional import ( get_response_mask, @@ -749,16 +748,9 @@ async def _async_rollout_a_request( for tool_call in parsed_tool_calls ] ) - for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)): - _req.add_tool_response_message( - self.tokenizer, - resp, - (i == len(parsed_tool_calls) - 1), - format=self.config.multi_turn.format, - ) + _req.add_tool_response_messages(self.tokenizer, [resp for resp, _, _ in tool_call_results]) + for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results): _req.update_metrics(metrics, tool_call.function.name) - if len(_req.input_ids) >= self.config.max_model_len: - break if len(_req.input_ids) >= self.config.max_model_len: finish_reason_type = FinishReasonTypeEnum.STOP break @@ -766,17 +758,17 @@ async def _async_rollout_a_request( else: raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: + # Only continue the conversation if the prompt length is not greater than max_model_len - 1, + # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra token accounts for the EOS token). + if len(_req.get_generation_prompt_ids(self.tokenizer)) + 1 >= self.config.max_model_len: + finish_reason_type = FinishReasonTypeEnum.LENGTH + break output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs) content = output["text"] finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) current_turns += 1 if finish_reason_type == FinishReasonTypeEnum.LENGTH: - _req.add_assistant_message( - self.tokenizer, - content, - already_over_long=True, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.tokenizer, content) break else: if self._function_call_parser and self._function_call_parser.has_tool_call(content): @@ -808,27 +800,14 @@ async def _async_rollout_a_request( ) ) if len(parsed_tool_calls) > 0: - _req.add_assistant_message( - self.tokenizer, - normed_content, - tool_calls=parsed_tool_calls, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.tokenizer, normed_content, tool_calls=parsed_tool_calls) else: - _req.add_assistant_message( - self.tokenizer, - content, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.tokenizer, content) finish_reason_type = FinishReasonTypeEnum.STOP _req.state = AsyncRolloutRequestStateEnum.COMPLETED break else: - _req.add_assistant_message( - self.tokenizer, - content, - format=self.config.multi_turn.format, - ) + _req.add_assistant_message(self.tokenizer, content) break if current_turns >= self.config.multi_turn.max_turns: @@ -851,7 +830,9 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): return _req async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, override_n: bool = True, **kwargs) -> dict: - generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) + generation_prompt_ids = _req.get_generation_prompt_ids(self.tokenizer) + # Adjust max_new_tokens to ensure it is not greater than max_model_len - 1 + # SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra token accounts for the EOS token). max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) if not do_sample: kwargs = dict( @@ -889,9 +870,9 @@ async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, return output async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: - if _req.tools is not None: + if _req.tool_schemas is not None: tool_creation_coroutines = [] - for tool_schema in _req.tools: + for tool_schema in _req.tool_schemas: tool = self._tool_map[tool_schema.function.name] create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) @@ -1050,67 +1031,36 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in for rollout_offset in range(n): if self._tool_schemas: _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] - _tool_schemas = [] - for k in _tools_kwargs.keys(): - _tool_schemas.append(self._tool_map[k].get_openai_tool_schema()) - prompt_with_chat_template = self.tokenizer.apply_chat_template( - conversation=raw_prompt, - tools=[tool.model_dump() for tool in _tool_schemas], - add_generation_prompt=True, - tokenize=False, - return_tensors="pt", - ) - input_data = self.tokenizer( - prompt_with_chat_template, - return_tensors="pt", - add_special_tokens=False, - ) - _input_ids = input_data["input_ids"][0].tolist() - _attention_mask = input_data["attention_mask"][0].tolist() - _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() - if len(_input_ids) > self.config.prompt_length: - logger.warning( - "Prompt {} has length {} greater than max_prompt_len {}", - data_idx, - len(_input_ids), - self.config.prompt_length, - ) - _input_ids = _input_ids[: self.config.prompt_length] - _attention_mask = _attention_mask[: self.config.prompt_length] - _position_ids = _position_ids[: self.config.prompt_length] + _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()] + _input_ids = None + _attention_mask = None else: _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) - _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - _tool_schemas = [] _tools_kwargs = {} + _tool_schemas = None req = AsyncRolloutRequest( batch_data_id=data_idx, rollout_offset=rollout_offset, request_id=str(uuid4()), state=AsyncRolloutRequestStateEnum.PENDING, - messages=[Message.model_validate(msg) for msg in raw_prompt], - tools=_tool_schemas, + messages=raw_prompt.tolist(), + tool_schemas=_tool_schemas, tools_kwargs=_tools_kwargs, input_ids=_input_ids, - prompt_ids=_input_ids, response_ids=[], attention_mask=_attention_mask, - prompt_attention_mask=_attention_mask, response_attention_mask=[], - position_ids=_position_ids, - prompt_position_ids=_position_ids, response_position_ids=[], - loss_mask=[0] * len(_input_ids), - prompt_loss_mask=[0] * len(_input_ids), response_loss_mask=[], reward_scores={}, + max_prompt_len=self.config.prompt_length, max_response_len=self.config.response_length, - max_model_len=min( - self.config.max_model_len, - self.config.prompt_length + self.config.response_length, - ), + max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, + enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check, + tokenizer=self.tokenizer, ) error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}"