diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index f958a6322e3..565df1324f6 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -113,14 +113,17 @@ class RoutingMethodType(IntEnum): RenormalizeNaive = (4,) # TopK: TopK (no softmax) TopK = (5,) - # Custom - Custom = (6,) - # Simulated - Simulated = (7,) - # Deepseek V4 -> sqrtsoftplus + Bias + Normalize - DeepseekV4 = (8,) + # SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K) + SigmoidRenorm = (6,) + # MiniMax2: Sigmoid + Bias -> TopK -> ScaledSumNormalize + MiniMax2 = (7,) # Unspecified - Unspecified = 9.0 + Unspecified = (8,) + # other routing types (not passed to FlashInfer kernels) + # Deepseek V4 -> sqrtsoftplus + Bias + Normalize + DeepseekV4 = (100,) + Custom = (101,) + Simulated = (102,) def get_routing_method_type( @@ -141,12 +144,16 @@ def get_routing_method_type( if has_e_score_bias: if (num_expert_group or 0) > 0 and scoring_func == "sigmoid": return RoutingMethodType.DeepSeekV3 + elif scoring_func == "sigmoid": + return RoutingMethodType.MiniMax2 else: return RoutingMethodType.Unspecified if scoring_func == "sigmoid": if top_k == 1: return RoutingMethodType.Llama4 + elif renormalize: + return RoutingMethodType.SigmoidRenorm else: return RoutingMethodType.Unspecified diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 1f0258fb657..31af4a32bae 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -175,13 +175,6 @@ def apply( # Pack topk ids and weights into format expected by the kernel. packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) - # trtllm_fp8_block_scale_routed_moe does not support autotuning - # so skip this kernel during dummy run for autotuning. - import vllm.utils.flashinfer as fi_utils - - if fi_utils._is_fi_autotuning: - return - assert a1q_scale is not None is_mxfp8 = self.quant_config.block_shape == [1, 32] @@ -196,11 +189,7 @@ def apply( weight_layout = WeightLayout.BlockMajorK hidden_states_scale = a1q_scale.t().contiguous() - # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the - # output tensor in-place so we need to manually copy the result to the - # output tensor - # https://github.com/flashinfer-ai/flashinfer/issues/2703 - result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe( + flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe( topk_ids=packed_topk_ids, routing_bias=None, hidden_states=hidden_states, @@ -217,13 +206,12 @@ def apply( local_expert_offset=self.ep_rank * self.local_num_experts, local_num_experts=self.local_num_experts, routed_scaling_factor=None, - routing_method_type=1, + routing_method_type=1, # not used use_shuffled_weight=use_shuffled_weight, weight_layout=weight_layout, fp8_quantization_type=fp8_quant_type, - # output=output, + output=output, ) - output.copy_(result) class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic): @@ -275,20 +263,6 @@ def _supports_router_logits_dtype( router_logits_dtype: torch.dtype | None, routing_method: RoutingMethodType, ) -> bool: - """ - The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - DeepSeekV3 routing supports float32 router_logits (converted internally). - Simulated routing generates synthetic decisions and is agnostic to dtype. - """ - if router_logits_dtype == torch.float32: - # DeepSeekV3 routing handles float32 logits internally. - # Simulated routing generates synthetic decisions, so the - # kernel doesn't care about the actual logits dtype. - # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method in ( - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Simulated, - ) return True @staticmethod @@ -308,18 +282,22 @@ def _supports_routing_method( # NOTE(rob): potentially allow others here. This is a conservative list. return routing_method in [ RoutingMethodType.DeepSeekV3, - RoutingMethodType.Simulated, RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, + RoutingMethodType.SigmoidRenorm, + RoutingMethodType.MiniMax2, + RoutingMethodType.Simulated, ] elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): # NOTE(dbari): as above, potentially allow others here. return routing_method in [ RoutingMethodType.DeepSeekV3, RoutingMethodType.Llama4, - RoutingMethodType.Simulated, RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, + RoutingMethodType.SigmoidRenorm, + RoutingMethodType.MiniMax2, + RoutingMethodType.Simulated, ] else: raise ValueError("Unsupported quantization scheme.") @@ -355,14 +333,6 @@ def _apply_block_scale( # TODO: fuse into the quant kernel. assert a1q_scale is not None - if self.routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - # Currently FI requires bfloat16 routing bias. - # https://github.com/flashinfer-ai/flashinfer/issues/2909 - if e_score_correction_bias is not None: - e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16) - is_mxfp8 = self.quant_config.block_shape == [1, 32] if is_mxfp8: fp8_quant_type = Fp8QuantizationType.MxFp8 @@ -429,10 +399,6 @@ def _apply_per_tensor( else: assert not apply_router_weight_on_input - # The DeepSeekV3 routing method requires float32 router logits. - if self.routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - # Currently FI requires bfloat16 routing bias. # https://github.com/flashinfer-ai/flashinfer/issues/2909 if e_score_correction_bias is not None: diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index c6689bf9fed..baa7d3fd3ee 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -198,13 +198,6 @@ def apply( # Pack topk ids and weights into format expected by the kernel. packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) - # trtllm_fp4_block_scale_routed_moe does not support autotuning - # so skip this kernel during dummy run for autotuning. - import vllm.utils.flashinfer as fi_utils - - if fi_utils._is_fi_autotuning: - return - # Invoke kernel. flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( topk_ids=packed_tensor, @@ -233,7 +226,7 @@ def apply( local_expert_offset=self.ep_rank * self.local_num_experts, local_num_experts=self.local_num_experts, routed_scaling_factor=None, - routing_method_type=1, + routing_method_type=1, # not used do_finalize=True, activation_type=activation_to_flashinfer_int(activation), output=output, @@ -267,6 +260,8 @@ def _supports_routing_method( RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, RoutingMethodType.Llama4, + RoutingMethodType.SigmoidRenorm, + RoutingMethodType.MiniMax2, RoutingMethodType.Simulated, ] @@ -275,20 +270,6 @@ def _supports_router_logits_dtype( router_logits_dtype: torch.dtype | None, routing_method: RoutingMethodType, ) -> bool: - """ - The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. - DeepSeekV3 routing supports float32 router_logits (converted internally). - Simulated routing generates synthetic decisions and is agnostic to dtype. - """ - if router_logits_dtype == torch.float32: - # DeepSeekV3 routing handles float32 logits internally. - # Simulated routing generates synthetic decisions, so the - # kernel doesn't care about the actual logits dtype. - # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method in ( - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Simulated, - ) return True def apply( @@ -322,13 +303,6 @@ def apply( and self.routing_method_type != RoutingMethodType.Llama4 ) - # Prepare router logits for kernel format. - router_logits = ( - router_logits.to(torch.float32) - if self.routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) - # Currently FI requires bfloat16 routing bias. # https://github.com/flashinfer-ai/flashinfer/issues/2909 if e_score_correction_bias is not None: diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index aa61fc8e80f..db6d56e3c3a 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -275,7 +275,6 @@ def _return_or_raise( activation_key, activation_format, ) - if supported: logger.info_once(_make_log_backend(backend)) return backend, k_cls diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d91a41eaa7f..fbbd5da1fd9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,7 +50,6 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoE, GateLinear, - RoutingMethodType, fused_moe_make_expert_params_mapping, ) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm @@ -338,17 +337,6 @@ def __init__( else None, ) - # NOTE(rob): this is a hack until we finish off the PR for - # merging TRTLLM kernels into the MK framework. Then we can - # query the MonolithicMK for the expected router logits. - # NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3 - self.gate.set_out_dtype( - torch.float32 - if self.experts.quant_method.is_monolithic - and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3 - else torch.bfloat16 - ) - # Pre-cast the bias to match the gate output dtype so the # conversion is not repeated on every forward pass. All # downstream references (FusedMoE, router) share the same