Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/router_replay/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Router Replay

Router Replay is an advanced routing replay functionality within the Verl framework designed for Mixture of Experts (MoE) models. It enables deterministic training by recording and replaying routing decisions, ensuring consistent model behavior across training runs.


## Key Features

### Multiple Operating Modes
- **`disabled`**: Router replay functionality is completely disabled
- **`R2`**: Standard router replay mode for recording and replaying routing decisions
- **`R3`**: Rollout-specific router replay mode optimized for reinforcement learning workflows

### Core Capabilities
- **Seamless Integration**: Works with reinforcement learning pipelines including PPO
- **Distributed Training Support**: Compatible with multi-GPU and multi-node training environments
- **Flexible Configuration**: Easy to configure via YAML files or command-line parameters

## Configuration

### RouterReplayConfig Parameters

```yaml
router_replay:
mode: "disabled" # Available options: disabled, R2, R3
record_file: null # Path for recording routing decisions
replay_file: null # Path for replaying recorded decisions
```

## Quick Start Guide

### Enabling R2 Mode

#### Configuration File Method
Add the following to your training configuration:

```yaml
actor:
router_replay:
mode: "R2"
```

#### Command Line Method
Enable R2 mode via command-line parameters:

```bash
actor_rollout_ref.actor.router_replay.mode="R2"
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
```

### Enabling R3 Mode

#### Configuration File Method
Configure both actor and rollout settings:

```yaml
# Actor configuration
router_replay:
mode: "R3"

# Rollout configuration
enable_rollout_routing_replay: True
```

#### Command Line Method
Enable R3 mode via command-line parameters:

```bash
actor_rollout_ref.actor.router_replay.mode="R3"
actor_rollout_ref.rollout.enable_rollout_routing_replay=True
```

R3 mode requires the rollout backend to support returning router selection results. Currently, this functionality is being tested based on the vllm implementation at https://github.com/vllm-project/vllm/pull/28284.
108 changes: 108 additions & 0 deletions examples/router_replay/run_qwen30_a3b_megatron_vllm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

set -x

NODES=1

# R2: enable routing replay
# R3: enable rollout routing replay
# If enabling R3, please set actor_rollout_ref.rollout.enable_rollout_routing_replay=True
# R3 example is based on vllm related pr https://github.com/vllm-project/vllm/pull/5322

ROUTING_REPLAY_MODE="R2"

DIST_CKPT_PATH=""
HF_MODEL_PATH=""
TRAIN_DATA_PATH=""
TEST_DATA_PATH=""

export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
PP=1
VPP=None
TP=2
EP=8
ETP=1
VLLM_INFER_TP=2
offload=True
gpu_memory_utilization=0.65
bs=8
micro_bs=3
use_dynamic_bsz=True
max_prompt_length=1024
max_response_length=1024
ppo_mini_batch_size=8
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))


exper_name=Node${NODES}_bs${bs}_${PP}${TP}${EP}${ETP}_${VLLM_INFER_TP}_minbs${ppo_mini_batch_size}_micro_bs${micro_bs}

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml' \
algorithm.adv_estimator=grpo \
data.train_files=$TRAIN_DATA_PATH \
data.val_files=$TEST_DATA_PATH \
data.train_batch_size=$bs \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.actor.router_replay.mode=${ROUTING_REPLAY_MODE} \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_bs \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_bs \
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_INFER_TP \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.actor.megatron.use_mbridge=True \
actor_rollout_ref.rollout.gpu_memory_utilization=$gpu_memory_utilization \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$micro_bs \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_grpo_example_gsm8k_math' \
trainer.experiment_name="$exper_name" \
trainer.nnodes=$NODES \
trainer.n_gpus_per_node=8 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_training_steps=50000 \
trainer.balance_batch=False \
trainer.val_before_train=False 2>&1
26 changes: 26 additions & 0 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class AgentLoopOutput(BaseModel):
"""Response mask, 1 for LLM generated token, 0 for tool response token."""
response_logprobs: Optional[list[float]] = None
"""Log probabilities for the response tokens."""
routed_experts: Optional[Any] = None
"""Routed experts for the total tokens."""
multi_modal_data: Optional[dict[str, Any]] = None
"""Multi-modal data for multi-modal tools."""
reward_score: Optional[float] = None
Expand Down Expand Up @@ -165,6 +167,8 @@ class _InternalAgentLoopOutput(AgentLoopOutput):
"""Padded attention mask."""
response_logprobs: Optional[torch.Tensor] = None
"""Padded log probabilities for the response tokens."""
routed_experts: Optional[torch.Tensor] = None
"""Padded routed experts for the total tokens."""
multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None
"""Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw)."""
extra_fields: dict[str, Any] = {}
Expand Down Expand Up @@ -487,6 +491,25 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)
input_ids = torch.cat([prompt_output["input_ids"], response_output["input_ids"]], dim=1)

routed_experts = None
if output.routed_experts is not None:
total_length = input_ids.shape[1]
length, layer_num, topk_num = output.routed_experts.shape
experts_tensor = torch.from_numpy(output.routed_experts)
routed_experts = torch.zeros(1, total_length, layer_num, topk_num, dtype=experts_tensor.dtype)

# Calculate start position: left padding means original prompt starts at the end
start_pos = prompt_output["input_ids"].shape[1] - len(output.prompt_ids)
end_pos = min(start_pos + length, total_length)

