Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions docs/sglang_multiturn/multiturn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,151 @@ Finally, set the ``tools_config_file`` in your rollout config:

This allows integration of customized tool behaviors during actor rollout steps.

Multi-turn Tokenization
~~~~~~~~~~~~~~~~~~~~~~~

Tokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it’s hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles.

To address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we:

1. Apply the chat template to all prior messages (`messages[:i]`).
2. Apply the chat template again including the latest message (`messages[:i+1]`).
3. Tokenize only the *delta* between these two serialized message strings.

This ensures that only tokens generated by the assistant are included in the loss mask.

.. code-block:: python

# Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True
prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)
curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)
token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)
loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens

While we’ve validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout.

If you see the following warning, enable `INFO` log level to inspect the mismatched outputs:

.. code-block::

Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md.

If the discrepancy is expected, you can disable the sanity check via:

``actor_rollout_ref.rollout.multi_turn.enable_tokenization_sanity_check=False``

Special Cases
^^^^^^^^^^^^^

Some models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate.

For example, for the following conversation:

.. code-block:: python

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "<think>user asked about a simple math question.</think> 2 + 2 = 4."},
{"role": "user", "content": "Explain why."},
{"role": "assistant", "content": "<think>user wants to know the reasoning behind the answer. Search for a good explanation</think>",
"tool_calls": [{"id": "tool1", "type": "search", "arguments": {"query": "Why is 2 + 2 = 4?"}}]},
{"role": "tool", "content": "The sum of two and two is four because it is a basic arithmetic operation."},
{"role": "assistant", "content": "<think>The tool provided a good explanation.</think>The sum of two and two is four because it is a basic arithmetic operation."}
]

1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template.

.. code-block:: text

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is 2 + 2?<|im_end|>
<|im_start|>assistant
2 + 2 = 4.<|im_end|>
<|im_start|>user
Explain why.<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}}
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
The sum of two and two is four because it is a basic arithmetic operation.
</tool_response><|im_end|>
<|im_start|>assistant
<think>The tool provided a good explanation.</think> The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>

2. Qwen3 series will remove all reasoning content before the last user message.

.. code-block:: text

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is 2 + 2?<|im_end|>
<|im_start|>assistant
2 + 2 = 4.<|im_end|>
<|im_start|>user
Explain why.<|im_end|>
<|im_start|>assistant
<think>
user wants to know the reasoning behind the answer. Search for a good explanation
</think>

<tool_call>
{"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}}
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
The sum of two and two is four because it is a basic arithmetic operation.
</tool_response><|im_end|>
<|im_start|>assistant
<think>
The tool provided a good explanation.
</think>

The sum of two and two is four because it is a basic arithmetic operation.<|im_end|>

To handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn’t include assistant messages or reasoning content, it remains consistent across turns.

.. code-block:: python

BASE_CHAT_HISTORY = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "I am a user."}
]
prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False)
curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False)
token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False)
loss_mask += [1] * len(token_ids)

This method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision:

.. code-block:: bash

pip install huggingface_hub
huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81

.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81

Discrepancy Between Training and Inference Templates
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Although the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not.

This mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout.

However, this approach comes with trade-offs:

1. Long reasoning contents can easily exceed the model’s context window, especially in multi-turn rollout.
2. There’s a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production.

We are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable:

``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True``

GSM8K Multi-turn Training Performance
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 0 additions & 1 deletion examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@ actor_rollout_ref:
multi_turn:
enable: True
max_turns: 5
format: qwen
# tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml"
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ actor_rollout_ref:
multi_turn:
enable: True
max_turns: 5
format: qwen
# tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml"

53 changes: 53 additions & 0 deletions examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# run on 8xH100
# make sure your current working directory is the root of the project

set -x

ulimit -n 65535

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
algorithm.adv_estimator=grpo \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen3-4B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
trainer.total_epochs=15 $@

Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search
req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)
assert len(req_list) == 1
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
assert len(req_list[0].tools) == 1
print(type(req_list[0].tools[0]))
assert req_list[0].tools[0] == OpenAIFunctionToolSchema(
assert len(req_list[0].tool_schemas) == 1
print(type(req_list[0].tool_schemas[0]))
assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(
type="function",
function=OpenAIFunctionSchema(
name="search",
Expand Down
6 changes: 3 additions & 3 deletions tests/workers/rollout/test_sglang_async_rollout_sf_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo
req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)
assert len(req_list) == 1
assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING
assert len(req_list[0].tools) == 1
print(type(req_list[0].tools[0]))
assert req_list[0].tools[0] == OpenAIFunctionToolSchema(
assert len(req_list[0].tool_schemas) == 1
print(type(req_list[0].tool_schemas[0]))
assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema(
type="function",
function=OpenAIFunctionSchema(
name="code_interpreter",
Expand Down
3 changes: 2 additions & 1 deletion tests/workers/rollout/utils_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_par
"max_turns": 4,
"enable": True,
"tool_config_path": tool_config_path,
"format": "chatml",
"use_inference_chat_template": False,
"enable_tokenization_sanity_check": True,
},
"max_model_len": None,
**sampling_params,
Expand Down
11 changes: 10 additions & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,16 @@ actor_rollout_ref:
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
max_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
format: chatml # chatml, more formats will be supported in the future
# - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
# - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
# which may contain additional content such as reasoning content. And this will lead to longer prompts.
use_inference_chat_template: True
# Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
# To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
enable_tokenization_sanity_check: True

critic:
rollout_n: ${actor_rollout_ref.rollout.n}
Expand Down
15 changes: 12 additions & 3 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,21 @@ actor_rollout_ref:
# null for no tool
tool_config_path: null

# chatml, more formats will be supported in the future
format: chatml

# null for default callback
completion_callback: null

# - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior.
# - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output,
# which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts.
use_inference_chat_template: False

# Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation.
# To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids.
# Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them.
# To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template:
# Qwen/QwQ-32B, Qwen/Qwen3-xxB
enable_tokenization_sanity_check: True

# configs for the critic
critic:

Expand Down
Loading
Loading