Routing replay implemented with SGLang and FSDP#4443
Routing replay implemented with SGLang and FSDP#4443KawaiiNotHawaii wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces MoE routing replay support for Qwen MoE models using SGLang and FSDP. The implementation involves significant changes, including monkey-patching transformer models, adding complex data synchronization logic for routing information, and updating the training loop. While the overall approach is sound, I have identified several critical issues that need to be addressed. These include the use of assert statements in production code paths, which can lead to unexpected crashes, a potential bug in tensor type casting, hardcoded values specific to an internal environment, and a logic bug in tensor manipulation. Addressing these issues is crucial for the stability, correctness, and portability of the new feature.
| selected_experts = routing_map.view(-1, self.top_k)#.long() # TODO cx note: review required | ||
| routing_weights = routing_weights.gather(1, selected_experts) |
There was a problem hiding this comment.
The .long() cast on selected_experts is commented out. The torch.gather operation requires the index tensor to be of type LongTensor. If routing_map is not already a LongTensor, this will cause a RuntimeError at line 140. Given the TODO comment, this seems to be a known risk. To prevent potential crashes, it's safer to ensure the correct dtype by uncommenting the .long() cast.
| selected_experts = routing_map.view(-1, self.top_k)#.long() # TODO cx note: review required | |
| routing_weights = routing_weights.gather(1, selected_experts) | |
| selected_experts = routing_map.view(-1, self.top_k).long() | |
| routing_weights = routing_weights.gather(1, selected_experts) | |
| with marked_timer("update_actor", timing_raw, color="red"): | ||
| batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable | ||
| actor_output = self.actor_rollout_wg.update_actor(batch) | ||
| assert self.actor_rollout_wg._routing_cache == {} and self.actor_rollout_wg._routing_refs == {} and self.actor_rollout_wg._routing_prepared_batches == set(), f"self.actor_rollout_wg._routing_cache of len {len(self.actor_rollout_wg._routing_cache)} ={self.actor_rollout_wg._routing_cache}, self.actor_rollout_wg._routing_refs of len {len(self.actor_rollout_wg._routing_refs)} = {self.actor_rollout_wg._routing_refs == {}}, self.actor_rollout_wg._routing_prepared_batches of len {len(self.actor_rollout_wg._routing_prepared_batches)} = {self.actor_rollout_wg._routing_prepared_batches}" |
There was a problem hiding this comment.
An assert statement is used here to check the state of routing caches. Assertions are for debugging and are removed when Python is run in optimized mode (with the -O flag). Using assert for runtime checks in production code is dangerous as it can be disabled, hiding potential issues, or it can crash the program if the assertion fails. This should be replaced with a proper check that raises an exception or logs a warning if the condition is not met.
| sp_size=self.ulysses_sequence_parallel_size, | ||
| ) | ||
|
|
||
| assert micro_batch.get("routing_ids", None) is not None, f"[ERROR] routing replay not implemented for vlm models." |
There was a problem hiding this comment.
| bad_pad = (mask == 0) & any_nonpad # non-zero routing under mask==0 | ||
| if bad_pad.any(): | ||
| idx = bad_pad.nonzero(as_tuple=False)[:8].tolist() | ||
| raise AssertionError( | ||
| f"[routing] non-zero routing under mask==0 at {idx} (showing up to 8)" | ||
| ) |
There was a problem hiding this comment.
This AssertionError will crash the program if there's a mismatch in routing data. While useful for debugging, assert statements should not be used for data validation in production code paths, as they can be disabled. Please replace this with a ValueError or RuntimeError that provides a clear error message. This comment also applies to other assertions in this method.
if bad_pad.any():
idx = bad_pad.nonzero(as_tuple=False)[:8].tolist()
raise RuntimeError(
f"[routing] non-zero routing under mask==0 at {idx} (showing up to 8)"
)| # set -xuo pipefail | ||
| # while true; do | ||
|
|
||
| export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com |
There was a problem hiding this comment.
| try: | ||
| # TODO cx_note add condition to only consolidate if is to CPU, or it is said to be time-consuming on GPU | ||
| # self.batch = self.batch.contiguous().consolidate() # TODO cx_note will this cause memory leak of the old batch? check out "/nlp_group/sunchenxi/qwen_moe/logs/qwen3_moe_2025-10-28-23:08:51.log" for log of enabling it (searching is_view=no and you'll find all no views) | ||
| self.batch = self.batch.to(device) | ||
| except RuntimeError as re: |
There was a problem hiding this comment.
The inner try...except block is redundant because it attempts the same self.batch.to(device) operation that failed in the outer try block. If the first call fails, the second will fail for the same reason. The logic should fall back directly to _safely_move_tensordict in the outer except block.
self.batch = self._safely_move_tensordict(self.batch, device)| if 'boxed' not in solution_str[-300:]: | ||
| return False, "" | ||
| answer = str(answer) | ||
|
|
||
| try: | ||
| solution_val = parse(solution_str[-300:]) | ||
| if "boxed" in answer: | ||
| gt_val = parse(answer) | ||
| else: | ||
| boxed_answer = "\\boxed{" + answer + "}" | ||
| gt_val = parse(boxed_answer) | ||
| if m_verify(solution_val, gt_val): | ||
| return True, "" | ||
| else: | ||
| return False, "" | ||
| except Exception as e: | ||
| return False, "" |
There was a problem hiding this comment.
The refactored verify function no longer returns the extracted prediction; it always returns an empty string. This is a regression from the previous implementation, which returned the predicted value. The compute_score function relies on this prediction to include it in its returned dictionary. With this change, the pred field in the output of compute_score will always be empty, which is likely not the intended behavior. The verify function should be updated to return the extracted prediction.
| print(f"Routing replay not enabled.") | ||
| return self._batch_level_generate_sequences(prompts, **kwargs) | ||
|
|
||
| step_bsz = 32 # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32 |
There was a problem hiding this comment.
The step_bsz is hardcoded to 32 due to a limitation in an external library (SGLang). This magic number makes the code less maintainable and harder to adapt if the library's behavior changes or if it's run in a different environment. This value should be made configurable, for example, by adding it to the rollout configuration.
| step_bsz = 32 # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32 | |
| step_bsz = self.config.get("routing_replay_step_bsz", 32) # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32 |
|
For routing replay, we still have to wait vllm and sglang offical support. |
Hi, I appreciate your implementation of routing replay on verl and sglang. However when I tried to run the script from �[36m(WorkerDict pid=3380574)�[0m Exception: Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.After setting �[36m(TaskRunner pid=3317161)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_compute_log_prob()�[39m (pid=3334909, ip=192.168.102.18, actor_id=201afa1f8e08842d00c0e71601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f38f2d692b0>)
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/ray/base.py", line 700, in func
�[36m(TaskRunner pid=3317161)�[0m return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/base/decorator.py", line 442, in inner
�[36m(TaskRunner pid=3317161)�[0m return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
�[36m(TaskRunner pid=3317161)�[0m return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/utils/profiler/profile.py", line 256, in wrapper
�[36m(TaskRunner pid=3317161)�[0m return func(self_instance, *args, **kwargs_inner)
�[36m(TaskRunner pid=3317161)�[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 1597, in compute_log_prob
�[36m(TaskRunner pid=3317161)�[0m self._ensure_routing_prepared(data, tag='[compute_log_prob]')
�[36m(TaskRunner pid=3317161)�[0m File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 780, in _ensure_routing_prepared
�[36m(TaskRunner pid=3317161)�[0m raise AssertionError(
�[36m(TaskRunner pid=3317161)�[0m AssertionError: [routing] non-zero routing under mask==0 at [[0, 308], [0, 309], [0, 310], [0, 311], [0, 312], [0, 313], [0, 314], [0, 315]] (showing up to 8)Do you have any thoughts on how to solve these problems? #!/usr/bin/env bash
set -xeuo pipefail
# set -xuo pipefail
# while true; do
# export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com
timestamp=$(date +"%Y-%m-%d-%H:%M:%S")""
loss_mode=vanilla
adv_estimator=grpo
loss_agg_mode="token-mean"
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((256))
max_response_length=$((256))
enable_overlong_buffer=False
overlong_buffer_len=$((256))
overlong_penalty_factor=1.0
# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=32
train_prompt_mini_bsz=32
n_resp_per_prompt=4
# for TIS
imp_ratio_cap=-1
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"}
NNODES=${NNODES:-1}
# NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay'
info_tag=""
exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag}
# Paths
MODEL_PATH=/model/Qwen/Qwen1.5-MoE-A2.7B-Chat
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"}
OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"}
OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"}
mkdir -p $OUTPUTS_ROLLOUT_DIR
mkdir -p $OUTPUTS_VALIDATION_DIR
mkdir -p "${WORKING_DIR}/logs"
TRAIN_FILE="/data/math_data/dapo-math-17k.parquet"
TEST_FILE="/data/math_data/aime-2024.parquet"
echo $OUTPUTS_ROLLOUT_DIR
# rollout
enable_routing_replay=True
rollout_mode="sync"
return_raw_chat="True"
rollout_name="sglang"
if [ "$rollout_mode" = "async" ]; then
# NOTE async rollout mode is not supported yet.
export VLLM_USE_V1=1
fi
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_temperature=1.0
val_top_p=0.7
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=8
fsdp_size=8 # default -1
# Trade compute for memory
entropy_checkpointing=True
entropy_from_logits_with_chunking=True
# export RAY_DEDUP_LOGS=0
PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \
--config-name='dapo_trainer.yaml'\
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.return_raw_chat=${return_raw_chat} \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_activation_offload=${offload} \
critic.model.enable_activation_offload=${offload} \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.name=$rollout_name \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \
actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \
trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \
trainer.logger='["console","wandb"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=25 \
trainer.save_freq=50 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=5 \
actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log
|
Hi, this is a known issue of SGLang expert distribution recorder. Specifically speaking, when an SGLang inference entry point has a batch size larger than 32, it almost always happen. |
Thanks for your reply! I tried you recommended setup by adjusting bsz as follows, however, # Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=4
train_prompt_mini_bsz=4
n_resp_per_prompt=2 |
What GPU are you using? My test is run on 94G H100. |
|
I tried running your code and script, but it seems to hang—there's been no response for several hours. No more log info.... The experiment was run on a single node with H100 GPUs. I reset bsz with It has same result with hang . Found it. when i set infer_tp as 8, it works. |
What does this PR do?
This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.
Brief implementation details:
After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.
Routing replay significantly reduces probability difference between rollout and training:

Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)