From 6de7d68cc8e60f244f85492cd58b10d183b6e31d Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Wed, 20 Aug 2025 09:32:20 -0700 Subject: [PATCH 1/4] Add routed_scaling_factor to MoE grouped topk Signed-off-by: Xin Yang --- .../layers/fused_moe/cpu_fused_moe.py | 12 ++++++++++++ vllm/model_executor/layers/fused_moe/fused_moe.py | 12 +++++++++--- vllm/model_executor/layers/fused_moe/layer.py | 15 +++++++++++++++ .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 ++ .../layers/quantization/awq_marlin.py | 2 ++ .../layers/quantization/bitsandbytes.py | 2 ++ .../compressed_tensors/compressed_tensors_moe.py | 10 ++++++++++ .../layers/quantization/experts_int8.py | 2 ++ vllm/model_executor/layers/quantization/fp8.py | 4 +++- vllm/model_executor/layers/quantization/gguf.py | 2 ++ .../layers/quantization/gptq_marlin.py | 2 ++ .../layers/quantization/modelopt.py | 4 ++++ .../layers/quantization/moe_wna16.py | 2 ++ vllm/model_executor/layers/quantization/mxfp4.py | 2 ++ .../layers/quantization/quark/quark_moe.py | 4 ++++ vllm/model_executor/layers/quantization/rtn.py | 2 ++ vllm/model_executor/models/deepseek_v2.py | 1 + 17 files changed, 76 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 769a04b7de89..0eec93601b3f 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -21,6 +21,7 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( @@ -65,6 +66,8 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids.to(torch.int32) @@ -78,6 +81,7 @@ def select_experts( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if use_grouped_topk: @@ -90,6 +94,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: assert scoring_func == "softmax" @@ -131,12 +136,15 @@ def __call__( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported." return layer.ipex_fusion( x, use_grouped_topk, @@ -170,6 +178,7 @@ def __call__( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -186,6 +195,7 @@ def __call__( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, ) @@ -227,6 +237,7 @@ def __call__( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -243,6 +254,7 @@ def __call__( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 84dafcf00d82..ebac839b2ed1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1728,6 +1728,7 @@ def fused_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, + routed_scaling_factor: float = 1.0, inplace: bool = False, activation: str = "silu", is_act_and_mul: bool = True, @@ -1814,9 +1815,14 @@ def fused_moe( if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, - topk, renormalize, - num_expert_group, topk_group) + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + routed_scaling_factor=routed_scaling_factor) elif custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 54406a5a2d87..acc32277be2c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -243,6 +243,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -398,6 +399,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -425,6 +427,7 @@ def apply( expert_map=expert_map, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -448,6 +451,7 @@ def forward_cuda( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -467,6 +471,7 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, @@ -532,6 +537,7 @@ def forward_cpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -558,6 +564,7 @@ def forward_cpu( expert_map, custom_routing_function, scoring_func, + routed_scaling_factor, e_score_correction_bias, apply_router_weight_on_input, activation, @@ -577,6 +584,7 @@ def forward_xpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -615,6 +623,7 @@ def forward_tpu( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -762,6 +771,7 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -844,6 +854,7 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.scoring_func = scoring_func + self.routed_scaling_factor = routed_scaling_factor self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation @@ -1412,6 +1423,7 @@ def select_experts( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, @@ -1456,6 +1468,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1623,6 +1636,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, enable_eplb=self.enable_eplb, @@ -1691,6 +1705,7 @@ def forward_impl(self, hidden_states: torch.Tensor, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 93e20c3477bb..cd47ac24c9a0 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -267,6 +267,7 @@ def rocm_aiter_grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] @@ -298,6 +299,7 @@ def rocm_aiter_grouped_topk( scoring_func, ) + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 287d66b06d6e..8293d42ef455 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -497,6 +497,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -523,6 +524,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index b7897a43793c..9713757df9b0 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -466,6 +466,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -490,6 +491,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) if self.quant_config.load_in_8bit: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1ee3478aa4f4..4e8691d18ff4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -319,6 +319,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -344,6 +345,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) @@ -779,6 +781,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -802,6 +805,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) @@ -1027,6 +1031,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1054,6 +1059,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -1331,6 +1337,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1359,6 +1366,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -1562,6 +1570,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1588,6 +1597,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 3e43caa4cbf7..2d8a684bc7d9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -120,6 +120,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -146,6 +147,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d45d368b582d..16a74d9cbb8b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -955,6 +955,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -994,7 +995,7 @@ def apply( expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.quant_config.weight_block_size, - routed_scaling=1.0, + routed_scaling=routed_scaling_factor, ) else: assert (not renormalize @@ -1020,6 +1021,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, enable_eplb=enable_eplb, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 90222f2e3b0e..ad648df23819 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -532,6 +532,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -562,6 +563,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c5d1e017014f..350975966668 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -643,6 +643,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -669,6 +670,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 72864853f7e0..de7c05b3277e 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -482,6 +482,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -520,6 +521,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, ) @@ -1352,6 +1354,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -1430,6 +1433,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 364d1ac314d2..f76e3dac60c2 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -299,6 +299,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -325,6 +326,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index df96e5d8c413..58e5d198b0bc 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -461,6 +461,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -484,6 +485,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 58f56c6381b3..fdf03ded0448 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -218,6 +218,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -244,6 +245,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) @@ -380,6 +382,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -406,6 +409,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 8bdb50e07b13..8f72b8cbea7a 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -283,6 +283,7 @@ def apply( expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", @@ -309,6 +310,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7657e7cb003d..81ddb8e4586b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -159,6 +159,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + routed_scaling_factor=config.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) From 473fa9c25fe2e3c15826a22567f80d0de956c3ef Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 25 Aug 2025 13:25:19 -0700 Subject: [PATCH 2/4] Review changes Signed-off-by: Xin Yang --- vllm/model_executor/layers/fused_moe/layer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index acc32277be2c..50aeb389e0e8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -644,6 +644,9 @@ def forward_tpu( raise NotImplementedError( "Expert score correction bias is not supported for TPU.") assert activation == "silu", f"{activation} is not supported for TPU." + assert routed_scaling_factor == 1.0, \ + f"routed_scaling_factor {routed_scaling_factor} is not supported " \ + f"for TPU." if enable_eplb is not False or expert_load_view is not None or \ logical_to_physical_map is not None or \ logical_replica_count is not None: From 37ee5903dd2a3215705fc009ffd668a4d927e37c Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 25 Aug 2025 17:25:31 -0700 Subject: [PATCH 3/4] Review changes Signed-off-by: Xin Yang --- .../layers/fused_moe/fused_moe.py | 19 +++++++------------ .../layers/fused_moe/rocm_aiter_fused_moe.py | 3 ++- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ebac839b2ed1..dd4af1936193 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1011,7 +1011,8 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights * routed_scaling_factor + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @@ -1728,7 +1729,6 @@ def fused_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - routed_scaling_factor: float = 1.0, inplace: bool = False, activation: str = "silu", is_act_and_mul: bool = True, @@ -1792,8 +1792,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. @@ -1815,14 +1815,9 @@ def fused_moe( if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - routed_scaling_factor=routed_scaling_factor) + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) elif custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index cd47ac24c9a0..348bf8745e09 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -299,7 +299,8 @@ def rocm_aiter_grouped_topk( scoring_func, ) - topk_weights = topk_weights * routed_scaling_factor + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor return topk_weights, topk_ids From 1a04347eb0338016b48f4e895a2a79e6f366c5ae Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Mon, 25 Aug 2025 18:39:53 -0700 Subject: [PATCH 4/4] Update dots1 and glm4_moe Signed-off-by: Xin Yang --- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/dots1.py | 1 + vllm/model_executor/models/glm4_moe.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 81ddb8e4586b..a6e9b8169333 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -159,7 +159,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - routed_scaling_factor=config.routed_scaling_factor, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 5f410c0ae5fb..dd3108a39467 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -136,6 +136,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index fe5e46a99826..42072edfa0ed 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -158,6 +158,7 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func="sigmoid", + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts)