Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用
- rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。
- log_rollout_offpolicy_metrics: 当 `rollout_importance_sampling_mode` 未设置时,是否记录训推不一致诊断指标(KL、PPL、χ²等)。当设置了 `rollout_importance_sampling_mode` 时,指标会自动记录。默认为False。
- off_policy_sequence_mask_delta: Off-Policy Sequence Masking 阈值,来自 DeepSeek-V3.2 论文。当设置此值时,会计算每个序列的 `mean(old_policy_logps - policy_logps)`,若该值大于阈值且该序列的优势为负,则 mask 掉该序列不参与损失计算。默认为None,不启用。具体参考[文档](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking)。
- router_replay_mode: 路由重放模式,可选项为`disabled`、`R2`、`R3`。默认为disabled,不启用路由重放。

内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数)

Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Megatron-SWIFT/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ In addition to inheriting the training parameters, the following parameters are
- rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0.
- log_rollout_offpolicy_metrics: Whether to log training-inference mismatch diagnostic metrics (KL, PPL, χ², etc.) when `rollout_importance_sampling_mode` is not set. When `rollout_importance_sampling_mode` is set, metrics are always logged. Default is False.
- off_policy_sequence_mask_delta: Off-Policy Sequence Masking threshold from [DeepSeek-V3.2 paper](https://arxiv.org/abs/2512.02556). When set, computes `mean(old_policy_logps - policy_logps)` for each sequence. If this value exceeds the threshold AND the sequence has negative advantage, the sequence is masked out from loss computation. For details, refer to the [documentation](../Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md#off-policy-sequence-masking).
- router_replay_mode: Router replay mode. Options are `disabled`,`R2`,`R3`. Default is disabled.

Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters).

Expand Down
1 change: 1 addition & 0 deletions swift/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _create_chat_completion_response(self, result, inputs, request_config, reque
finish_reason=output.finish_reason,
logprobs=logprobs,
token_ids=token_ids,
routed_experts=getattr(output, 'routed_experts', None)
)
choices.append(choice)
prompt_token_ids = None
Expand Down
32 changes: 30 additions & 2 deletions swift/infer_engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,40 @@
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from PIL import Image
from pydantic import BaseModel, Field, field_validator
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, field_validator, PlainSerializer, AfterValidator
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Annotated
import numpy as np

from swift.template import Messages, Tool
from swift.utils import remove_response


def serialize_ndarray(value):
if value is None:
return None
if isinstance(value, np.ndarray):
return {
'data': value.tolist(),
'shape': value.shape,
'dtype': str(value.dtype),
'__ndarray__': True
}
return value

def deserialize_ndarray(value):
if value is None:
return None
if isinstance(value, dict) and value.get('__ndarray__'):
return np.array(value['data'], dtype=value['dtype']).reshape(value['shape'])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return np.array(value['data'], dtype=value['dtype']).reshape(value['shape'])
data = base64.b64decode(value['data'])
return np.frombuffer(data, dtype=value['dtype']).reshape(value['shape'])

return value

NumpyArray = Annotated[
Any,
PlainSerializer(serialize_ndarray, return_type=Dict),
AfterValidator(deserialize_ndarray)
]