# Add boundary checks for robustness
if start_pos < 0 or end_pos > total_length:
raise ValueError(
f"Invalid position range: start_pos={start_pos}, end_pos={end_pos}, total_length={total_length}"
)

routed_experts[:, start_pos:end_pos] = experts_tensor.unsqueeze(0)

# Handle multi-modal inputs and position_ids calculation
# Only support Qwen2VLImageProcessor for multi-modal processing currently
# TODO: support other multi-modal inputs
Expand Down Expand Up @@ -560,6 +583,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalAgentLoopO
response_mask=response_mask,
attention_mask=attention_mask,
response_logprobs=response_logprobs,
routed_experts=routed_experts,
multi_modal_inputs=multi_modal_inputs,
multi_modal_data=output.multi_modal_data,
reward_score=output.reward_score,
Expand All @@ -580,6 +604,8 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
optional_outputs = {}
if inputs[0].response_logprobs is not None:
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
if inputs[0].routed_experts is not None:
optional_outputs["routed_experts"] = torch.cat([input.routed_experts for input in inputs], dim=0)

batch = TensorDict(
{
Expand Down
5 changes: 5 additions & 0 deletions verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
response_ids=output.token_ids[: self.response_length],
response_mask=response_mask[: self.response_length],
response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None,
routed_experts=(
output.routed_experts[: len(prompt_ids) + self.response_length]
if output.routed_experts is not None
else None
),
multi_modal_data={"image": image_data} if image_data is not None else {},
num_turns=2,
metrics=metrics,
Expand Down
3 changes: 3 additions & 0 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ async def _handle_generating_state(
if output.log_probs:
agent_data.response_logprobs += output.log_probs

if output.routed_experts is not None:
agent_data.routed_experts = output.routed_experts

# Check termination conditions
if not ignore_termination and len(agent_data.response_mask) >= self.response_length:
return AgentState.TERMINATED
Expand Down
4 changes: 2 additions & 2 deletions verl/models/mcore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def postprocess_packed_seqs(
if cp_size > 1:
# output shape: [1, packed_len, hidden_dim]
# need to gather across cp group and concatenate in sequence dimension
output_list = [torch.empty_like(output) for _ in range(cp_size)]
output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)]
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
output_list[mpu.get_context_parallel_rank()] = output
else:
Expand All @@ -159,7 +159,7 @@ def postprocess_packed_seqs(
half_seqlen = s_len_padded_chunk // 2
s_len = seq_lens_cpu[i]
s_len_padded = s_len_padded_chunk * cp_size
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype)
for j in range(cp_size):
o = output_list[j][0]
# split to 2 chunks
Expand Down
11 changes: 11 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ actor_rollout_ref:
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
router_replay:
_target_: verl.workers.config.RouterReplayConfig
mode: disabled
record_file: null
replay_file: null
data_loader_seed: 42
load_weight: true
ref:
Expand Down Expand Up @@ -156,6 +161,11 @@ actor_rollout_ref:
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
router_replay:
_target_: verl.workers.config.RouterReplayConfig
mode: disabled
record_file: null
replay_file: null
megatron:
_target_: verl.workers.config.McoreEngineConfig
param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False}
Expand Down Expand Up @@ -259,6 +269,7 @@ actor_rollout_ref:
skip_rollout: false
skip_dump_dir: /tmp/rollout_dump
skip_tokenizer_init: true
enable_rollout_routing_replay: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
11 changes: 11 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ actor_rollout_ref:
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
router_replay:
_target_: verl.workers.config.RouterReplayConfig
mode: disabled
record_file: null
replay_file: null
grad_clip: 1.0
ulysses_sequence_parallel_size: 1
entropy_from_logits_with_chunking: false
Expand Down Expand Up @@ -145,6 +150,11 @@ actor_rollout_ref:
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
router_replay:
_target_: verl.workers.config.RouterReplayConfig
mode: disabled
record_file: null
replay_file: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
Expand Down Expand Up @@ -245,6 +255,7 @@ actor_rollout_ref:
skip_rollout: false
skip_dump_dir: /tmp/rollout_dump
skip_tokenizer_init: true
enable_rollout_routing_replay: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
20 changes: 20 additions & 0 deletions verl/trainer/config/actor/actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,23 @@ profiler:

# Stack trace depth for memory allocations
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}

# Router replay configuration for MoE models
router_replay:

# Target dataclass for this configuration
_target_: verl.workers.config.RouterReplayConfig

# Router replay mode: disabled, R2, R3
# - R2: Use R2 routing strategy (record mode)
# - R3: Use R3 routing strategy (record mode)
mode: disabled

# File path to save recorded routing decisions
# Required when mode is 'record', 'R2', or 'R3'
record_file: null

# File path to load recorded routing decisions for replay
# Required when mode is 'replay'
replay_file: null

19 changes: 19 additions & 0 deletions verl/trainer/config/ref/ref.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,22 @@ profiler:

# Stack trace depth for memory allocations
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}

# Router replay configuration for MoE models
router_replay:

# Target dataclass for this configuration
_target_: verl.workers.config.RouterReplayConfig

# Router replay mode: disabled, R2, R3
# - R2: Use R2 routing strategy (record mode)
# - R3: Use R3 routing strategy (record mode)
mode: disabled

# File path to save recorded routing decisions
# Required when mode is 'record', 'R2', or 'R3'
record_file: null

# File path to load recorded routing decisions for replay
# Required when mode is 'replay'
replay_file: null
Loading
Loading