From 19561aee21828b24aa337b7a2a9bc153538ba2ad Mon Sep 17 00:00:00 2001 From: haosdent Date: Wed, 18 Mar 2026 21:49:26 +0800 Subject: [PATCH] [Bugfix] Add autotuning guard to all unprotected FlashInfer MoE kernels Use `with autotune(False):` to disable FlashInfer autotuning for MoE kernels that are incompatible with it (upstream flashinfer#2023). This follows the existing pattern in trtllm_moe.py and avoids shape/dtype mismatches from dummy return values. Kernels wrapped: - TrtLlmNvFp4ExpertsMonolithic (trtllm_fp4_block_scale_moe) - TrtLlmFp8ExpertsMonolithic (trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe) - flashinfer_fused_moe_bf16 (flashinfer_trtllm_bf16_moe) - FlashInferExperts (flashinfer_cutlass_fused_moe) - FlashInferCuteDSLExperts (flashinfer_cutedsl_moe_masked) - flashinfer_trtllm_mxint4_moe (trtllm_mxint4_block_scale_moe) Signed-off-by: haosdent --- .../fused_moe/experts/trtllm_fp8_moe.py | 95 ++++++++++--------- .../fused_moe/experts/trtllm_nvfp4_moe.py | 72 ++++++++------ .../fused_moe/flashinfer_cutedsl_moe.py | 40 +++++--- .../fused_moe/flashinfer_cutlass_moe.py | 65 ++++++++----- .../layers/fused_moe/flashinfer_trtllm_moe.py | 39 ++++---- .../utils/flashinfer_mxint4_moe.py | 55 ++++++----- 6 files changed, 208 insertions(+), 158 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 74096ef6ed6f..65f1f1d7e179 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -380,27 +380,32 @@ def _apply_block_scale( use_shuffled_weight = False hidden_states_scale = a1q_scale.t().contiguous() - return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, - hidden_states=hidden_states, - hidden_states_scale=hidden_states_scale, - gemm1_weights=w1, - gemm1_weights_scale=self.quant_config.w1_scale, - gemm2_weights=w2, - gemm2_weights_scale=self.quant_config.w2_scale, - num_experts=global_num_experts, - top_k=self.topk, - n_group=(num_expert_group or 0), - topk_group=(topk_group or 0), - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.ep_rank * self.local_num_experts, - local_num_experts=self.local_num_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=self.routing_method_type, - use_shuffled_weight=use_shuffled_weight, - fp8_quantization_type=fp8_quant_type, - ) + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + use_shuffled_weight=use_shuffled_weight, + fp8_quantization_type=fp8_quant_type, + ) def _apply_per_tensor( self, @@ -437,28 +442,32 @@ def _apply_per_tensor( if self.routing_method_type == RoutingMethodType.DeepSeekV3: router_logits = router_logits.to(torch.float32) - out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, - hidden_states=hidden_states, - gemm1_weights=w1, - output1_scales_scalar=self._g1_scale_c, - output1_scales_gate_scalar=self._g1_alphas, - gemm2_weights=w2, - output2_scales_scalar=self._g2_alphas, - num_experts=global_num_experts, - top_k=self.topk, - n_group=num_expert_group or 0, - topk_group=topk_group or 0, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.ep_rank * self.local_num_experts, - local_num_experts=self.local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=apply_router_weight_on_input, - routing_method_type=self.routing_method_type, - activation_type=activation_type, - ) - return out + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + return flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + gemm1_weights=w1, + output1_scales_scalar=self._g1_scale_c, + output1_scales_gate_scalar=self._g1_alphas, + gemm2_weights=w2, + output2_scales_scalar=self._g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=num_expert_group or 0, + topk_group=topk_group or 0, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=apply_router_weight_on_input, + routing_method_type=self.routing_method_type, + activation_type=activation_type, + ) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 87b1eb9fd58d..3a9213b2388d 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -309,34 +309,44 @@ def apply( ) # Invoke kernel. - return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( - *hidden_states.shape[:-1], -1 - ), - gemm1_weights=w1, - gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=w2, - gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn), - gemm2_bias=None, - output1_scale_scalar=self.g1_scale_c, - output1_scale_gate_scalar=self.quant_config.g1_alphas, - output2_scale_scalar=self.quant_config.g2_alphas, - num_experts=global_num_experts, - top_k=self.topk, - n_group=(num_expert_group or 0), - topk_group=(topk_group or 0), - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.ep_rank * self.local_num_experts, - local_num_experts=self.local_num_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=self.routing_method_type, - do_finalize=True, - activation_type=activation_to_flashinfer_int(activation), - )[0] + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + # Enable autotune when flashinfer#2023 is resolved. + return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ), + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=self.g1_scale_c, + output1_scale_gate_scalar=self.quant_config.g1_alphas, + output2_scale_scalar=self.quant_config.g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + do_finalize=True, + activation_type=activation_to_flashinfer_int(activation), + )[0] diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 5805a4dd5bf6..f6645d6f4f61 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -144,6 +144,13 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): + # Skip this kernel during autotuning dummy run to avoid + # CUDA errors from incompatible autotuning (flashinfer#2023). + import vllm.utils.flashinfer as fi_utils + + if fi_utils._is_fi_autotuning: + return + assert self.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." ) @@ -165,20 +172,25 @@ def apply( if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else hidden_states ) - flashinfer_cutedsl_moe_masked( - hidden_states=flashinfer_hidden_states, - input_global_scale=input_global_scale, - w1=w1, - w1_blockscale=self.w1_scale, - w1_alpha=self.g1_alphas, - w2=w2, - a2_global_scale=self.a2_gscale, - w2_blockscale=self.w2_scale, - w2_alpha=self.g2_alphas, - masked_m=expert_num_tokens, - workspace=workspace2, - out=output, - ) + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + flashinfer_cutedsl_moe_masked( + hidden_states=flashinfer_hidden_states, + input_global_scale=input_global_scale, + w1=w1, + w1_blockscale=self.w1_scale, + w1_alpha=self.g1_alphas, + w2=w2, + a2_global_scale=self.a2_gscale, + w2_blockscale=self.w2_scale, + w2_alpha=self.g2_alphas, + masked_m=expert_num_tokens, + workspace=workspace2, + out=output, + ) def get_cute_dtype(input: torch.Tensor) -> str: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 91f7a83f6fce..ba85fb8c5d5f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -259,6 +259,13 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, ): + # Skip this kernel during autotuning dummy run to avoid + # CUDA errors from incompatible autotuning (flashinfer#2023). + import vllm.utils.flashinfer as fi_utils + + if fi_utils._is_fi_autotuning: + return + from flashinfer.fused_moe.core import ActivationType activation_str_to_value_map = { @@ -366,32 +373,38 @@ def apply( fc1_expert_weights = w1 fc2_expert_weights = w2 - _ = flashinfer_cutlass_fused_moe( - input=hidden_states, - token_selected_experts=topk_ids.to(torch.int), - token_final_scales=topk_weights, - fc1_expert_weights=fc1_expert_weights, - fc2_expert_weights=fc2_expert_weights, - fc1_expert_biases=fc1_expert_biases, - fc2_expert_biases=fc2_expert_biases, - swiglu_alpha=swiglu_alpha, - swiglu_beta=swiglu_beta, - swiglu_limit=swiglu_limit, - output=output, - output_dtype=self.out_dtype, - quant_scales=quant_scales, - input_sf=a1q_scale, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - activation_type=activation_str_to_value_map[activation], - # Informs FlashInfer to use the block-scale decoding path when True - use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, - use_mxfp8_act_scaling=use_mxfp8_act_scaling, - use_w4_group_scaling=use_w4_group_scaling, - tune_max_num_tokens=max(self.max_capture_size, 1), - ) + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + fc1_expert_weights=fc1_expert_weights, + fc2_expert_weights=fc2_expert_weights, + fc1_expert_biases=fc1_expert_biases, + fc2_expert_biases=fc2_expert_biases, + swiglu_alpha=swiglu_alpha, + swiglu_beta=swiglu_beta, + swiglu_limit=swiglu_limit, + output=output, + output_dtype=self.out_dtype, + quant_scales=quant_scales, + input_sf=a1q_scale, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + activation_type=activation_str_to_value_map[activation], + # Informs FlashInfer to use the block-scale decoding + # path when True + use_deepseek_fp8_block_scale=(self.use_deepseek_fp8_block_scale), + use_mxfp8_act_scaling=use_mxfp8_act_scaling, + use_w4_group_scaling=use_w4_group_scaling, + tune_max_num_tokens=max(self.max_capture_size, 1), + ) def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: # No support for LoRA in flashinfer_cutlass_fused_moe. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index d04e040c8959..9c339bcf61ab 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -94,24 +94,27 @@ def flashinfer_fused_moe_bf16( routing_method_type: int, tune_max_num_tokens: int = 8192, ) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe - - return flashinfer_trtllm_bf16_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - num_experts=num_experts, - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routing_method_type=routing_method_type, - tune_max_num_tokens=tune_max_num_tokens, - ) + from vllm.utils.flashinfer import autotune, flashinfer_trtllm_bf16_moe + + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + with autotune(False): + return flashinfer_trtllm_bf16_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routing_method_type=routing_method_type, + tune_max_num_tokens=tune_max_num_tokens, + ) def flashinfer_fused_moe_bf16_fake( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py index 98a3d1e12bdc..9be88bebb911 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py @@ -238,29 +238,32 @@ def flashinfer_trtllm_mxint4_moe( if routing_method_type == RoutingMethodType.DeepSeekV3: router_logits = router_logits.to(torch.float32) - out = trtllm_mxint4_block_scale_moe( - routing_logits=router_logits, - routing_bias=routing_bias, - hidden_states=x, - gemm1_weights=w13_weight_packed.data, - gemm1_weights_scale=w13_weight_scale.data, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=w2_weight_packed.data, - gemm2_weights_scale=w2_weight_scale.data, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group if num_expert_group is not None else 0, - topk_group=topk_group if topk_group is not None else 0, - intermediate_size=intermediate_size_per_partition, - local_expert_offset=ep_rank * local_num_experts, - local_num_experts=local_num_experts, - routed_scaling_factor=None, - routing_method_type=routing_method_type, - enable_pdl=None, - output=None, - tune_max_num_tokens=8192, - ).to(x.dtype) - - return out + # Disable autotune until + # https://github.com/flashinfer-ai/flashinfer/issues/2023 is resolved. + from vllm.utils.flashinfer import autotune + + with autotune(False): + return trtllm_mxint4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=x, + gemm1_weights=w13_weight_packed.data, + gemm1_weights_scale=w13_weight_scale.data, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2_weight_packed.data, + gemm2_weights_scale=w2_weight_scale.data, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group if num_expert_group is not None else 0, + topk_group=topk_group if topk_group is not None else 0, + intermediate_size=intermediate_size_per_partition, + local_expert_offset=ep_rank * local_num_experts, + local_num_experts=local_num_experts, + routed_scaling_factor=None, + routing_method_type=routing_method_type, + enable_pdl=None, + output=None, + tune_max_num_tokens=8192, + ).to(x.dtype)