Skip to content

Routing replay implemented with SGLang and FSDP#4443

Open
KawaiiNotHawaii wants to merge 1 commit intoverl-project:mainfrom
KawaiiNotHawaii:routingReplay
Open

Routing replay implemented with SGLang and FSDP#4443
KawaiiNotHawaii wants to merge 1 commit intoverl-project:mainfrom
KawaiiNotHawaii:routingReplay

Conversation

@KawaiiNotHawaii
Copy link
Copy Markdown

@KawaiiNotHawaii KawaiiNotHawaii commented Dec 7, 2025

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.

Brief implementation details:
After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.

Routing replay significantly reduces probability difference between rollout and training:
Screenshot 2025-12-09 at 17 41 44

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MoE routing replay support for Qwen MoE models using SGLang and FSDP. The implementation involves significant changes, including monkey-patching transformer models, adding complex data synchronization logic for routing information, and updating the training loop. While the overall approach is sound, I have identified several critical issues that need to be addressed. These include the use of assert statements in production code paths, which can lead to unexpected crashes, a potential bug in tensor type casting, hardcoded values specific to an internal environment, and a logic bug in tensor manipulation. Addressing these issues is crucial for the stability, correctness, and portability of the new feature.

Comment on lines +139 to +140
selected_experts = routing_map.view(-1, self.top_k)#.long() # TODO cx note: review required
routing_weights = routing_weights.gather(1, selected_experts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The .long() cast on selected_experts is commented out. The torch.gather operation requires the index tensor to be of type LongTensor. If routing_map is not already a LongTensor, this will cause a RuntimeError at line 140. Given the TODO comment, this seems to be a known risk. To prevent potential crashes, it's safer to ensure the correct dtype by uncommenting the .long() cast.

Suggested change
selected_experts = routing_map.view(-1, self.top_k)#.long() # TODO cx note: review required
routing_weights = routing_weights.gather(1, selected_experts)
selected_experts = routing_map.view(-1, self.top_k).long()
routing_weights = routing_weights.gather(1, selected_experts)

with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
assert self.actor_rollout_wg._routing_cache == {} and self.actor_rollout_wg._routing_refs == {} and self.actor_rollout_wg._routing_prepared_batches == set(), f"self.actor_rollout_wg._routing_cache of len {len(self.actor_rollout_wg._routing_cache)} ={self.actor_rollout_wg._routing_cache}, self.actor_rollout_wg._routing_refs of len {len(self.actor_rollout_wg._routing_refs)} = {self.actor_rollout_wg._routing_refs == {}}, self.actor_rollout_wg._routing_prepared_batches of len {len(self.actor_rollout_wg._routing_prepared_batches)} = {self.actor_rollout_wg._routing_prepared_batches}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

An assert statement is used here to check the state of routing caches. Assertions are for debugging and are removed when Python is run in optimized mode (with the -O flag). Using assert for runtime checks in production code is dangerous as it can be disabled, hiding potential issues, or it can crash the program if the assertion fails. This should be replaced with a proper check that raises an exception or logs a warning if the condition is not met.

sp_size=self.ulysses_sequence_parallel_size,
)

assert micro_batch.get("routing_ids", None) is not None, f"[ERROR] routing replay not implemented for vlm models."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

An assert statement is used here. Assertions should not be used for runtime checks in production code as they can be disabled with the -O flag, potentially hiding bugs. If this check is important, it should be converted to a proper conditional check that raises an exception.