@dataclass
class InferRequest:
"""
Expand Down Expand Up @@ -392,6 +419,7 @@ class ChatCompletionResponseChoice:
finish_reason: Literal['stop', 'length', None]
logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None
token_ids: Optional[List[int]] = None
routed_experts: Optional[NumpyArray] = None

def to_cmpl_choice(self) -> 'CompletionResponseChoice':
self = deepcopy(self)
Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/arguments/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class RLHFMegatronArgumentsMixin:
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0

router_replay_mode: Literal['disabled', 'R2', 'R3'] = 'disabled'

# ─────────────────────────── Not Supported Yet ───────────────────────────

Expand Down
6 changes: 6 additions & 0 deletions swift/megatron/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class MegatronModelConfig(TransformerConfig):
'none'] = 'aux_loss'
use_shared_expert_gate: bool = False

enable_routing_replay: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using the parameter name moe_enable_routing_replay to maintain
consistency with Megatron-LM.

Reference: NVIDIA/Megatron-LM#2101


# mla
multi_latent_attention: bool = False
q_lora_rank: Optional[int] = None
Expand Down Expand Up @@ -508,6 +510,10 @@ def get_mcore_model_config(args, hf_config):
if num_moe_experts is None:
kwargs['expert_model_parallel_size'] = 1
kwargs['expert_tensor_parallel_size'] = 1

if args.router_replay_mode != "disabled":
kwargs['enable_routing_replay'] = True

config = MegatronModelConfig(**kwargs)
config.hf_config = hf_config
config.args = args
Expand Down
16 changes: 15 additions & 1 deletion swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
logical_and_across_model_parallel_group, maybe_finalize_async_save,
prepare_mcore_model, reduce_max_stat_across_model_parallel_group,
save_mcore_checkpoint, should_disable_forward_pre_hook, warmup_jit_function,
wrap_model)
wrap_model,
apply_router_replay_patch, RouterReplay, RouterReplayAction)
from swift.template import Template
from swift.trainers import dynamic_gradient_checkpointing
from swift.trainers.utils import patch_modelscope_hub_timeout
Expand All @@ -53,6 +54,11 @@
class BaseMegatronTrainer(ABC):

def __init__(self, args, template: Template):
# patch routing_replay
self.enable_routing_replay = args.router_replay_mode != "disabled"
if self.enable_routing_replay:
apply_router_replay_patch()

self.args = args
self.template = template
self.bridge = args.megatron_model_meta.bridge_cls(args)
Expand Down Expand Up @@ -768,6 +774,10 @@ def train_step(self, train_data_iterator):
self.optimizer.zero_grad()
# TODO: refactor _replace_data_iterator
data_iterator = self._replace_data_iterator(train_data_iterator)

if self.enable_routing_replay:
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)

metrics = forward_backward_func(
forward_step_func=self.forward_step,
data_iterator=data_iterator,
Expand All @@ -784,6 +794,10 @@ def train_step(self, train_data_iterator):
if update_successful:
self.opt_param_scheduler.step(increment=args.global_batch_size)

if self.enable_routing_replay:
RouterReplay.clear_global_router_replay_action()
RouterReplay.clear_global_indices()

return metrics, grad_norm, update_successful

def _aggregated_metrics(self, metrics, total_metrics):
Expand Down
88 changes: 84 additions & 4 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from swift.dataset import RowPreprocessor
from swift.infer_engine.protocol import RequestConfig, RolloutInferRequest, RolloutOutput
from swift.megatron.arguments import MegatronArguments, MegatronRLHFArguments
from swift.megatron.utils import forward_step_helper, get_padding_to, set_random_seed
from swift.megatron.utils import (forward_step_helper, get_padding_to, set_random_seed,
RouterReplay, RouterReplayAction, RouterReplayHelper,
get_router_replay_data, set_router_replay_data, get_local_topk_idx_for_current_rank)
from swift.rewards import orms
from swift.rlhf_trainers.grpo_trainer import DataType
from swift.rlhf_trainers.utils import (aggressive_empty_cache, nanstd, pad_logps_back_to_batch, profiling_context,
Expand Down Expand Up @@ -271,6 +273,8 @@ def _batch_encode(self, infer_requests: List[Dict], template: Template, strict:
return batched_inputs, error_list

def _get_encoded_batch(self, encoded_list, rollout_batch, template):
original_seq_lengths = [item['length'] for item in encoded_list]

args = self.args
encoded_batch = to_device(template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device)

Expand Down Expand Up @@ -360,6 +364,40 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):

encoded_batch['rollout_per_token_logps'] = rollout_per_token_logps

# Validating and processing routed_experts data in R3 mode
if self.args.router_replay_mode == 'R3':
routed_experts_list = []
cur_seq_lengths = seq_lengths
if (seq_lengths.size(0) > batch_size):
cur_seq_lengths = seq_lengths[:batch_size].clone()
cur_seq_lengths[batch_size - 1] = seq_lengths[batch_size-1:].sum()
for data, original_seq_len, cur_seq_len in zip(rollout_batch, original_seq_lengths, cur_seq_lengths):
routed_experts = data.get('routed_experts')
assert routed_experts is not None, 'When router_replay_mode = R3, routed_experts must be in rollout data'
routed_experts = torch.tensor(routed_experts)
# The number of experts in the output can be 1 less than (prompt_length + response_token_count)
# This gap of 1 is expected
# For more details, please refer PR https://github.com/vllm-project/vllm/pull/28284
experts_seq_len = routed_experts.shape[0]
assert (experts_seq_len == original_seq_len
or experts_seq_len + 1 == original_seq_len), \
f'The seq_len of routed_experts({experts_seq_len}) in output does not match the seq_len of data({original_seq_len}), should be equal to or 1 less than the seq_len of data'
# Padding routed_experts(seq_len, layer_num, topk) seq_len to match the seq_len of the input_ids
padding_routed_experts = routed_experts
padding_to = cur_seq_len if template.padding_free else max_seq_len
padding_len = padding_to - experts_seq_len
if padding_len > 0:
padding_right = template.padding_side == 'right'
padding_routed_experts = nn.functional.pad(routed_experts,
(0, 0, 0, 0, 0, padding_len) if padding_right else (0, 0, 0, 0, padding_len, 0),
'constant', 0)
routed_experts_list.append(padding_routed_experts)
if template.padding_free:
gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
gloabl_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the variable name gloabl_routed_experts. It should be global_routed_experts.

Suggested change
if template.padding_free:
gloabl_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
gloabl_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = gloabl_routed_experts.to(device=self.device)
if template.padding_free:
global_routed_experts = torch.cat(routed_experts_list, dim=0).unsqueeze(0)
else:
global_routed_experts = torch.stack(routed_experts_list)
encoded_batch['routed_experts'] = global_routed_experts.to(device=self.device)


return encoded_batch

def _generate_and_score_completions(self, batch):
Expand Down Expand Up @@ -557,6 +595,9 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out
if 'content' in choice.logprobs:
rollout_logprobs = [item['logprob'] for item in choice.logprobs['content']]
input_data['rollout_logprobs'] = [rollout_logprobs]

# Step 6: Store rollout routed_experts for routing replay
input_data['routed_experts'] = choice.routed_experts
return input_data

assert len(batch) == len(outputs)
Expand Down Expand Up @@ -951,9 +992,16 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
ref_per_token_logps = ref_per_token_logps_raw
batch['ref_per_token_logps'] = ref_per_token_logps

old_per_token_logps_raw = self.model_forward(
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']

if self.enable_routing_replay:
if self.args.router_replay_mode == 'R2':
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
if self.args.router_replay_mode == 'R3':
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)

output = self.model_forward(self.unwrapped_models[0], iter([deepcopy(inputs)]),
no_grad=True, per_token=True)
old_per_token_logps_raw = output['logps']
if self.template.padding_free:
old_per_token_logps, _ = pad_logps_back_to_batch(
logps_rmpad=old_per_token_logps_raw,
Expand All @@ -964,6 +1012,11 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_per_token_logps = old_per_token_logps_raw
batch['old_per_token_logps'] = old_per_token_logps

if self.enable_routing_replay:
batch['routed_experts'] = output['layers_topk_idx']
RouterReplay.clear_global_indices()
RouterReplay.clear_global_router_replay_action()

return batch

def _compute_kl_from_batches(self, mini_batch_data: List[Dict[str, Any]]) -> torch.Tensor:
Expand Down Expand Up @@ -1043,6 +1096,15 @@ def forward_step(self, data_iterator, model):
'seq_lengths': seq_lengths,
})
data.pop('loss_scale', None)

if RouterReplayHelper.is_replay_backward_action(model.config):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
if RouterReplayHelper.is_replay_forward_action(model.config):
layers_topk_idx = data.pop('routed_experts', None)
set_router_replay_data(layers_topk_idx, model.config)

inputs = self._prepare_model_inputs(data)

labels = data['labels']
Expand Down Expand Up @@ -1116,6 +1178,11 @@ def forward_step(self, data_iterator, model):
data['per_token_logps'] = per_token_logps
data['per_token_entropy'] = None

if RouterReplayHelper.is_replay_forward_action(model.config):
router_instance_list = RouterReplayHelper.get_micro_batch_router_list(model.config)
for router in router_instance_list:
router.set_router_replay_action(RouterReplayAction.REPLAY_BACKWARD)

return output_tensor, partial(self.loss_func, data=data)

@profiling_decorator
Expand Down Expand Up @@ -1430,6 +1497,13 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
labels = data.get('labels')
context = torch.no_grad() if no_grad else nullcontext()

layers_topk_idx = None
global_topk_idx = data.pop('routed_experts', None)
if RouterReplayHelper.is_replay_forward_action(model.config):
assert global_topk_idx is not None, "When router_replay_mode = R3, routed_experts must be in data"
layers_topk_idx = get_local_topk_idx_for_current_rank(global_topk_idx, model.config, data.get('packed_seq_params'))
set_router_replay_data(layers_topk_idx, model.config)

with context:
output_tensor = forward_step_helper(self.args, model, data)

Expand All @@ -1441,6 +1515,12 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
data['logps'] = None if labels is None else self.get_logps(
output_tensor, labels, packed_seq_params, num_samples, per_token=per_token)

if RouterReplayHelper.is_r2_record_action(model.config):
layers_topk_idx = get_router_replay_data(model.config)
if layers_topk_idx is not None:
data['layers_topk_idx'] = layers_topk_idx

return data

def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]:
Expand Down
3 changes: 3 additions & 0 deletions swift/megatron/trainers/rollout_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def _prepare_vllm_engine(self):
vllm_engine_kwargs = args.vllm_engine_kwargs or {}
load_format = vllm_engine_kwargs.pop('load_format', 'dummy')

if self.args.router_replay_mode == 'R3':
vllm_engine_kwargs['enable_return_routed_experts'] = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a vLLM version check? The enable_return_routed_experts
parameter is only available in recent versions


engine = GRPOVllmEngine(
args.model_info.model_dir,
torch_dtype=args.torch_dtype,
Expand Down
2 changes: 2 additions & 0 deletions swift/megatron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
from .patcher import patch_merge_fn, patch_torch_dist_shard
from .utils import (copy_original_module_weight, forward_step_helper, get_local_layer_specs, get_padding_to,
prepare_mcore_model, tuners_sharded_state_dict)
from .router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch
from .router_replay_utils import *
Loading
Loading