Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,26 @@ def _quant_flags_to_group_shape(
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax/Sigmoid
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K)
SigmoidRenorm = (6,)
# MiniMax2: Sigmoid + Bias -> TopK -> ScaledSumNormalize
# (routeScale=1.0, epsilon=1e-20)
MiniMax2 = (7,)
# Sigmoid: Sigmoid -> TopK (no renormalization)
Sigmoid = (8,)
# Unspecified
Unspecified = (8,)
Unspecified = (9,)
# other routing types (not passed to FlashInfer kernels)
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
DeepseekV4 = (100,)
Expand All @@ -132,6 +135,7 @@ def get_routing_method_type(
renormalize: bool,
num_expert_group: int | None,
has_e_score_bias: bool,
routed_scaling_factor: float | None = 1.0,
) -> RoutingMethodType:
if scoring_func == "sqrtsoftplus":
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
Expand All @@ -142,20 +146,21 @@ def get_routing_method_type(
return RoutingMethodType.Unspecified

if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3
elif scoring_func == "sigmoid":
return RoutingMethodType.MiniMax2
if scoring_func == "sigmoid":
if not renormalize:
return RoutingMethodType.Unspecified
if (num_expert_group or 0) > 0:
return RoutingMethodType.DeepSeekV3
if routed_scaling_factor in (None, 1.0):
return RoutingMethodType.MiniMax2
return RoutingMethodType.Unspecified
else:
return RoutingMethodType.Unspecified

if scoring_func == "sigmoid":
if top_k == 1:
return RoutingMethodType.Llama4
elif renormalize:
if renormalize:
return RoutingMethodType.SigmoidRenorm
else:
return RoutingMethodType.Unspecified
return RoutingMethodType.Sigmoid

if scoring_func == "softmax":
if renormalize:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return router_logits_dtype != torch.float32
return router_logits_dtype in [torch.bfloat16, torch.float32]

@staticmethod
def _supports_routing_method(
Expand All @@ -279,6 +279,7 @@ def _supports_routing_method(
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
RoutingMethodType.MiniMax2,
RoutingMethodType.Simulated,
]
Expand All @@ -290,6 +291,7 @@ def _supports_routing_method(
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
RoutingMethodType.MiniMax2,
RoutingMethodType.Simulated,
]
Expand Down Expand Up @@ -402,11 +404,6 @@ def _apply_per_tensor(
else:
assert not apply_router_weight_on_input

# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)

out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return router_logits_dtype != torch.float32
return router_logits_dtype in [torch.bfloat16, torch.float32]

def apply(
self,
Expand Down Expand Up @@ -393,11 +393,6 @@ def apply(
and self.routing_method_type != RoutingMethodType.Llama4
)

# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)

output1_scale_gate_scalar = self.quant_config.g1_alphas

# Invoke kernel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def routing_method_type(self) -> RoutingMethodType:
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
routed_scaling_factor=self.routed_scaling_factor,
)

def _compute_routing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def routing_method_type(self) -> RoutingMethodType:
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
has_e_score_bias=self.e_score_correction_bias is not None,
routed_scaling_factor=self.routed_scaling_factor,
)

def _compute_routing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def routing_method_type(self) -> RoutingMethodType:
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
routed_scaling_factor=self.routed_scaling_factor,
)

def _compute_routing(
Expand Down
Loading