Comment on lines +776 to +781
bad_pad = (mask == 0) & any_nonpad # non-zero routing under mask==0
if bad_pad.any():
idx = bad_pad.nonzero(as_tuple=False)[:8].tolist()
raise AssertionError(
f"[routing] non-zero routing under mask==0 at {idx} (showing up to 8)"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This AssertionError will crash the program if there's a mismatch in routing data. While useful for debugging, assert statements should not be used for data validation in production code paths, as they can be disabled. Please replace this with a ValueError or RuntimeError that provides a clear error message. This comment also applies to other assertions in this method.

                if bad_pad.any():
                    idx = bad_pad.nonzero(as_tuple=False)[:8].tolist()
                    raise RuntimeError(
                        f"[routing] non-zero routing under mask==0 at {idx} (showing up to 8)"
                    )

# set -xuo pipefail
# while true; do

export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This script contains hardcoded proxy settings that are specific to an internal environment. This makes the script not portable and leaks internal infrastructure details. These should be removed or parameterized to make the script usable in different environments.

Comment thread verl/protocol.py
Comment on lines +673 to +677
try:
# TODO cx_note add condition to only consolidate if is to CPU, or it is said to be time-consuming on GPU
# self.batch = self.batch.contiguous().consolidate() # TODO cx_note will this cause memory leak of the old batch? check out "/nlp_group/sunchenxi/qwen_moe/logs/qwen3_moe_2025-10-28-23:08:51.log" for log of enabling it (searching is_view=no and you'll find all no views)
self.batch = self.batch.to(device)
except RuntimeError as re:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner try...except block is redundant because it attempts the same self.batch.to(device) operation that failed in the outer try block. If the first call fails, the second will fail for the same reason. The logic should fall back directly to _safely_move_tensordict in the outer except block.

                    self.batch = self._safely_move_tensordict(self.batch, device)

Comment on lines +309 to +325
if 'boxed' not in solution_str[-300:]:
return False, ""
answer = str(answer)

try:
solution_val = parse(solution_str[-300:])
if "boxed" in answer:
gt_val = parse(answer)
else:
boxed_answer = "\\boxed{" + answer + "}"
gt_val = parse(boxed_answer)
if m_verify(solution_val, gt_val):
return True, ""
else:
return False, ""
except Exception as e:
return False, ""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactored verify function no longer returns the extracted prediction; it always returns an empty string. This is a regression from the previous implementation, which returned the predicted value. The compute_score function relies on this prediction to include it in its returned dictionary. With this change, the pred field in the output of compute_score will always be empty, which is likely not the intended behavior. The verify function should be updated to return the extracted prediction.

print(f"Routing replay not enabled.")
return self._batch_level_generate_sequences(prompts, **kwargs)

step_bsz = 32 # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The step_bsz is hardcoded to 32 due to a limitation in an external library (SGLang). This magic number makes the code less maintainable and harder to adapt if the library's behavior changes or if it's run in a different environment. This value should be made configurable, for example, by adding it to the rollout configuration.

Suggested change
step_bsz = 32 # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32
step_bsz = self.config.get("routing_replay_step_bsz", 32) # NOTE Curretnly, SGLang expert_distribution_recorder is not functioning as expected when per entry bsz > 32

@wuxibin89
Copy link
Copy Markdown
Collaborator

@Shadowyuan616
Copy link
Copy Markdown

Shadowyuan616 commented Dec 9, 2025

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.

Brief implementation details: After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...

  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)

    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Hi, I appreciate your implementation of routing replay on verl and sglang. However when I tried to run the script from verl_routing_replay/recipe/moe_routing_replay/run.sh with this verl PR and your SGLang PR, I had the following exception:

�[36m(WorkerDict pid=3380574)�[0m Exception: Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.

After setting args["expert_distribution_recorder_mode"] = "per_token" in verl_routing_replay/verl/workers/rollout/sglang_rollout/sglang_rollout.py, there was an another problem:

�[36m(TaskRunner pid=3317161)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_compute_log_prob()�[39m (pid=3334909, ip=192.168.102.18, actor_id=201afa1f8e08842d00c0e71601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f38f2d692b0>)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/ray/base.py", line 700, in func
�[36m(TaskRunner pid=3317161)�[0m     return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/base/decorator.py", line 442, in inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/profiler/profile.py", line 256, in wrapper
�[36m(TaskRunner pid=3317161)�[0m     return func(self_instance, *args, **kwargs_inner)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 1597, in compute_log_prob
�[36m(TaskRunner pid=3317161)�[0m     self._ensure_routing_prepared(data, tag='[compute_log_prob]')
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 780, in _ensure_routing_prepared
�[36m(TaskRunner pid=3317161)�[0m     raise AssertionError(
�[36m(TaskRunner pid=3317161)�[0m AssertionError: [routing] non-zero routing under mask==0 at [[0, 308], [0, 309], [0, 310], [0, 311], [0, 312], [0, 313], [0, 314], [0, 315]] (showing up to 8)

Do you have any thoughts on how to solve these problems?
Here is my version of run.sh

#!/usr/bin/env bash
set -xeuo pipefail

# set -xuo pipefail
# while true; do

# export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com

timestamp=$(date +"%Y-%m-%d-%H:%M:%S")""

loss_mode=vanilla
adv_estimator=grpo
loss_agg_mode="token-mean"

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((256))
max_response_length=$((256))

enable_overlong_buffer=False
overlong_buffer_len=$((256))
overlong_penalty_factor=1.0

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=32
train_prompt_mini_bsz=32
n_resp_per_prompt=4



# for TIS
imp_ratio_cap=-1


# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"}
NNODES=${NNODES:-1}
# NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay'
info_tag=""
exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag}


# Paths
MODEL_PATH=/model/Qwen/Qwen1.5-MoE-A2.7B-Chat
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"}
OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"}
OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"}
mkdir -p $OUTPUTS_ROLLOUT_DIR
mkdir -p $OUTPUTS_VALIDATION_DIR
mkdir -p "${WORKING_DIR}/logs"

TRAIN_FILE="/data/math_data/dapo-math-17k.parquet"
TEST_FILE="/data/math_data/aime-2024.parquet"

echo $OUTPUTS_ROLLOUT_DIR

# rollout
enable_routing_replay=True
rollout_mode="sync"
return_raw_chat="True"
rollout_name="sglang"
if [ "$rollout_mode" = "async" ]; then
    # NOTE async rollout mode is not supported yet.
    export VLLM_USE_V1=1
fi



# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

val_temperature=1.0
val_top_p=0.7
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout



# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=8
fsdp_size=8  # default -1

# Trade compute for memory
entropy_checkpointing=True
entropy_from_logits_with_chunking=True

# export RAY_DEDUP_LOGS=0
PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \
    --config-name='dapo_trainer.yaml'\
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.return_raw_chat=${return_raw_chat} \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
    actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \
    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_activation_offload=${offload} \
    critic.model.enable_activation_offload=${offload} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.name=$rollout_name \
    actor_rollout_ref.rollout.mode=$rollout_mode \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
    reward_model.reward_manager=dapo \
    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
    trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \
    trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \
    trainer.logger='["console","wandb"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=False \
    trainer.test_freq=25 \
    trainer.save_freq=50 \
    trainer.total_epochs=1 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=auto \
    trainer.log_val_generations=5 \
    actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log

@KawaiiNotHawaii
Copy link
Copy Markdown
Author

KawaiiNotHawaii commented Dec 9, 2025

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.
Brief implementation details: After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...

  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)

    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Hi, I appreciate your implementation of routing replay on verl and sglang. However when I tried to run the script from verl_routing_replay/recipe/moe_routing_replay/run.sh with this verl PR and your SGLang PR, I had the following exception:

�[36m(WorkerDict pid=3380574)�[0m Exception: Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.

After setting args["expert_distribution_recorder_mode"] = "per_token" in verl_routing_replay/verl/workers/rollout/sglang_rollout/sglang_rollout.py, there was an another problem:

�[36m(TaskRunner pid=3317161)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_compute_log_prob()�[39m (pid=3334909, ip=192.168.102.18, actor_id=201afa1f8e08842d00c0e71601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f38f2d692b0>)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/ray/base.py", line 700, in func
�[36m(TaskRunner pid=3317161)�[0m     return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/base/decorator.py", line 442, in inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/profiler/profile.py", line 256, in wrapper
�[36m(TaskRunner pid=3317161)�[0m     return func(self_instance, *args, **kwargs_inner)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 1597, in compute_log_prob
�[36m(TaskRunner pid=3317161)�[0m     self._ensure_routing_prepared(data, tag='[compute_log_prob]')
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 780, in _ensure_routing_prepared
�[36m(TaskRunner pid=3317161)�[0m     raise AssertionError(
�[36m(TaskRunner pid=3317161)�[0m AssertionError: [routing] non-zero routing under mask==0 at [[0, 308], [0, 309], [0, 310], [0, 311], [0, 312], [0, 313], [0, 314], [0, 315]] (showing up to 8)

Do you have any thoughts on how to solve these problems? Here is my version of run.sh

#!/usr/bin/env bash
set -xeuo pipefail

# set -xuo pipefail
# while true; do

# export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com

timestamp=$(date +"%Y-%m-%d-%H:%M:%S")""

loss_mode=vanilla
adv_estimator=grpo
loss_agg_mode="token-mean"

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((256))
max_response_length=$((256))

enable_overlong_buffer=False
overlong_buffer_len=$((256))
overlong_penalty_factor=1.0

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=32
train_prompt_mini_bsz=32
n_resp_per_prompt=4



# for TIS
imp_ratio_cap=-1


# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"}
NNODES=${NNODES:-1}
# NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay'
info_tag=""
exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag}


# Paths
MODEL_PATH=/model/Qwen/Qwen1.5-MoE-A2.7B-Chat
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"}
OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"}
OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"}
mkdir -p $OUTPUTS_ROLLOUT_DIR
mkdir -p $OUTPUTS_VALIDATION_DIR
mkdir -p "${WORKING_DIR}/logs"

TRAIN_FILE="/data/math_data/dapo-math-17k.parquet"
TEST_FILE="/data/math_data/aime-2024.parquet"

echo $OUTPUTS_ROLLOUT_DIR

# rollout
enable_routing_replay=True
rollout_mode="sync"
return_raw_chat="True"
rollout_name="sglang"
if [ "$rollout_mode" = "async" ]; then
    # NOTE async rollout mode is not supported yet.
    export VLLM_USE_V1=1
fi



# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

val_temperature=1.0
val_top_p=0.7
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout



# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=8
fsdp_size=8  # default -1

# Trade compute for memory
entropy_checkpointing=True
entropy_from_logits_with_chunking=True

# export RAY_DEDUP_LOGS=0
PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \
    --config-name='dapo_trainer.yaml'\
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.return_raw_chat=${return_raw_chat} \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
    actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \
    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_activation_offload=${offload} \
    critic.model.enable_activation_offload=${offload} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.name=$rollout_name \
    actor_rollout_ref.rollout.mode=$rollout_mode \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
    reward_model.reward_manager=dapo \
    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
    trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \
    trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \
    trainer.logger='["console","wandb"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=False \
    trainer.test_freq=25 \
    trainer.save_freq=50 \
    trainer.total_epochs=1 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=auto \
    trainer.log_val_generations=5 \
    actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log

Hi, this is a known issue of SGLang expert distribution recorder. Specifically speaking, when an SGLang inference entry point has a batch size larger than 32, it almost always happen.
In your case, given bsz=32, rollout_n=4 and n_nodes=1, the actual batch size for the entry point is 32*4.
The easiest solution would be to scale the nodes. I also tried to manually split it down to pieces of smaller batches of size 32 but results in vain.
I recommend to checkout the newer sglang release and I'm also actively trying to update the SGLang routing replay PR to the its newest release.

@Shadowyuan616
Copy link
Copy Markdown

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.
Brief implementation details: After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...

  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)

    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Hi, I appreciate your implementation of routing replay on verl and sglang. However when I tried to run the script from verl_routing_replay/recipe/moe_routing_replay/run.sh with this verl PR and your SGLang PR, I had the following exception:

�[36m(WorkerDict pid=3380574)�[0m Exception: Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.

After setting args["expert_distribution_recorder_mode"] = "per_token" in verl_routing_replay/verl/workers/rollout/sglang_rollout/sglang_rollout.py, there was an another problem:

�[36m(TaskRunner pid=3317161)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_compute_log_prob()�[39m (pid=3334909, ip=192.168.102.18, actor_id=201afa1f8e08842d00c0e71601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f38f2d692b0>)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/ray/base.py", line 700, in func
�[36m(TaskRunner pid=3317161)�[0m     return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/base/decorator.py", line 442, in inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/profiler/profile.py", line 256, in wrapper
�[36m(TaskRunner pid=3317161)�[0m     return func(self_instance, *args, **kwargs_inner)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 1597, in compute_log_prob
�[36m(TaskRunner pid=3317161)�[0m     self._ensure_routing_prepared(data, tag='[compute_log_prob]')
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 780, in _ensure_routing_prepared
�[36m(TaskRunner pid=3317161)�[0m     raise AssertionError(
�[36m(TaskRunner pid=3317161)�[0m AssertionError: [routing] non-zero routing under mask==0 at [[0, 308], [0, 309], [0, 310], [0, 311], [0, 312], [0, 313], [0, 314], [0, 315]] (showing up to 8)

Do you have any thoughts on how to solve these problems? Here is my version of run.sh

#!/usr/bin/env bash
set -xeuo pipefail

# set -xuo pipefail
# while true; do

# export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com

timestamp=$(date +"%Y-%m-%d-%H:%M:%S")""

loss_mode=vanilla
adv_estimator=grpo
loss_agg_mode="token-mean"

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((256))
max_response_length=$((256))

enable_overlong_buffer=False
overlong_buffer_len=$((256))
overlong_penalty_factor=1.0

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=32
train_prompt_mini_bsz=32
n_resp_per_prompt=4



# for TIS
imp_ratio_cap=-1


# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"}
NNODES=${NNODES:-1}
# NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay'
info_tag=""
exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag}


# Paths
MODEL_PATH=/model/Qwen/Qwen1.5-MoE-A2.7B-Chat
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"}
OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"}
OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"}
mkdir -p $OUTPUTS_ROLLOUT_DIR
mkdir -p $OUTPUTS_VALIDATION_DIR
mkdir -p "${WORKING_DIR}/logs"

TRAIN_FILE="/data/math_data/dapo-math-17k.parquet"
TEST_FILE="/data/math_data/aime-2024.parquet"

echo $OUTPUTS_ROLLOUT_DIR

# rollout
enable_routing_replay=True
rollout_mode="sync"
return_raw_chat="True"
rollout_name="sglang"
if [ "$rollout_mode" = "async" ]; then
    # NOTE async rollout mode is not supported yet.
    export VLLM_USE_V1=1
fi



# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

val_temperature=1.0
val_top_p=0.7
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout



# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=8
fsdp_size=8  # default -1

# Trade compute for memory
entropy_checkpointing=True
entropy_from_logits_with_chunking=True

# export RAY_DEDUP_LOGS=0
PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \
    --config-name='dapo_trainer.yaml'\
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.return_raw_chat=${return_raw_chat} \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
    actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \
    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_activation_offload=${offload} \
    critic.model.enable_activation_offload=${offload} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.name=$rollout_name \
    actor_rollout_ref.rollout.mode=$rollout_mode \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
    reward_model.reward_manager=dapo \
    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
    trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \
    trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \
    trainer.logger='["console","wandb"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=False \
    trainer.test_freq=25 \
    trainer.save_freq=50 \
    trainer.total_epochs=1 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=auto \
    trainer.log_val_generations=5 \
    actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log

Hi, this is a known issue of SGLang expert distribution recorder. Specifically speaking, when an SGLang inference entry point has a batch size larger than 32, it almost always happen. In your case, given bsz=32, rollout_n=4 and n_nodes=1, the actual batch size for the entry point is 32*4. The easiest solution would be to scale the nodes. I also tried to manually split it down to pieces of smaller batches of size 32 but results in vain. I recommend to checkout the newer sglang release and I'm also actively trying to update the SGLang routing replay PR to the its newest release.

Thanks for your reply! I tried you recommended setup by adjusting bsz as follows, however, non-zero routing under mask==0 issue remains. I might have a deeper research on it later. Still looking forward to having routing replay function on verl/sglang :)

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=4
train_prompt_mini_bsz=4
n_resp_per_prompt=2

@KawaiiNotHawaii
Copy link
Copy Markdown
Author

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

This PR adds MoE routing replay support for Qwen MoE models with this SGLang PR and FSDP.
Brief implementation details: After inference stage, routing information is returned along with the decoding results on each tp_rank 0. Then, before training, after batch sliced and distributed to each node, routing information will be synced via NCCL p2p exchange which is much faster than the original Ray implementation.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...

  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)

    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Hi, I appreciate your implementation of routing replay on verl and sglang. However when I tried to run the script from verl_routing_replay/recipe/moe_routing_replay/run.sh with this verl PR and your SGLang PR, I had the following exception:

�[36m(WorkerDict pid=3380574)�[0m Exception: Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder.

After setting args["expert_distribution_recorder_mode"] = "per_token" in verl_routing_replay/verl/workers/rollout/sglang_rollout/sglang_rollout.py, there was an another problem:

�[36m(TaskRunner pid=3317161)�[0m Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_compute_log_prob()�[39m (pid=3334909, ip=192.168.102.18, actor_id=201afa1f8e08842d00c0e71601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f38f2d692b0>)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/ray/base.py", line 700, in func
�[36m(TaskRunner pid=3317161)�[0m     return getattr(self.worker_dict[key], name)(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/single_controller/base/decorator.py", line 442, in inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/transferqueue_utils.py", line 199, in dummy_inner
�[36m(TaskRunner pid=3317161)�[0m     return func(*args, **kwargs)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/utils/profiler/profile.py", line 256, in wrapper
�[36m(TaskRunner pid=3317161)�[0m     return func(self_instance, *args, **kwargs_inner)
�[36m(TaskRunner pid=3317161)�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 1597, in compute_log_prob
�[36m(TaskRunner pid=3317161)�[0m     self._ensure_routing_prepared(data, tag='[compute_log_prob]')
�[36m(TaskRunner pid=3317161)�[0m   File "/code/RoutingReplay/verl_routing_replay/verl/workers/fsdp_workers.py", line 780, in _ensure_routing_prepared
�[36m(TaskRunner pid=3317161)�[0m     raise AssertionError(
�[36m(TaskRunner pid=3317161)�[0m AssertionError: [routing] non-zero routing under mask==0 at [[0, 308], [0, 309], [0, 310], [0, 311], [0, 312], [0, 313], [0, 314], [0, 315]] (showing up to 8)

Do you have any thoughts on how to solve these problems? Here is my version of run.sh

#!/usr/bin/env bash
set -xeuo pipefail

# set -xuo pipefail
# while true; do

# export http_proxy=http://oversea-squid2.ko.txyun:11080 https_proxy=http://oversea-squid2.ko.txyun:11080 no_proxy=localhost,127.0.0.1,localaddress,localdomain.com,internal,corp.kuaishou.com,test.gifshow.com,staging.kuaishou.com

timestamp=$(date +"%Y-%m-%d-%H:%M:%S")""

loss_mode=vanilla
adv_estimator=grpo
loss_agg_mode="token-mean"

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((256))
max_response_length=$((256))

enable_overlong_buffer=False
overlong_buffer_len=$((256))
overlong_penalty_factor=1.0

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=32
train_prompt_mini_bsz=32
n_resp_per_prompt=4



# for TIS
imp_ratio_cap=-1


# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/recipe/moe/runtime_env.yaml"}
NNODES=${NNODES:-1}
# NNODES=${NNODES:-$(sort -u /etc/mpi/hostfile | wc -l)}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}

project_name='DAPO-Qwen3-MOE-30B-FSDP-RoutingReplay'
info_tag=""
exp_name='qwen3moe-'${loss_mode}-${train_prompt_bsz}_${train_prompt_mini_bsz}'-n-'${n_resp_per_prompt}'-len-'${max_response_length}-${info_tag}


# Paths
MODEL_PATH=/model/Qwen/Qwen1.5-MoE-A2.7B-Chat
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/ckpts_hdd/${project_name}/${exp_name}"}
OUTPUTS_ROLLOUT_DIR=${OUTPUTS_ROLLOUT_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/rollout/"}
OUTPUTS_VALIDATION_DIR=${OUTPUTS_VALIDATION_DIR:-"${WORKING_DIR}/outputs/${project_name}/${exp_name}/validate/"}
mkdir -p $OUTPUTS_ROLLOUT_DIR
mkdir -p $OUTPUTS_VALIDATION_DIR
mkdir -p "${WORKING_DIR}/logs"

TRAIN_FILE="/data/math_data/dapo-math-17k.parquet"
TEST_FILE="/data/math_data/aime-2024.parquet"

echo $OUTPUTS_ROLLOUT_DIR

# rollout
enable_routing_replay=True
rollout_mode="sync"
return_raw_chat="True"
rollout_name="sglang"
if [ "$rollout_mode" = "async" ]; then
    # NOTE async rollout mode is not supported yet.
    export VLLM_USE_V1=1
fi



# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

val_temperature=1.0
val_top_p=0.7
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout



# Performance Related Parameter
sp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
offload=True
gen_tp=8
fsdp_size=8  # default -1

# Trade compute for memory
entropy_checkpointing=True
entropy_from_logits_with_chunking=True

# export RAY_DEDUP_LOGS=0
PYTHONUNBUFFERED=1 python3 -m recipe.moe_routing_replay.main_dapo --config-path=config \
    --config-name='dapo_trainer.yaml'\
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.return_raw_chat=${return_raw_chat} \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    actor_rollout_ref.actor.entropy_checkpointing=${entropy_checkpointing} \
    actor_rollout_ref.actor.entropy_from_logits_with_chunking=${entropy_from_logits_with_chunking} \
    actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_activation_offload=${offload} \
    critic.model.enable_activation_offload=${offload} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.rollout.name=$rollout_name \
    actor_rollout_ref.rollout.mode=$rollout_mode \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k=${top_k} \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    actor_rollout_ref.rollout.enable_routing_replay=${enable_routing_replay} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
    reward_model.reward_manager=dapo \
    +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
    +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
    +reward_model.reward_kwargs.max_resp_len=${max_response_length} \
    trainer.rollout_data_dir="${OUTPUTS_ROLLOUT_DIR}" \
    trainer.validation_data_dir="${OUTPUTS_VALIDATION_DIR}" \
    trainer.logger='["console","wandb"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    trainer.nnodes="${NNODES}" \
    trainer.val_before_train=False \
    trainer.test_freq=25 \
    trainer.save_freq=50 \
    trainer.total_epochs=1 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=auto \
    trainer.log_val_generations=5 \
    actor_rollout_ref.nccl_timeout=60000 2>&1 | tee logs/${project_name}_${exp_name}_$timestamp.log

Hi, this is a known issue of SGLang expert distribution recorder. Specifically speaking, when an SGLang inference entry point has a batch size larger than 32, it almost always happen. In your case, given bsz=32, rollout_n=4 and n_nodes=1, the actual batch size for the entry point is 32*4. The easiest solution would be to scale the nodes. I also tried to manually split it down to pieces of smaller batches of size 32 but results in vain. I recommend to checkout the newer sglang release and I'm also actively trying to update the SGLang routing replay PR to the its newest release.

Thanks for your reply! I tried you recommended setup by adjusting bsz as follows, however, non-zero routing under mask==0 issue remains. I might have a deeper research on it later. Still looking forward to having routing replay function on verl/sglang :)

# Recommended setup: (train_prompt_bsz * n_resp_per_prompt) / n_machines <= 32, in respect of SGLang expert_recorder's unstable recording when inference bsz is bigger than 32
train_prompt_bsz=4
train_prompt_mini_bsz=4
n_resp_per_prompt=2

What GPU are you using? My test is run on 94G H100.

@Cesilina
Copy link
Copy Markdown

Cesilina commented Dec 31, 2025

I tried running your code and script, but it seems to hang—there's been no response for several hours.
[图片]
[DEBUG][fsdp_workers]: Preparing routing matrices (owner materialization) for batch with id: RID|85df07b73f6b45e3b7772fb487282f68|c8483b01fbbf452189fcba33e799c213|667b029ca1b84d509a39cc747e5462cf|07ef0392895b47b5ada0686924871235|0d7d160ac24c476b9b6ccef39ec0fec1|9b0aa25e48664e7ebe33f04b1c7570e4|def2a00906ce476cafd77fdc04dcf23d|395899a0522f4eb4b58804c7d5da532f|05695b3f8f5044408dca09e8d3d5c42c|487f0d2356824cbd96b32ad2bb17f431|96349107c38b4258947d6ba90c8945ab|c911568bce0d4fab9419c568cd8c326b|0fd87e95a0684de7accb62bb0a14fd7e|e2e9fefb657946c5b862a8f050c4157a|2ed37199ba94497a8e600aea2ab480e1|449a65cef41249eaaf6a4edfda1e0257|8e34fdd3b3ab471cb29690525a01c7f2|c68d0cc0070a45cfa62378659c43fd7a|0b0132f47ccd4593af6ef0c0f0a5e4c7|3a1df0a9d5b143febbf4c4e7d7246c39|62d0299a4ec54bbd9d3c9337b333506f|31a4f92fdf524a64b69c9eb8097ed066|4e63fee864ba494f9fca5f6dc988aa93|61ccac6e17ae4cbcb9c3cb56656102a2|5ebcf44a3ca54c3db9ca191a04e64b85|8f7b0e2c598b4a6cb5644905941ad02e|8da6a07cad6640dbaaa338bc5f7c6bed|7104b7859ec2484797c82bf845f18224|821c5fe04ed8404b920dcb9cef611e44|fef0e8c3cb5c4324a493825629255b68|51c82c3f151e4c259c61d244d8714e61|389b3dffcbef4edabfe35a7344756df9

No more log info....

The experiment was run on a single node with H100 GPUs.

I reset bsz with
train_prompt_bsz=4
train_prompt_mini_bsz=4
n_resp_per_prompt=2

It has same result with hang .
May I ask how to solve it?

Found it. when i set infer_tp as 8, it works.
when i set infer_tp as 4,it didn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants