From ce46183ba5ae2c73bb3c2d49d39b7a9e99d6ffa2 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Fri, 8 May 2026 12:50:50 -0700 Subject: [PATCH 01/10] [None][feat] FLASHINFER_NVFP4SM12X NVFP4 MoE backend (SM120/SM121, hybrid CUTLASS-prefill / b12x-decode) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the FLASHINFER_NVFP4SM12X MoE backend, selectable via moe_config.backend: FLASHINFER_NVFP4SM12X. Targets Nemotron-Super-120B-NVFP4 on SM120 (RTX PRO 6000 / GB202) and SM121 (DGX Spark / GB10). Composition (see MOE_DEVELOPER_GUIDE.md for the full explainer): - Prefill (m >= 64) routes through the inherited CutlassFusedMoE NVFP4 GroupGEMM. The b12x kernel's 12-CTA-per-token MMA pattern is suboptimal at large m. - Decode (m < 64) dispatches to FlashInfer's B12xMoEWrapper.run, a kernel purpose-built for m=1 / small routed-row counts. NVFP4 weights are loaded once via the inherited NVFP4 quant method; post_load_weights then prepares the b12x-shaped weight tensors alongside the existing CUTLASS layout (un-normalize FP8 block scales, apply convert_sf_to_mma_layout, prep w*_alpha for b12x's dual-use convention). Both layouts coexist; the dispatcher picks per call based on x.shape[0]. CUDA graph capture only covers decode in TRT-LLM, so captured graphs always replay the b12x path; eager prefill always runs CUTLASS — no graph-capture conflict. Hard-rejects EP, MoE alltoall, Fp4QuantizedTensor input on the decode path, swiglu_gptoss_style biased SwiGLU, and activations outside {Relu2, Swiglu}. Misconfigured selection raises at get_moe_cls time rather than silently falling back to CUTLASS. Replaces the prior FLASHINFER backend identifier (which exposed only the pure-FlashInfer / b12x path with a +48.6% TTFT regression at prefill). The hybrid composition eliminates that regression and beats CUTLASS by +21.7% throughput / -17.6% TPOT on Nemotron-Super-120B-NVFP4 at conc=1. Bench numbers and full investigation in .claude_docs/nemo-fp4-moe-b12x-mr/HYBRID_RESULTS.md. Tests: 23 unit tests in tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py (19 negative-path can_implement / get_moe_cls tests + 4 hybrid dispatch shape-predicate tests). All pass on container (no GPU required). Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- .../_torch/modules/fused_moe/__init__.py | 2 + .../_torch/modules/fused_moe/create_moe.py | 22 +- .../fused_moe_flashinfer_nvfp4_sm12x.py | 359 ++++++++++++++++++ tensorrt_llm/llmapi/llm_args.py | 2 +- ...test_flashinfer_nvfp4_sm12x_moe_backend.py | 170 +++++++++ 5 files changed, 553 insertions(+), 2 deletions(-) create mode 100644 tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py create mode 100644 tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index 4b957c86246c..105ea3ddb3dd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -1,6 +1,7 @@ from .create_moe import create_moe, get_moe_cls from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_flashinfer_nvfp4_sm12x import FlashInferNvfp4Sm12xFusedMoE from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE @@ -30,6 +31,7 @@ "CutlassFusedMoE", "DeepSeekV3MoeRoutingMethod", "DefaultMoeRoutingMethod", + "FlashInferNvfp4Sm12xFusedMoE", "FusedMoEQuantScalesFP8", "get_moe_cls", "Llama4RenormalizeMoeRoutingMethod", diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 264580986a0f..e7cc4ef0d893 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -15,6 +15,7 @@ from .fused_moe_cutlass import CutlassFusedMoE from .fused_moe_deepgemm import DeepGemmFusedMoE from .fused_moe_densegemm import DenseGEMMFusedMoE +from .fused_moe_flashinfer_nvfp4_sm12x import FlashInferNvfp4Sm12xFusedMoE from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE @@ -138,6 +139,23 @@ def get_moe_cls( "Falling back to CutlassFusedMoE.") return CutlassFusedMoE return MegaMoEDeepGemm + elif moe_backend.upper() == "FLASHINFER_NVFP4SM12X": + # FlashInferNvfp4Sm12xFusedMoE is the hybrid CUTLASS-prefill / + # b12x-decode NVFP4 MoE backend for SM120/SM121. Hard-error rather + # than silently falling back to CUTLASS so a misconfigured request + # is loud at startup, not a silent perf regression. + if quant_config is None or not quant_config.quant_mode.has_nvfp4(): + raise ValueError( + "FlashInferNvfp4Sm12xFusedMoE requires NVFP4 quantization " + f"(got quant_config={quant_config}).") + from tensorrt_llm._utils import get_sm_version + sm_version = get_sm_version() + if sm_version not in FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS: + sm_list = "/".join(f"SM{v}" for v in sorted( + FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) + raise ValueError(f"FlashInferNvfp4Sm12xFusedMoE requires {sm_list} " + f"(got SM{sm_version}).") + return FlashInferNvfp4Sm12xFusedMoE else: raise ValueError(f"Unsupported moe backend: {moe_backend}") @@ -280,7 +298,9 @@ def create_moe_backend( without_comm=without_comm, activation_type=activation_type, ) - elif moe_cls == CutlassFusedMoE: + elif issubclass(moe_cls, CutlassFusedMoE): + # CutlassFusedMoE and any of its subclasses (e.g. FlashInferNvfp4Sm12xFusedMoE) + # share the same constructor signature. return moe_cls( routing_method=routing_method, num_experts=num_experts, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py new file mode 100644 index 000000000000..83b6fa62571a --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch + +from tensorrt_llm._utils import get_sm_version, nvtx_range +from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import QuantAlgo + +from ...utils import ActivationType, Fp4QuantizedTensor +from .fused_moe_cutlass import CutlassFusedMoE +from .interface import _warn_and_return + +# Shared MoE output buffer pool, keyed by (max_num_tokens, hidden_size, dtype, +# device). ``B12xMoEWrapper.__init__`` allocates a private +# ``(max_num_tokens, hidden_size)`` output tensor per instance; with one +# wrapper per MoE layer that is ``num_layers * max_num_tokens * hidden_size`` +# bytes of GPU memory holding identical-shape buffers that are written +# sequentially. We fold them into a single shared buffer because MoE layers +# run sequentially on the same CUDA stream, and the wrapper consumes its +# previous output before the next layer is dispatched. +_SHARED_MOE_OUTPUT_BUF: dict = {} + +# ActivationType -> b12x activation string. b12x currently exposes "relu2" +# (Nemotron-style x * relu(x)) and "silu" (SwiGLU-style x * silu(gate)). +_ACTIVATION_MAP = { + ActivationType.Relu2: "relu2", + ActivationType.Swiglu: "silu", +} + + +class FlashInferNvfp4Sm12xFusedMoE(CutlassFusedMoE): + """Hybrid CUTLASS-prefill / b12x-decode NVFP4 fused-MoE backend for SM120 / SM121. + + Composition (see ``MOE_DEVELOPER_GUIDE.md`` for the full explainer): + + - **Prefill (``m >= _PREFILL_VIA_CUTLASS_THRESHOLD``)** routes through the + inherited :class:`CutlassFusedMoE` NVFP4 GroupGEMM. The b12x kernel's + 12-CTA-per-token MMA pattern is suboptimal at large ``m``. + - **Decode (``m < _PREFILL_VIA_CUTLASS_THRESHOLD``)** dispatches to + FlashInfer's ``B12xMoEWrapper.run`` — a kernel purpose-built for + ``m=1`` / small routed-row counts. + + NVFP4 weights are loaded once via the inherited NVFP4 quant method; + ``post_load_weights`` then prepares the b12x-shaped weight tensors + alongside the existing CUTLASS layout. Both layouts coexist in memory + and the dispatcher picks per call based on ``x.shape[0]``. + + CUDA graph capture only covers decode, so captured graphs always replay + the b12x path; eager prefill always runs CUTLASS — there is no graph + capture conflict. + + The backend hard-rejects EP (b12x has no dispatch / combine kernel), + MoE alltoall, ``Fp4QuantizedTensor`` input, ``swiglu_gptoss_style`` + biased SwiGLU, and activations outside ``{Relu2, Swiglu}``. It is + selected via ``moe_config.backend: FLASHINFER_NVFP4SM12X``. + """ + + # SM versions on which the FlashInfer b12x NVFP4 MoE kernel is available. + # SM120 = desktop Blackwell (RTX 5090 / GB202); SM121 = GB10 / DGX Spark. + _SUPPORTED_SM_VERSIONS = frozenset({120, 121}) + + # Prefill chunks (``x.shape[0] >= threshold``) route via CUTLASS NVFP4 + # GroupGEMM; decode (``x.shape[0] < threshold``) uses b12x. 64 cleanly + # separates conc=1 prefill (m=2048 with ``max_num_tokens=2048``) from + # decode (m=1) and stays robust against future chunked-prefill splits + # that might shrink prefill chunk size. + _PREFILL_VIA_CUTLASS_THRESHOLD = 64 + + @classmethod + def can_implement( + cls, + quant_algo: Optional[QuantAlgo], + dtype_activation: torch.dtype = torch.bfloat16, + swiglu_gptoss_style: bool = False, + ) -> Tuple[bool, Optional[str]]: + sm_version = get_sm_version() + if sm_version not in cls._SUPPORTED_SM_VERSIONS: + sm_list = "/".join(f"SM{v}" for v in sorted(cls._SUPPORTED_SM_VERSIONS)) + return _warn_and_return( + f"FlashInferNvfp4Sm12xFusedMoE requires {sm_list}, got SM{sm_version}" + ) + if quant_algo != QuantAlgo.NVFP4: + return _warn_and_return( + f"FlashInferNvfp4Sm12xFusedMoE only supports NVFP4 quantization " + f"(got quant_algo={quant_algo})" + ) + if dtype_activation not in {torch.float16, torch.bfloat16}: + return _warn_and_return( + f"FlashInferNvfp4Sm12xFusedMoE NVFP4 requires float16 or bfloat16 " + f"activation dtype (got {dtype_activation})" + ) + if swiglu_gptoss_style: + return _warn_and_return( + "FlashInferNvfp4Sm12xFusedMoE does not support swiglu_gptoss_style" + ) + return True, None + + def __init__(self, *args, **kwargs): + # ``ModelConfig`` is consumed by the inherited ``__init__`` for cache + # / mapping setup but isn't kept on ``self``. b12x's wrapper needs the + # ``use_cuda_graph`` flag at construction time, so capture it here + # before delegating. + model_config = kwargs.get("model_config", None) + self._b12x_use_cuda_graph = bool(getattr(model_config, "use_cuda_graph", False)) + + super().__init__(*args, **kwargs) + + # b12x has no expert-parallel dispatch/combine kernel, so EP must be + # disabled. dp_size > 1 implies the alltoall path which b12x can't run. + if self.ep_size != 1: + raise ValueError( + f"FlashInferNvfp4Sm12xFusedMoE requires ep_size == 1 " + f"(got ep_size={self.ep_size}); use --moe_backend CUTLASS for EP." + ) + if self.enable_alltoall: + raise ValueError( + "FlashInferNvfp4Sm12xFusedMoE does not support MoE alltoall communication." + ) + if self.activation_type not in _ACTIVATION_MAP: + supported = ", ".join(a.name for a in _ACTIVATION_MAP) + raise ValueError( + f"FlashInferNvfp4Sm12xFusedMoE does not support activation " + f"{ActivationType(self.activation_type).name}; " + f"supported: {supported}." + ) + + self._b12x_weights: Optional[dict] = None + self.b12x_wrapper = None + + def _route_to_cutlass(self, x) -> bool: + """Return ``True`` iff this call should fall back to the inherited + CUTLASS path (prefill chunk). ``Fp4QuantizedTensor`` inputs always + stay on the b12x path (which rejects them) so the existing error + message is preserved.""" + return isinstance(x, torch.Tensor) and x.shape[0] >= self._PREFILL_VIA_CUTLASS_THRESHOLD + + def post_load_weights(self): + """Build the b12x weight dict and instantiate ``B12xMoEWrapper``. + + Called by ``model_loader`` after ``load_weights`` finishes. The NVFP4 + quant method's ``process_weights_after_loading`` has already run as + part of ``load_weights``, so the inherited ``w3_w1_weight`` / + ``w2_weight`` / ``*_weight_scale`` / ``*_alpha`` / ``*_input_scale`` + tensors are populated; we just convert them to the layout b12x + expects. + """ + super().post_load_weights() + + try: + from flashinfer import B12xMoEWrapper + from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + except ImportError as e: + raise RuntimeError( + "FlashInferNvfp4Sm12xFusedMoE requires the `flashinfer` package " + "(B12xMoEWrapper, cute_dsl.utils.convert_sf_to_mma_layout). " + f"Original import error: {e}" + ) from e + + num_local_experts = self.w3_w1_weight.shape[0] + # Tensor shapes use the *padded* per-rank dims because TP partitions + # may pad ``intermediate_size`` up to a kernel-friendly boundary. + # Recover them from the actual stored tensors rather than the logical + # model config so reshapes stay valid under TP > 1. + _, w3w1_out_dim, _ = self.w3_w1_weight.shape # (E, 2*I_pad, H//16) + _, w2_out_dim, w2_in_packed = self.w2_weight.shape # (E, H, I_pad//16) + w3w1_in_dim = self.hidden_size + w2_in_dim = w2_in_packed * 16 + + # b12x reuses the per-expert ``w1_alpha`` tensor as both (a) the + # online activation-quant ``global_scale`` and (b) the FC1 epilogue + # output-dequant multiplier. That dual use is only self-consistent + # when the FP4 weight block scales are stored in their *unnormalized* + # form (raw ``max_block / FP4_MAX``), not divided out by the + # per-tensor ``weight_scale_2``. HF / ModelOpt NVFP4 checkpoints + # store the normalized variant so the FP8 block scales fit in range, + # and TRT-LLM's NVFP4 loader preserves that form. To match b12x's + # convention we recover ``weight_scale_2 = fc_alpha * fc_input_scale`` + # and multiply each expert's FP8 block scales by it before handing + # them to ``convert_sf_to_mma_layout``. With the un-normalized scales + # in place we pass ``w1_alpha = w2_alpha = 1 / fc_input_scale`` + # (== ``s_in``) so the kernel's dual-use cancels algebraically and + # the stored input-side block scales remain FP8-representable. + w1_w_scale_2 = (self.fc31_alpha * self.fc31_input_scale).to(torch.float32) + w2_w_scale_2 = (self.fc2_alpha * self.fc2_input_scale).to(torch.float32) + + w1_sf_fp8_norm = self.w3_w1_weight_scale.view(torch.float8_e4m3fn).float() + w2_sf_fp8_norm = self.w2_weight_scale.view(torch.float8_e4m3fn).float() + + # Broadcast per-expert scalar over the trailing dims (E, *). + bcast1 = w1_w_scale_2.view(-1, *([1] * (w1_sf_fp8_norm.dim() - 1))) + bcast2 = w2_w_scale_2.view(-1, *([1] * (w2_sf_fp8_norm.dim() - 1))) + w1_sf_fp8 = (w1_sf_fp8_norm * bcast1).to(torch.float8_e4m3fn) + w2_sf_fp8 = (w2_sf_fp8_norm * bcast2).to(torch.float8_e4m3fn) + + w1_sf_b12x = convert_sf_to_mma_layout( + w1_sf_fp8, m=w3w1_out_dim, k=w3w1_in_dim, num_groups=num_local_experts + ) + w2_sf_b12x = convert_sf_to_mma_layout( + w2_sf_fp8, m=w2_out_dim, k=w2_in_dim, num_groups=num_local_experts + ) + + w1_alpha_b12x = ( + (1.0 / self.fc31_input_scale).expand(self.num_experts).to(torch.float32).contiguous() + ) + w2_alpha_b12x = ( + (1.0 / self.fc2_input_scale).expand(self.num_experts).to(torch.float32).contiguous() + ) + fc2_input_scale_b12x = (1.0 / self.fc2_input_scale).to(torch.float32) + + # TRT-LLM packs 16 FP4 values per int64. flashinfer's internal + # ``view(torch.float4_e2m1fn_x2)`` requires byte-contiguous storage + # (stride[-1] == 1 in bytes); a uint8 view of the int64 tensor + # provides that without copying. + self._b12x_weights = dict( + w1_weight=self.w3_w1_weight.view(torch.uint8), + w1_weight_sf=w1_sf_b12x, + w1_alpha=w1_alpha_b12x, + w2_weight=self.w2_weight.view(torch.uint8), + w2_weight_sf=w2_sf_b12x, + w2_alpha=w2_alpha_b12x, + fc2_input_scale=fc2_input_scale_b12x, + ) + + self.b12x_wrapper = B12xMoEWrapper( + num_experts=self.num_experts, + top_k=self.routing_method.experts_per_token, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size_per_partition, + use_cuda_graph=self._b12x_use_cuda_graph, + max_num_tokens=self.moe_max_num_tokens, + activation=_ACTIVATION_MAP[self.activation_type], + ) + + # Replace the wrapper's per-instance output buffer with a shared one. + # Layers run sequentially on a single stream, so a single buffer of the + # right shape is correct and saves + # ``(num_moe_layers - 1) * max_num_tokens * hidden_size * 2`` bytes — + # ~2.5 GB on Nemotron-Super-120B with ``max_num_tokens=2048``, + # ``hidden=8192``, bf16, 80 MoE layers. + if self.b12x_wrapper._moe_output is not None: + buf = self.b12x_wrapper._moe_output + key = (buf.shape[0], buf.shape[1], buf.dtype, str(buf.device)) + shared = _SHARED_MOE_OUTPUT_BUF.get(key) + if shared is None: + _SHARED_MOE_OUTPUT_BUF[key] = buf + else: + # Free the freshly allocated buffer; reuse the existing one. + self.b12x_wrapper._moe_output = shared + + logger.info_once( + f"FlashInferNvfp4Sm12xFusedMoE active: hidden={self.hidden_size}, " + f"intermediate={self.intermediate_size_per_partition}, " + f"experts={self.num_experts}, top_k=" + f"{self.routing_method.experts_per_token}, " + f"activation={_ACTIVATION_MAP[self.activation_type]}.", + key="flashinfer_nvfp4_sm12x_moe_active", + ) + + @nvtx_range("[b12x] quantize_input") + def quantize_input( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + post_quant_comm: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Hybrid dispatch entrypoint for activation handling. + + Prefill chunks (``x.shape[0] >= _PREFILL_VIA_CUTLASS_THRESHOLD``) take + the inherited :meth:`CutlassFusedMoE.quantize_input` path so the + downstream ``run_moe`` can call CUTLASS NVFP4 GroupGEMM. Decode + chunks pass through unchanged because b12x quantizes activations + internally (consumes a bf16 / fp16 ``x`` and produces its own scale + factors). + """ + if self._route_to_cutlass(x): + return CutlassFusedMoE.quantize_input( + self, x, post_quant_comm=post_quant_comm, **kwargs + ) + if isinstance(x, Fp4QuantizedTensor): + raise ValueError( + "FlashInferNvfp4Sm12xFusedMoE does not accept Fp4QuantizedTensor input " + "on the b12x decode path; b12x performs its own input quantization." + ) + return x, None + + @nvtx_range("[b12x] run_moe") + def run_moe( + self, + x: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: torch.Tensor, + x_sf: Optional[torch.Tensor] = None, + is_sf_swizzled: bool = True, + output_dtype: Optional[torch.dtype] = None, + tuner_num_tokens: Optional[int] = None, + tuner_top_k: Optional[int] = None, + moe_output: Optional[torch.Tensor] = None, + enable_alltoall: Optional[bool] = None, + ) -> torch.Tensor: + if self._route_to_cutlass(x): + return CutlassFusedMoE.run_moe( + self, + x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + x_sf=x_sf, + is_sf_swizzled=is_sf_swizzled, + output_dtype=output_dtype, + tuner_num_tokens=tuner_num_tokens, + tuner_top_k=tuner_top_k, + moe_output=moe_output, + enable_alltoall=enable_alltoall, + ) + if self.b12x_wrapper is None or self._b12x_weights is None: + raise RuntimeError( + "FlashInferNvfp4Sm12xFusedMoE.run_moe called before " + "process_weights_after_loading completed." + ) + if x_sf is not None: + raise ValueError( + "FlashInferNvfp4Sm12xFusedMoE expects unquantized input (x_sf=None) " + "on the b12x decode path; got a precomputed scale factor." + ) + + # Annotate the kwargs spread + wrapper entry separately so we can + # attribute the per-layer Python dispatch cost vs. the kernel cost. + with nvtx_range("[b12x] wrapper.run"): + out = self.b12x_wrapper.run( + x=x, + token_selected_experts=token_selected_experts, + token_final_scales=token_final_scales, + **self._b12x_weights, + ) + + # B12xMoEWrapper allocates its own output buffer for CUDA-graph + # compatibility. If the caller provided ``moe_output`` (e.g. an + # alltoall workspace tensor), copy into it; FlashInferNvfp4Sm12xFusedMoE + # currently rejects alltoall in __init__, so this is a defensive + # path for future workspace-driven uses. + if moe_output is not None: + with nvtx_range("[b12x] out_copy"): + moe_output.copy_(out) + return moe_output + return out diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 0a7b2297e3b5..84b76b24e5e0 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -594,7 +594,7 @@ class MoeConfig(StrictBaseModel): """ backend: Literal[ "AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", - "DENSEGEMM", "VANILLA", "TRITON"] = Field( + "DENSEGEMM", "VANILLA", "TRITON", "FLASHINFER_NVFP4SM12X"] = Field( default='AUTO', description="MoE backend to use. " "AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases." diff --git a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py new file mode 100644 index 000000000000..ea71d34b9b77 --- /dev/null +++ b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Negative-path + dispatch tests for FlashInferNvfp4Sm12xFusedMoE. + +These checks run without a GPU: they verify the can_implement() gating +matrix, the hard-error policy in create_moe.get_moe_cls, and the +hybrid CUTLASS-prefill / b12x-decode dispatch predicate. Functional +correctness of the b12x kernel is covered by end-to-end model tests on +SM120/SM121 hardware. +""" + +from unittest.mock import patch + +import pytest +import torch + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.modules.fused_moe.create_moe import get_moe_cls +from tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x import ( + FlashInferNvfp4Sm12xFusedMoE, +) +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + +_FUSED_MOE_MODULE = "tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x" + + +@pytest.mark.parametrize("sm_version", [80, 89, 90, 100, 103]) +def test_can_implement_rejects_unsupported_sm(sm_version): + """can_implement returns False on every SM outside the supported set.""" + with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=sm_version): + ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(QuantAlgo.NVFP4) + assert not ok + assert reason is not None and f"SM{sm_version}" in reason + + +@pytest.mark.parametrize("sm_version", sorted(FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) +def test_can_implement_accepts_supported_sm_with_nvfp4(sm_version): + with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=sm_version): + ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(QuantAlgo.NVFP4) + assert ok + assert reason is None + + +@pytest.mark.parametrize( + "quant_algo", + [ + None, + QuantAlgo.FP8, + QuantAlgo.FP8_BLOCK_SCALES, + QuantAlgo.W4A16_MXFP4, + QuantAlgo.W4A8_MXFP4_FP8, + ], +) +def test_can_implement_rejects_non_nvfp4(quant_algo): + """Only NVFP4 is supported; everything else must be turned away.""" + with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): + ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(quant_algo) + assert not ok + assert reason is not None and "NVFP4" in reason + + +def test_can_implement_rejects_swiglu_gptoss_style(): + with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): + ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement( + QuantAlgo.NVFP4, swiglu_gptoss_style=True + ) + assert not ok + assert reason is not None and "swiglu_gptoss_style" in reason + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float8_e4m3fn]) +def test_can_implement_rejects_unsupported_activation_dtype(dtype): + with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): + ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement( + QuantAlgo.NVFP4, dtype_activation=dtype + ) + assert not ok + assert reason is not None + + +def test_get_moe_cls_raises_on_non_nvfp4(): + """create_moe.get_moe_cls must hard-error rather than fall back silently.""" + cfg = ModelConfig() + cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) + with pytest.raises(ValueError, match="NVFP4"): + get_moe_cls(cfg) + + +def test_get_moe_cls_raises_on_missing_quant(): + cfg = ModelConfig() + cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.quant_config = None + with pytest.raises(ValueError, match="NVFP4"): + get_moe_cls(cfg) + + +def test_get_moe_cls_raises_on_unsupported_sm(): + cfg = ModelConfig() + cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + with patch("tensorrt_llm._utils.get_sm_version", return_value=100): + with pytest.raises(ValueError, match="SM"): + get_moe_cls(cfg) + + +def test_get_moe_cls_returns_flashinfer_on_supported_sm(): + cfg = ModelConfig() + cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is FlashInferNvfp4Sm12xFusedMoE + + +# -------------------------------------------------------------------------- +# Hybrid CUTLASS-prefill / b12x-decode dispatch predicate tests +# +# ``_route_to_cutlass`` is a pure shape predicate on its input ``x``; we test +# it via a stub that holds the class constant, sidestepping the full +# CutlassFusedMoE constructor (which needs a routing method, real model +# config, etc.). +# -------------------------------------------------------------------------- + + +class _RoutePredicateStub: + """Minimal carrier for ``_PREFILL_VIA_CUTLASS_THRESHOLD`` so we can call + the unbound ``_route_to_cutlass`` without instantiating the whole MoE + backend.""" + + _PREFILL_VIA_CUTLASS_THRESHOLD = FlashInferNvfp4Sm12xFusedMoE._PREFILL_VIA_CUTLASS_THRESHOLD + + _route_to_cutlass = FlashInferNvfp4Sm12xFusedMoE._route_to_cutlass + + +def test_dispatch_routes_prefill_shape_via_cutlass(): + stub = _RoutePredicateStub() + x = torch.empty(_RoutePredicateStub._PREFILL_VIA_CUTLASS_THRESHOLD, 1024) + assert stub._route_to_cutlass(x) is True + + +def test_dispatch_just_below_threshold_takes_b12x(): + stub = _RoutePredicateStub() + x = torch.empty(_RoutePredicateStub._PREFILL_VIA_CUTLASS_THRESHOLD - 1, 1024) + assert stub._route_to_cutlass(x) is False + + +def test_dispatch_decode_shape_takes_b12x(): + stub = _RoutePredicateStub() + x = torch.empty(1, 1024) + assert stub._route_to_cutlass(x) is False + + +def test_dispatch_rejects_non_tensor(): + """Non-tensor inputs (e.g. Fp4QuantizedTensor) stay on the b12x path + so the existing ValueError surfaces in quantize_input.""" + stub = _RoutePredicateStub() + assert stub._route_to_cutlass(object()) is False From c29de6288cc24be92be74f00a772af516a31c522 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Tue, 12 May 2026 10:56:37 -0700 Subject: [PATCH 02/10] [None][fix] pin nvidia-cutlass-dsl-libs-cu13 for SM120/SM121 b12x JIT The b12x MoE kernel introduced by PR #13773 (FLASHINFER_NVFP4SM12X) JIT-compiles via nvidia-cutlass-dsl, whose CUDA 13 runtime libraries ship as a separate optional wheel (nvidia-cutlass-dsl-libs-cu13) and are NOT pulled automatically by the main nvidia-cutlass-dsl wheel. Without this wheel, executor initialization on SM120/SM121 hosts dies with ptxas "Unexpected instruction types specified for '_mma'" because the chip->compute_target conversion falls back to a path that strips the 'a' suffix (sm_120a -> sm_120), and ptxas is then invoked with -opt-arch=sm_120 against PTX that has .target sm_120a with sm_120a-only mma instruction forms. The runtime requirement was documented in the PR body but never made binding via requirements.txt. Pin it explicitly at the same version as the main wheel so fresh builds reproduce the same working environment. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0282091159d0..b5a8e1f796fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,6 +72,7 @@ tiktoken blobfile openai-harmony==0.0.4 nvidia-cutlass-dsl==4.5.0; python_version >= "3.10" +nvidia-cutlass-dsl-libs-cu13==4.5.0; python_version >= "3.10" plotly numexpr partial_json_parser From a6e389da1c29db36ee6bb147c46ce041caa5530d Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Tue, 12 May 2026 11:08:32 -0700 Subject: [PATCH 03/10] [None][test] add FLASHINFER_NVFP4SM12X unit tests to l0_rtx_pro_6000 Adds test_flashinfer_nvfp4_sm12x_moe_backend.py (23 tests covering can_implement gating, get_moe_cls error paths, and the hybrid CUTLASS-prefill / b12x-decode dispatch predicate) to the pre-merge 1-GPU PyTorch section of the SM120 (RTX PRO 6000) test list. No GPU required, ~12s run time, so it stays in the pre-merge tier. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml index 0ceba9c3bc7f..d4ec70fdac04 100644 --- a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml +++ b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml @@ -15,6 +15,7 @@ l0_rtx_pro_6000: tests: # ------------- PyTorch tests --------------- - unittest/_torch/modeling -k "modeling_out_of_tree" + - unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py # - unittest/_torch/modeling -k "modeling_qwen" # https://nvbugs/5234573 - unittest/_torch/attention/test_attention_mla.py # SM120 W4A16 / W4A8 mixed-dtype GEMM coverage (paired with FinegrainedMixedDtypeGemm From ff292c23687f692b8c58c6db9dd70d35bf006bba Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Wed, 13 May 2026 13:16:10 -0700 Subject: [PATCH 04/10] [None][fix] restrict create_moe CutlassFusedMoE branch to exact classes The `elif issubclass(moe_cls, CutlassFusedMoE)` dispatch added with this backend captured `CuteDslFusedMoE` and `DeepGemmFusedMoE` (both subclass `CutlassFusedMoE`) before their dedicated branches at L326 / L343 could match, leaking `swiglu_alpha` / `swiglu_beta` / `swiglu_limit` into constructors that don't accept them. Restrict the branch to an explicit allowlist (`CutlassFusedMoE`, `FlashInferNvfp4Sm12xFusedMoE`) so the existing exact-class branches keep working. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/create_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index e7cc4ef0d893..9efff071db58 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -298,9 +298,9 @@ def create_moe_backend( without_comm=without_comm, activation_type=activation_type, ) - elif issubclass(moe_cls, CutlassFusedMoE): - # CutlassFusedMoE and any of its subclasses (e.g. FlashInferNvfp4Sm12xFusedMoE) - # share the same constructor signature. + elif moe_cls in (CutlassFusedMoE, FlashInferNvfp4Sm12xFusedMoE): + # CuteDslFusedMoE and DeepGemmFusedMoE also subclass CutlassFusedMoE but + # have narrower constructors, so they take their own branches below. return moe_cls( routing_method=routing_method, num_experts=num_experts, From 9f3fd2490346cf3e36c82095a58adc3622176bcf Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Wed, 13 May 2026 13:19:49 -0700 Subject: [PATCH 05/10] [None][test] wire FlashInferNvfp4Sm12xFusedMoE into unified MoE test framework Per review feedback, register the new backend in the shared MoE test framework so it participates in the same parametrized harness as the other backends: - Add `FLASHINFER_NVFP4SM12X` to `MoeBackendType` enum and `get_backend_class()` map. - Add `should_skip_flashinfer_nvfp4_sm12x()` covering the EP / alltoall hard rejects from `__init__` (SM, quant, dtype and gptoss are already handled by `can_implement()`). - Exclude the backend from `supports_autotuner_capture()` (b12x decode does not go through the autotuner). - Hook the new helper into `get_quick_skip_reason()` skip chain. - Add the new enum value to `test_moe_backend.py::BACKEND_TYPES_TO_TEST` so the existing parametrization picks it up on SM120/SM121 + NVFP4 configs and skips elsewhere via `can_implement()`. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- .../_torch/modules/moe/moe_test_utils.py | 43 ++++++++++++++++++- .../_torch/modules/moe/test_moe_backend.py | 1 + 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index 13f916db7da7..a35d693e4691 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -46,6 +46,9 @@ ) from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_densegemm import DenseGEMMFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x import ( + FlashInferNvfp4Sm12xFusedMoE, +) from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._torch.modules.fused_moe.mega_moe import MegaMoEDeepGemm from tensorrt_llm._torch.utils import ActivationType, is_gated_activation @@ -66,6 +69,7 @@ class MoeBackendType(str, Enum): DEEPGEMM = "DEEPGEMM" DENSEGEMM = "DENSEGEMM" MEGAMOE = "MEGAMOE_DEEPGEMM" + FLASHINFER_NVFP4SM12X = "FLASHINFER_NVFP4SM12X" def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: @@ -77,6 +81,7 @@ def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: MoeBackendType.DEEPGEMM: DeepGemmFusedMoE, MoeBackendType.DENSEGEMM: DenseGEMMFusedMoE, MoeBackendType.MEGAMOE: MegaMoEDeepGemm, + MoeBackendType.FLASHINFER_NVFP4SM12X: FlashInferNvfp4Sm12xFusedMoE, } return backend_class_map[backend_type] @@ -864,6 +869,35 @@ def should_skip_megamoe( return None +def should_skip_flashinfer_nvfp4_sm12x( + backend_type: MoeBackendType, + comm_method: Optional[str] = None, + moe_tp_size: int = 1, + parallel_mode: Optional[str] = None, +) -> Optional[str]: + """Check FlashInferNvfp4Sm12xFusedMoE constraints not covered by can_implement(). + + can_implement() already gates SM version, quant_algo, dtype_activation, and + swiglu_gptoss_style. This helper covers the additional EP / alltoall hard + rejects enforced in __init__ (b12x has no expert-parallel dispatch/combine + kernel). + """ + if backend_type != MoeBackendType.FLASHINFER_NVFP4SM12X: + return None + + if comm_method is not None or parallel_mode is not None: + return ( + "FlashInferNvfp4Sm12xFusedMoE rejects expert parallelism / alltoall; " + f"got comm_method={comm_method}, parallel_mode={parallel_mode}." + ) + if moe_tp_size != 1: + return ( + f"FlashInferNvfp4Sm12xFusedMoE requires ep_size=1; " + f"got moe_tp_size={moe_tp_size}." + ) + return None + + def should_skip_multi_gpu( parallel_mode: str, model_config: "MoeModelConfig", @@ -970,8 +1004,12 @@ def supports_autotuner_capture( Returns: True if autotuner capture/replay is supported, False otherwise """ - # DEEPGEMM and MEGAMOE do not support autotuner capture - if backend_type in (MoeBackendType.DEEPGEMM, MoeBackendType.MEGAMOE): + # DEEPGEMM, MEGAMOE, and FLASHINFER_NVFP4SM12X do not support autotuner capture + if backend_type in ( + MoeBackendType.DEEPGEMM, + MoeBackendType.MEGAMOE, + MoeBackendType.FLASHINFER_NVFP4SM12X, + ): return False if use_flashinfer: @@ -1050,6 +1088,7 @@ def get_quick_skip_reason( model_config=model_config, swiglu_gptoss_style=swiglu_gptoss_style, ), + lambda: should_skip_flashinfer_nvfp4_sm12x(backend_type), ] for check in skip_checks: skip_reason = check() diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index e9957bf4ca89..5ca289048d2f 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -294,6 +294,7 @@ def run_backend_moe( MoeBackendType.DEEPGEMM, MoeBackendType.DENSEGEMM, MoeBackendType.MEGAMOE, + MoeBackendType.FLASHINFER_NVFP4SM12X, ] # Data types to test From 9c5de386c11bfc8cd90447ee041e563e3bb1da36 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Thu, 14 May 2026 10:00:25 -0700 Subject: [PATCH 06/10] [None][feat] hide FLASHINFER_NVFP4SM12X behind a CUTLASS heuristic auto-promote MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the user-facing `moe_backend: FLASHINFER_NVFP4SM12X` knob with transparent heuristic auto-promotion on the `CUTLASS` path. When the user selects `moe_backend: CUTLASS` (the default), `get_moe_cls()` now returns `FlashInferNvfp4Sm12xFusedMoE` automatically when: - quant_config has NVFP4 - SM version is 120 or 121 - `import flashinfer` succeeds Otherwise it returns `CutlassFusedMoE` (the pre-PR behaviour). The class itself, its weight lifecycle, and its hybrid `m >= 64` decode dispatch are unchanged — only the selection plumbing moves. This responds to xxi-nv's review comment on PR #13773 asking whether the b12x backend could be selected via a heuristic rather than an explicit name. Mirrors the existing `MEGAMOE_DEEPGEMM` pattern of `can_implement`-gated promotion with a CUTLASS fallback. Drops `"FLASHINFER_NVFP4SM12X"` from `MoeConfig.backend` Literal — the class stays importable as an internal API for tests and for direct construction, but is no longer a valid user-facing config string. Tests in `test_flashinfer_nvfp4_sm12x_moe_backend.py` flipped from "explicit name raises on bad config" to "heuristic auto-promotes vs falls back to CutlassFusedMoE". Internal `MoeBackendType` entry kept so `test_moe_backend.py` parametrization continues to cover the backend; `create_test_backend` routes the enum through `moe_backend="CUTLASS"` to exercise the same code path users hit. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- .../_torch/modules/fused_moe/create_moe.py | 38 +++++----- tensorrt_llm/llmapi/llm_args.py | 2 +- ...test_flashinfer_nvfp4_sm12x_moe_backend.py | 69 +++++++++++++------ .../_torch/modules/moe/test_moe_backend.py | 11 ++- 4 files changed, 81 insertions(+), 39 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 9efff071db58..f9915b7e9106 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -34,6 +34,27 @@ def get_moe_cls( if override_quant_config is not None: quant_config = override_quant_config if moe_backend.upper() == "CUTLASS": + # Auto-promote to FlashInferNvfp4Sm12xFusedMoE (hybrid CUTLASS-prefill + # / b12x-decode) on SM120 / SM121 + NVFP4 when flashinfer is available. + # Falls back to plain CutlassFusedMoE otherwise. + if quant_config is not None and quant_config.quant_mode.has_nvfp4(): + from tensorrt_llm._utils import get_sm_version + sm_version = get_sm_version() + if sm_version in FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS: + try: + import flashinfer # noqa: F401 + logger.info( + "Auto-selecting FlashInferNvfp4Sm12xFusedMoE for hybrid " + "CUTLASS-prefill / b12x-decode (SM%d + NVFP4).", + sm_version, + ) + return FlashInferNvfp4Sm12xFusedMoE + except ImportError: + logger.warning( + "FlashInferNvfp4Sm12xFusedMoE eligible (SM%d + NVFP4) " + "but flashinfer is not importable; using CutlassFusedMoE.", + sm_version, + ) return CutlassFusedMoE elif moe_backend.upper() == "VANILLA": return VanillaMoE @@ -139,23 +160,6 @@ def get_moe_cls( "Falling back to CutlassFusedMoE.") return CutlassFusedMoE return MegaMoEDeepGemm - elif moe_backend.upper() == "FLASHINFER_NVFP4SM12X": - # FlashInferNvfp4Sm12xFusedMoE is the hybrid CUTLASS-prefill / - # b12x-decode NVFP4 MoE backend for SM120/SM121. Hard-error rather - # than silently falling back to CUTLASS so a misconfigured request - # is loud at startup, not a silent perf regression. - if quant_config is None or not quant_config.quant_mode.has_nvfp4(): - raise ValueError( - "FlashInferNvfp4Sm12xFusedMoE requires NVFP4 quantization " - f"(got quant_config={quant_config}).") - from tensorrt_llm._utils import get_sm_version - sm_version = get_sm_version() - if sm_version not in FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS: - sm_list = "/".join(f"SM{v}" for v in sorted( - FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) - raise ValueError(f"FlashInferNvfp4Sm12xFusedMoE requires {sm_list} " - f"(got SM{sm_version}).") - return FlashInferNvfp4Sm12xFusedMoE else: raise ValueError(f"Unsupported moe backend: {moe_backend}") diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 84b76b24e5e0..0a7b2297e3b5 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -594,7 +594,7 @@ class MoeConfig(StrictBaseModel): """ backend: Literal[ "AUTO", "CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", - "DENSEGEMM", "VANILLA", "TRITON", "FLASHINFER_NVFP4SM12X"] = Field( + "DENSEGEMM", "VANILLA", "TRITON"] = Field( default='AUTO', description="MoE backend to use. " "AUTO selects default backend based on model. It currently doesn\'t always give the best choice for all scenarios. The capabilities of auto selection will be improved in future releases." diff --git a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py index ea71d34b9b77..baf3d5390722 100644 --- a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py @@ -15,10 +15,11 @@ """Negative-path + dispatch tests for FlashInferNvfp4Sm12xFusedMoE. These checks run without a GPU: they verify the can_implement() gating -matrix, the hard-error policy in create_moe.get_moe_cls, and the -hybrid CUTLASS-prefill / b12x-decode dispatch predicate. Functional -correctness of the b12x kernel is covered by end-to-end model tests on -SM120/SM121 hardware. +matrix, the heuristic auto-promotion in create_moe.get_moe_cls (the +backend is selected transparently from `moe_backend=CUTLASS` on +SM120/SM121 + NVFP4), and the hybrid CUTLASS-prefill / b12x-decode +dispatch predicate. Functional correctness of the b12x kernel is +covered by end-to-end model tests on SM120/SM121 hardware. """ from unittest.mock import patch @@ -28,11 +29,13 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.create_moe import get_moe_cls +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x import ( FlashInferNvfp4Sm12xFusedMoE, ) from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig + _FUSED_MOE_MODULE = "tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x" @@ -90,41 +93,67 @@ def test_can_implement_rejects_unsupported_activation_dtype(dtype): assert reason is not None -def test_get_moe_cls_raises_on_non_nvfp4(): - """create_moe.get_moe_cls must hard-error rather than fall back silently.""" +def test_get_moe_cls_falls_back_to_cutlass_on_non_nvfp4(): + """Heuristic auto-promotion only fires on NVFP4; otherwise CUTLASS path stays.""" cfg = ModelConfig() - cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.moe_backend = "CUTLASS" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) - with pytest.raises(ValueError, match="NVFP4"): - get_moe_cls(cfg) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is CutlassFusedMoE -def test_get_moe_cls_raises_on_missing_quant(): +def test_get_moe_cls_falls_back_to_cutlass_on_missing_quant(): cfg = ModelConfig() - cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.moe_backend = "CUTLASS" cfg.quant_config = None - with pytest.raises(ValueError, match="NVFP4"): - get_moe_cls(cfg) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is CutlassFusedMoE -def test_get_moe_cls_raises_on_unsupported_sm(): +def test_get_moe_cls_falls_back_to_cutlass_on_unsupported_sm(): + """NVFP4 + non-SM120/121 must not auto-promote.""" cfg = ModelConfig() - cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.moe_backend = "CUTLASS" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) with patch("tensorrt_llm._utils.get_sm_version", return_value=100): - with pytest.raises(ValueError, match="SM"): - get_moe_cls(cfg) + cls = get_moe_cls(cfg) + assert cls is CutlassFusedMoE -def test_get_moe_cls_returns_flashinfer_on_supported_sm(): +@pytest.mark.parametrize("sm_version", sorted(FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) +def test_get_moe_cls_auto_promotes_on_supported_sm(sm_version): + """CUTLASS + NVFP4 + SM120/121 + flashinfer importable → hybrid backend.""" cfg = ModelConfig() - cfg.moe_backend = "FLASHINFER_NVFP4SM12X" + cfg.moe_backend = "CUTLASS" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) - with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + with patch("tensorrt_llm._utils.get_sm_version", return_value=sm_version): cls = get_moe_cls(cfg) assert cls is FlashInferNvfp4Sm12xFusedMoE +def test_get_moe_cls_falls_back_when_flashinfer_missing(monkeypatch): + """Eligible hardware but flashinfer not importable → CutlassFusedMoE.""" + import builtins + + cfg = ModelConfig() + cfg.moe_backend = "CUTLASS" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) + + real_import = builtins.__import__ + + def _raise_on_flashinfer(name, *args, **kwargs): + if name == "flashinfer": + raise ImportError("flashinfer not installed (simulated)") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _raise_on_flashinfer) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is CutlassFusedMoE + + # -------------------------------------------------------------------------- # Hybrid CUTLASS-prefill / b12x-decode dispatch predicate tests # diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index 5ca289048d2f..c870c80519ae 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -139,11 +139,20 @@ def create_test_backend( pretrained_config.intermediate_size = intermediate_size pretrained_config.torch_dtype = dtype + # FLASHINFER_NVFP4SM12X is internal-only: the user-facing API selects it + # transparently via the CUTLASS heuristic auto-promotion on SM120/121 + + # NVFP4. Route through "CUTLASS" so the test exercises the same code path + # users hit. + moe_backend_value = ( + "CUTLASS" + if backend_type == MoeBackendType.FLASHINFER_NVFP4SM12X + else backend_type.value + ) model_config = ModelConfig( pretrained_config=pretrained_config, quant_config=quant_config, mapping=mapping, - moe_backend=backend_type.value, + moe_backend=moe_backend_value, ) return create_moe_backend( From 47c889adbf74ca3edb2d2a2fec17b92ef3e8bc29 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Thu, 14 May 2026 11:46:19 -0700 Subject: [PATCH 07/10] [None][test][doc] add FlashInferNvfp4Sm12x to test_moe_module + trim MOE guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_moe_module.py: register MoeBackendType.FLASHINFER_NVFP4SM12X in BACKEND_TYPES so the unified ConfigurableMoE matrix exercises it. _create_model_config maps the internal enum value to moe_backend="CUTLASS" before passing into ModelConfig — the enum is internal-only after the heuristic auto-promote landed; users reach the backend via the CUTLASS path. MOE_DEVELOPER_GUIDE.md: remove the dedicated FlashInferNvfp4Sm12xFusedMoE section (composition / dispatch policy / weight-conversion algebra / hard-reject list) and drop the Nvfp4Sm12x matrix column. The class's NVFP4 support on SM120/121 is already covered by the CUTLASS row in the matrix (auto-promote target). Only the single inventory-table entry under "Backends" remains, pointing at the backend file for anyone who wants the details. Both changes respond to xxi-nv's review comments on PR #13773 asking that test_moe_module.py / test_moe_backend.py cover the new backend and that the MoE guide stay high-level. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- tests/unittest/_torch/modules/moe/test_moe_module.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index 44ba462af23c..320656727fad 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -249,6 +249,13 @@ def _create_model_config( else None ) + # FLASHINFER_NVFP4SM12X is an internal-only MoeBackendType — it has no + # corresponding user-facing MoeConfig.backend literal. Route through + # "CUTLASS" so the test exercises the heuristic auto-promotion path that + # users hit on SM120/121 + NVFP4. + if moe_backend == MoeBackendType.FLASHINFER_NVFP4SM12X.value: + moe_backend = MoeBackendType.CUTLASS.value + kwargs = dict( pretrained_config=pretrained_config, mapping=mapping, @@ -817,6 +824,7 @@ def init_worker(custom_paths, comm_method_type, master_port): MoeBackendType.DEEPGEMM, MoeBackendType.DENSEGEMM, MoeBackendType.MEGAMOE, + MoeBackendType.FLASHINFER_NVFP4SM12X, ] # Data types to test From f66f5e7c8c005ad956cf12d5205fdc6606949651 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Thu, 14 May 2026 12:55:28 -0700 Subject: [PATCH 08/10] [None][chore] ruff-format fixes for MoE test changes Pre-commit hooks flagged 3 cosmetic formatting tweaks (collapse multi-line ternary/f-string/blank-line) in the MoE test files added/edited earlier in this PR. No behaviour change. Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- tests/unittest/_torch/modules/moe/moe_test_utils.py | 5 +---- .../modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py | 1 - tests/unittest/_torch/modules/moe/test_moe_backend.py | 4 +--- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index a35d693e4691..aa7e5282069d 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -891,10 +891,7 @@ def should_skip_flashinfer_nvfp4_sm12x( f"got comm_method={comm_method}, parallel_mode={parallel_mode}." ) if moe_tp_size != 1: - return ( - f"FlashInferNvfp4Sm12xFusedMoE requires ep_size=1; " - f"got moe_tp_size={moe_tp_size}." - ) + return f"FlashInferNvfp4Sm12xFusedMoE requires ep_size=1; got moe_tp_size={moe_tp_size}." return None diff --git a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py index baf3d5390722..6f2a37bd5ff0 100644 --- a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py @@ -35,7 +35,6 @@ ) from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig - _FUSED_MOE_MODULE = "tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x" diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index c870c80519ae..0b5e4750e07a 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -144,9 +144,7 @@ def create_test_backend( # NVFP4. Route through "CUTLASS" so the test exercises the same code path # users hit. moe_backend_value = ( - "CUTLASS" - if backend_type == MoeBackendType.FLASHINFER_NVFP4SM12X - else backend_type.value + "CUTLASS" if backend_type == MoeBackendType.FLASHINFER_NVFP4SM12X else backend_type.value ) model_config = ModelConfig( pretrained_config=pretrained_config, From 58c36d188840b76a030a0b09c56c0f9e8888d009 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Thu, 21 May 2026 11:20:25 -0700 Subject: [PATCH 09/10] [None][refactor] move b12x MoE under cuteDSL family + extract quant method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses three of @xxi-nv's PR #13773 follow-up comments (May 15): - Move post_load_weights into a dedicated quantization_method (NVFP4CuteDslB12xFusedMoEMethod, sibling of NVFP4CuteDslFusedMoEMethod). Backend post_load_weights is now inherited from CutlassFusedMoE and is a thin pass-through to self.quant_method.post_load_weights(self). All b12x weight prep (SF un-normalization, convert_sf_to_mma_layout, B12xMoEWrapper instantiation, shared output buffer) lives next to the rest of the NVFP4 quant-method family. - Make the backend a member of the cuteDSL family: switch the parent class to CuteDslFusedMoE. The hybrid prefill path keeps explicit CutlassFusedMoE.method(self, ...) calls so the same C++ CUTLASS NVFP4 GroupGEMM still runs at m>=64 — the MRO change does not affect which kernels execute. create_moe.py constructor call moved into the CuteDslFusedMoE branch (narrower init signature). - Rename file / class / enum / test to match the cuteDSL family: fused_moe_flashinfer_nvfp4_sm12x.py -> fused_moe_cute_dsl_b12x.py FlashInferNvfp4Sm12xFusedMoE -> CuteDslB12xFusedMoE MoeBackendType.FLASHINFER_NVFP4SM12X -> MoeBackendType.CUTE_DSL_B12X test_flashinfer_nvfp4_sm12x_moe_backend.py -> test_cute_dsl_b12x_moe_backend.py Also adds a local output_dtype fallback in the backend's run_moe before delegating to CutlassFusedMoE.run_moe — schedulers that drive run_moe directly (the KV-cache capacity probe) leave it unset, which surfaces as a 'trtllm::fused_moe() Expected ScalarType output_dtype but instead found NoneType' on the prefill probe. Mirrors fused_moe_cutlass.py:691's forward_chunk convention; FP4-packed uint8 falls back to bf16. Validation: - 25/25 unit tests pass (test_cute_dsl_b12x_moe_backend.py, --noconftest) - trtllm-bench Nemotron-Super-120B-NVFP4 on SM120: 86.75 tok/s vs 85.92 pre-refactor baseline (HYBRID_RESULTS.md, May 7) — within 1% noise - nsys: CUTLASS sm120 block-scaled NVFP4 GroupGEMM kernels fire on prefill (m=2048, 96 calls); b12x cuteDSL MoEStatic/Dynamic/Micro kernels fire on decode (m=1, 40/80/160 calls); [b12x] quantize_input NVTX ranges present (400 calls) Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- .../modules/fused_moe/MOE_DEVELOPER_GUIDE.md | 1 + .../_torch/modules/fused_moe/__init__.py | 4 +- .../_torch/modules/fused_moe/create_moe.py | 23 +- ...p4_sm12x.py => fused_moe_cute_dsl_b12x.py} | 224 ++++++------------ .../_torch/modules/fused_moe/quantization.py | 158 ++++++++++++ .../test_lists/test-db/l0_rtx_pro_6000.yml | 2 +- .../_torch/modules/moe/moe_test_utils.py | 24 +- ...d.py => test_cute_dsl_b12x_moe_backend.py} | 32 +-- .../_torch/modules/moe/test_moe_backend.py | 6 +- .../_torch/modules/moe/test_moe_module.py | 6 +- 10 files changed, 280 insertions(+), 200 deletions(-) rename tensorrt_llm/_torch/modules/fused_moe/{fused_moe_flashinfer_nvfp4_sm12x.py => fused_moe_cute_dsl_b12x.py} (52%) rename tests/unittest/_torch/modules/moe/{test_flashinfer_nvfp4_sm12x_moe_backend.py => test_cute_dsl_b12x_moe_backend.py} (85%) diff --git a/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md b/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md index 06735afe0564..3dba9fbf9ded 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md +++ b/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md @@ -147,6 +147,7 @@ Still on old path (standalone, with embedded communication): | `fused_moe_deepgemm.py` | `DeepGemmFusedMoE` | SM100/SM103 | FP8 Block Scales on Blackwell | `EXTERNAL_COMM` | | `fused_moe_densegemm.py` | `DenseGEMMFusedMoE` | SM100/SM103 | NVFP4 min-latency; CuTe DSL dense GEMM packs all experts into one matrix (vs Cutlass per-expert scatter), efficient for small token counts | `EXTERNAL_COMM` | | `fused_moe_cute_dsl.py` | `CuteDslFusedMoE` | SM100/SM103 | High throughput NVFP4, generally faster than Cutlass | `EXTERNAL_COMM` | +| `fused_moe_cute_dsl_b12x.py` | `CuteDslB12xFusedMoE` | SM120/SM121 | NVFP4; auto-selected on the `CUTLASS` path | `EXTERNAL_COMM` | | `mega_moe/mega_moe_deepgemm.py` | `MegaMoEDeepGemm` | SM100/SM103 | W4A8_MXFP4_MXFP8 via DeepGEMM `fp8_fp4_mega_moe` fused dispatch+GEMM+act+GEMM+combine kernel; requires `hidden_size % 512 == 0` | `FUSED_COMM` | | `fused_moe_triton.py` | `TritonFusedMoE` | SM90 only | GPT-OSS on Hopper (requires `swiglu_gptoss_style=True`) | (legacy path) | | `fused_moe_wide_ep.py` | `WideEPMoE` | All GPUs | Deprecating — use ConfigurableMoE instead | (legacy path) | diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index 105ea3ddb3dd..5cf48dc4fe2c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -1,7 +1,7 @@ from .create_moe import create_moe, get_moe_cls from .fused_moe_cute_dsl import CuteDslFusedMoE +from .fused_moe_cute_dsl_b12x import CuteDslB12xFusedMoE from .fused_moe_cutlass import CutlassFusedMoE -from .fused_moe_flashinfer_nvfp4_sm12x import FlashInferNvfp4Sm12xFusedMoE from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE @@ -27,11 +27,11 @@ "BaseMoeRoutingMethod", "create_load_balanced_logits", "create_moe", + "CuteDslB12xFusedMoE", "CuteDslFusedMoE", "CutlassFusedMoE", "DeepSeekV3MoeRoutingMethod", "DefaultMoeRoutingMethod", - "FlashInferNvfp4Sm12xFusedMoE", "FusedMoEQuantScalesFP8", "get_moe_cls", "Llama4RenormalizeMoeRoutingMethod", diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index f9915b7e9106..69a1e23b8b35 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -12,10 +12,10 @@ from ...utils import ActivationType, AuxStreamType from .configurable_moe import ConfigurableMoE from .fused_moe_cute_dsl import CuteDslFusedMoE +from .fused_moe_cute_dsl_b12x import CuteDslB12xFusedMoE from .fused_moe_cutlass import CutlassFusedMoE from .fused_moe_deepgemm import DeepGemmFusedMoE from .fused_moe_densegemm import DenseGEMMFusedMoE -from .fused_moe_flashinfer_nvfp4_sm12x import FlashInferNvfp4Sm12xFusedMoE from .fused_moe_triton import TritonFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE @@ -34,24 +34,24 @@ def get_moe_cls( if override_quant_config is not None: quant_config = override_quant_config if moe_backend.upper() == "CUTLASS": - # Auto-promote to FlashInferNvfp4Sm12xFusedMoE (hybrid CUTLASS-prefill + # Auto-promote to CuteDslB12xFusedMoE (hybrid CUTLASS-prefill # / b12x-decode) on SM120 / SM121 + NVFP4 when flashinfer is available. # Falls back to plain CutlassFusedMoE otherwise. if quant_config is not None and quant_config.quant_mode.has_nvfp4(): from tensorrt_llm._utils import get_sm_version sm_version = get_sm_version() - if sm_version in FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS: + if sm_version in CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS: try: import flashinfer # noqa: F401 logger.info( - "Auto-selecting FlashInferNvfp4Sm12xFusedMoE for hybrid " + "Auto-selecting CuteDslB12xFusedMoE for hybrid " "CUTLASS-prefill / b12x-decode (SM%d + NVFP4).", sm_version, ) - return FlashInferNvfp4Sm12xFusedMoE + return CuteDslB12xFusedMoE except ImportError: logger.warning( - "FlashInferNvfp4Sm12xFusedMoE eligible (SM%d + NVFP4) " + "CuteDslB12xFusedMoE eligible (SM%d + NVFP4) " "but flashinfer is not importable; using CutlassFusedMoE.", sm_version, ) @@ -302,9 +302,10 @@ def create_moe_backend( without_comm=without_comm, activation_type=activation_type, ) - elif moe_cls in (CutlassFusedMoE, FlashInferNvfp4Sm12xFusedMoE): - # CuteDslFusedMoE and DeepGemmFusedMoE also subclass CutlassFusedMoE but - # have narrower constructors, so they take their own branches below. + elif moe_cls is CutlassFusedMoE: + # CuteDslFusedMoE, DeepGemmFusedMoE, and CuteDslB12xFusedMoE + # also subclass CutlassFusedMoE but have narrower constructors, so + # they take their own branches below. return moe_cls( routing_method=routing_method, num_experts=num_experts, @@ -355,7 +356,9 @@ def create_moe_backend( layer_idx=layer_idx, activation_type=activation_type, ) - elif moe_cls == CuteDslFusedMoE: + elif moe_cls in (CuteDslFusedMoE, CuteDslB12xFusedMoE): + # CuteDslB12xFusedMoE subclasses CuteDslFusedMoE and shares + # its narrower constructor (no bias / swiglu_alpha-beta-limit args). return moe_cls( routing_method=routing_method, num_experts=num_experts, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py similarity index 52% rename from tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py rename to tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py index 83b6fa62571a..101d3d913ede 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_flashinfer_nvfp4_sm12x.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py @@ -18,10 +18,10 @@ import torch from tensorrt_llm._utils import get_sm_version, nvtx_range -from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantAlgo from ...utils import ActivationType, Fp4QuantizedTensor +from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE from .interface import _warn_and_return @@ -43,22 +43,34 @@ } -class FlashInferNvfp4Sm12xFusedMoE(CutlassFusedMoE): +class CuteDslB12xFusedMoE(CuteDslFusedMoE): """Hybrid CUTLASS-prefill / b12x-decode NVFP4 fused-MoE backend for SM120 / SM121. + Member of the cuteDSL backend family: the decode kernel + (``flashinfer.B12xMoEWrapper.run``) is JIT-compiled CuTe DSL, so the + backend slots in next to :class:`CuteDslFusedMoE` (which targets SM100 / + SM103). The hybrid prefill path still routes through the C++ CUTLASS + NVFP4 GroupGEMM via explicit :class:`CutlassFusedMoE` method calls; the + parent class on the MRO does not change which kernels execute, only + where the b12x backend sits in the family. + Composition (see ``MOE_DEVELOPER_GUIDE.md`` for the full explainer): - - **Prefill (``m >= _PREFILL_VIA_CUTLASS_THRESHOLD``)** routes through the - inherited :class:`CutlassFusedMoE` NVFP4 GroupGEMM. The b12x kernel's + - **Prefill (``m >= _PREFILL_VIA_CUTLASS_THRESHOLD``)** explicitly + invokes :class:`CutlassFusedMoE` NVFP4 GroupGEMM. The b12x kernel's 12-CTA-per-token MMA pattern is suboptimal at large ``m``. - **Decode (``m < _PREFILL_VIA_CUTLASS_THRESHOLD``)** dispatches to FlashInfer's ``B12xMoEWrapper.run`` — a kernel purpose-built for ``m=1`` / small routed-row counts. - NVFP4 weights are loaded once via the inherited NVFP4 quant method; - ``post_load_weights`` then prepares the b12x-shaped weight tensors - alongside the existing CUTLASS layout. Both layouts coexist in memory - and the dispatcher picks per call based on ``x.shape[0]``. + NVFP4 weights are loaded via :class:`NVFP4CuteDslB12xFusedMoEMethod` + (an :class:`NVFP4CutlassFusedMoEMethod` subclass returned by + ``_get_quant_method``). The inherited CUTLASS NVFP4 layout is finalised + by the base class, and the b12x-shaped tensors (un-normalised FP8 SF, + ``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instance) are + materialised on top by the quant method's ``post_load_weights``. Both + layouts coexist in memory and the dispatcher picks per call based on + ``x.shape[0]``. CUDA graph capture only covers decode, so captured graphs always replay the b12x path; eager prefill always runs CUTLASS — there is no graph @@ -67,7 +79,8 @@ class FlashInferNvfp4Sm12xFusedMoE(CutlassFusedMoE): The backend hard-rejects EP (b12x has no dispatch / combine kernel), MoE alltoall, ``Fp4QuantizedTensor`` input, ``swiglu_gptoss_style`` biased SwiGLU, and activations outside ``{Relu2, Swiglu}``. It is - selected via ``moe_config.backend: FLASHINFER_NVFP4SM12X``. + auto-selected on the ``CUTLASS`` MoE path when SM120 / SM121 + NVFP4 + + flashinfer-importable gates pass (see ``create_moe.get_moe_cls``). """ # SM versions on which the FlashInfer b12x NVFP4 MoE kernel is available. @@ -91,23 +104,19 @@ def can_implement( sm_version = get_sm_version() if sm_version not in cls._SUPPORTED_SM_VERSIONS: sm_list = "/".join(f"SM{v}" for v in sorted(cls._SUPPORTED_SM_VERSIONS)) - return _warn_and_return( - f"FlashInferNvfp4Sm12xFusedMoE requires {sm_list}, got SM{sm_version}" - ) + return _warn_and_return(f"CuteDslB12xFusedMoE requires {sm_list}, got SM{sm_version}") if quant_algo != QuantAlgo.NVFP4: return _warn_and_return( - f"FlashInferNvfp4Sm12xFusedMoE only supports NVFP4 quantization " + f"CuteDslB12xFusedMoE only supports NVFP4 quantization " f"(got quant_algo={quant_algo})" ) if dtype_activation not in {torch.float16, torch.bfloat16}: return _warn_and_return( - f"FlashInferNvfp4Sm12xFusedMoE NVFP4 requires float16 or bfloat16 " + f"CuteDslB12xFusedMoE NVFP4 requires float16 or bfloat16 " f"activation dtype (got {dtype_activation})" ) if swiglu_gptoss_style: - return _warn_and_return( - "FlashInferNvfp4Sm12xFusedMoE does not support swiglu_gptoss_style" - ) + return _warn_and_return("CuteDslB12xFusedMoE does not support swiglu_gptoss_style") return True, None def __init__(self, *args, **kwargs): @@ -124,17 +133,15 @@ def __init__(self, *args, **kwargs): # disabled. dp_size > 1 implies the alltoall path which b12x can't run. if self.ep_size != 1: raise ValueError( - f"FlashInferNvfp4Sm12xFusedMoE requires ep_size == 1 " + f"CuteDslB12xFusedMoE requires ep_size == 1 " f"(got ep_size={self.ep_size}); use --moe_backend CUTLASS for EP." ) if self.enable_alltoall: - raise ValueError( - "FlashInferNvfp4Sm12xFusedMoE does not support MoE alltoall communication." - ) + raise ValueError("CuteDslB12xFusedMoE does not support MoE alltoall communication.") if self.activation_type not in _ACTIVATION_MAP: supported = ", ".join(a.name for a in _ACTIVATION_MAP) raise ValueError( - f"FlashInferNvfp4Sm12xFusedMoE does not support activation " + f"CuteDslB12xFusedMoE does not support activation " f"{ActivationType(self.activation_type).name}; " f"supported: {supported}." ) @@ -142,6 +149,22 @@ def __init__(self, *args, **kwargs): self._b12x_weights: Optional[dict] = None self.b12x_wrapper = None + def _get_quant_method(self): + # Route NVFP4 to the b12x-aware quant method so weight prep + # (SF un-normalization, ``convert_sf_to_mma_layout``, + # ``B12xMoEWrapper`` instantiation) lives next to the rest of the + # NVFP4 quant-method family, while every other quant algo (and the + # unquantized fallback) continues to resolve via the parent. + if ( + self.quant_config is not None + and self.quant_config.layer_quant_mode.has_any_quant(exclude_kv_cache=True) + and self.quant_config.layer_quant_mode.has_nvfp4() + ): + from .quantization import NVFP4CuteDslB12xFusedMoEMethod + + return NVFP4CuteDslB12xFusedMoEMethod() + return super()._get_quant_method() + def _route_to_cutlass(self, x) -> bool: """Return ``True`` iff this call should fall back to the inherited CUTLASS path (prefill chunk). ``Fp4QuantizedTensor`` inputs always @@ -149,127 +172,14 @@ def _route_to_cutlass(self, x) -> bool: message is preserved.""" return isinstance(x, torch.Tensor) and x.shape[0] >= self._PREFILL_VIA_CUTLASS_THRESHOLD - def post_load_weights(self): - """Build the b12x weight dict and instantiate ``B12xMoEWrapper``. - - Called by ``model_loader`` after ``load_weights`` finishes. The NVFP4 - quant method's ``process_weights_after_loading`` has already run as - part of ``load_weights``, so the inherited ``w3_w1_weight`` / - ``w2_weight`` / ``*_weight_scale`` / ``*_alpha`` / ``*_input_scale`` - tensors are populated; we just convert them to the layout b12x - expects. - """ - super().post_load_weights() - - try: - from flashinfer import B12xMoEWrapper - from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout - except ImportError as e: - raise RuntimeError( - "FlashInferNvfp4Sm12xFusedMoE requires the `flashinfer` package " - "(B12xMoEWrapper, cute_dsl.utils.convert_sf_to_mma_layout). " - f"Original import error: {e}" - ) from e - - num_local_experts = self.w3_w1_weight.shape[0] - # Tensor shapes use the *padded* per-rank dims because TP partitions - # may pad ``intermediate_size`` up to a kernel-friendly boundary. - # Recover them from the actual stored tensors rather than the logical - # model config so reshapes stay valid under TP > 1. - _, w3w1_out_dim, _ = self.w3_w1_weight.shape # (E, 2*I_pad, H//16) - _, w2_out_dim, w2_in_packed = self.w2_weight.shape # (E, H, I_pad//16) - w3w1_in_dim = self.hidden_size - w2_in_dim = w2_in_packed * 16 - - # b12x reuses the per-expert ``w1_alpha`` tensor as both (a) the - # online activation-quant ``global_scale`` and (b) the FC1 epilogue - # output-dequant multiplier. That dual use is only self-consistent - # when the FP4 weight block scales are stored in their *unnormalized* - # form (raw ``max_block / FP4_MAX``), not divided out by the - # per-tensor ``weight_scale_2``. HF / ModelOpt NVFP4 checkpoints - # store the normalized variant so the FP8 block scales fit in range, - # and TRT-LLM's NVFP4 loader preserves that form. To match b12x's - # convention we recover ``weight_scale_2 = fc_alpha * fc_input_scale`` - # and multiply each expert's FP8 block scales by it before handing - # them to ``convert_sf_to_mma_layout``. With the un-normalized scales - # in place we pass ``w1_alpha = w2_alpha = 1 / fc_input_scale`` - # (== ``s_in``) so the kernel's dual-use cancels algebraically and - # the stored input-side block scales remain FP8-representable. - w1_w_scale_2 = (self.fc31_alpha * self.fc31_input_scale).to(torch.float32) - w2_w_scale_2 = (self.fc2_alpha * self.fc2_input_scale).to(torch.float32) - - w1_sf_fp8_norm = self.w3_w1_weight_scale.view(torch.float8_e4m3fn).float() - w2_sf_fp8_norm = self.w2_weight_scale.view(torch.float8_e4m3fn).float() - - # Broadcast per-expert scalar over the trailing dims (E, *). - bcast1 = w1_w_scale_2.view(-1, *([1] * (w1_sf_fp8_norm.dim() - 1))) - bcast2 = w2_w_scale_2.view(-1, *([1] * (w2_sf_fp8_norm.dim() - 1))) - w1_sf_fp8 = (w1_sf_fp8_norm * bcast1).to(torch.float8_e4m3fn) - w2_sf_fp8 = (w2_sf_fp8_norm * bcast2).to(torch.float8_e4m3fn) - - w1_sf_b12x = convert_sf_to_mma_layout( - w1_sf_fp8, m=w3w1_out_dim, k=w3w1_in_dim, num_groups=num_local_experts - ) - w2_sf_b12x = convert_sf_to_mma_layout( - w2_sf_fp8, m=w2_out_dim, k=w2_in_dim, num_groups=num_local_experts - ) - - w1_alpha_b12x = ( - (1.0 / self.fc31_input_scale).expand(self.num_experts).to(torch.float32).contiguous() - ) - w2_alpha_b12x = ( - (1.0 / self.fc2_input_scale).expand(self.num_experts).to(torch.float32).contiguous() - ) - fc2_input_scale_b12x = (1.0 / self.fc2_input_scale).to(torch.float32) - - # TRT-LLM packs 16 FP4 values per int64. flashinfer's internal - # ``view(torch.float4_e2m1fn_x2)`` requires byte-contiguous storage - # (stride[-1] == 1 in bytes); a uint8 view of the int64 tensor - # provides that without copying. - self._b12x_weights = dict( - w1_weight=self.w3_w1_weight.view(torch.uint8), - w1_weight_sf=w1_sf_b12x, - w1_alpha=w1_alpha_b12x, - w2_weight=self.w2_weight.view(torch.uint8), - w2_weight_sf=w2_sf_b12x, - w2_alpha=w2_alpha_b12x, - fc2_input_scale=fc2_input_scale_b12x, - ) - - self.b12x_wrapper = B12xMoEWrapper( - num_experts=self.num_experts, - top_k=self.routing_method.experts_per_token, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size_per_partition, - use_cuda_graph=self._b12x_use_cuda_graph, - max_num_tokens=self.moe_max_num_tokens, - activation=_ACTIVATION_MAP[self.activation_type], - ) - - # Replace the wrapper's per-instance output buffer with a shared one. - # Layers run sequentially on a single stream, so a single buffer of the - # right shape is correct and saves - # ``(num_moe_layers - 1) * max_num_tokens * hidden_size * 2`` bytes — - # ~2.5 GB on Nemotron-Super-120B with ``max_num_tokens=2048``, - # ``hidden=8192``, bf16, 80 MoE layers. - if self.b12x_wrapper._moe_output is not None: - buf = self.b12x_wrapper._moe_output - key = (buf.shape[0], buf.shape[1], buf.dtype, str(buf.device)) - shared = _SHARED_MOE_OUTPUT_BUF.get(key) - if shared is None: - _SHARED_MOE_OUTPUT_BUF[key] = buf - else: - # Free the freshly allocated buffer; reuse the existing one. - self.b12x_wrapper._moe_output = shared - - logger.info_once( - f"FlashInferNvfp4Sm12xFusedMoE active: hidden={self.hidden_size}, " - f"intermediate={self.intermediate_size_per_partition}, " - f"experts={self.num_experts}, top_k=" - f"{self.routing_method.experts_per_token}, " - f"activation={_ACTIVATION_MAP[self.activation_type]}.", - key="flashinfer_nvfp4_sm12x_moe_active", - ) + # ``post_load_weights`` is inherited from ``CutlassFusedMoE`` and + # dispatches to ``self.quant_method.post_load_weights(self)`` — for this + # backend ``self.quant_method`` is ``NVFP4CuteDslB12xFusedMoEMethod`` + # (see ``_get_quant_method`` override), which performs the SF un-normalization, + # ``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instantiation, + # and the cross-layer shared output buffer dance. The wrapper and the + # bundled weight dict are attached to this module as ``self.b12x_wrapper`` + # / ``self._b12x_weights``, which the decode path below consumes. @nvtx_range("[b12x] quantize_input") def quantize_input( @@ -293,7 +203,7 @@ def quantize_input( ) if isinstance(x, Fp4QuantizedTensor): raise ValueError( - "FlashInferNvfp4Sm12xFusedMoE does not accept Fp4QuantizedTensor input " + "CuteDslB12xFusedMoE does not accept Fp4QuantizedTensor input " "on the b12x decode path; b12x performs its own input quantization." ) return x, None @@ -313,6 +223,23 @@ def run_moe( enable_alltoall: Optional[bool] = None, ) -> torch.Tensor: if self._route_to_cutlass(x): + # ``CutlassFusedMoE.run_moe`` forwards ``output_dtype`` straight + # into the C++ ``trtllm::fused_moe`` op, which requires a concrete + # high-precision ``ScalarType`` (uint8 / FP4-packed activations are + # rejected at the kernel epilogue with "Invalid output type Byte"). + # Schedulers that drive ``run_moe`` directly (the KV-cache capacity + # probe, for one) leave ``output_dtype`` unset, so fall back to + # ``x.dtype`` if it is a real compute dtype, else bf16. Mirrors the + # ``forward_chunk`` convention while staying safe for the FP4 + # quant-input path (``x`` is uint8 after ``quantize_input``). + _HIGH_PRECISION = {torch.float16, torch.bfloat16, torch.float32} + cutlass_output_dtype = output_dtype + if cutlass_output_dtype is None: + cutlass_output_dtype = ( + x.dtype + if isinstance(x, torch.Tensor) and x.dtype in _HIGH_PRECISION + else torch.bfloat16 + ) return CutlassFusedMoE.run_moe( self, x, @@ -320,7 +247,7 @@ def run_moe( token_final_scales=token_final_scales, x_sf=x_sf, is_sf_swizzled=is_sf_swizzled, - output_dtype=output_dtype, + output_dtype=cutlass_output_dtype, tuner_num_tokens=tuner_num_tokens, tuner_top_k=tuner_top_k, moe_output=moe_output, @@ -328,12 +255,11 @@ def run_moe( ) if self.b12x_wrapper is None or self._b12x_weights is None: raise RuntimeError( - "FlashInferNvfp4Sm12xFusedMoE.run_moe called before " - "process_weights_after_loading completed." + "CuteDslB12xFusedMoE.run_moe called before process_weights_after_loading completed." ) if x_sf is not None: raise ValueError( - "FlashInferNvfp4Sm12xFusedMoE expects unquantized input (x_sf=None) " + "CuteDslB12xFusedMoE expects unquantized input (x_sf=None) " "on the b12x decode path; got a precomputed scale factor." ) @@ -349,7 +275,7 @@ def run_moe( # B12xMoEWrapper allocates its own output buffer for CUDA-graph # compatibility. If the caller provided ``moe_output`` (e.g. an - # alltoall workspace tensor), copy into it; FlashInferNvfp4Sm12xFusedMoE + # alltoall workspace tensor), copy into it; CuteDslB12xFusedMoE # currently rejects alltoall in __init__, so this is a defensive # path for future workspace-driven uses. if moe_output is not None: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index da5109091501..e6769b241646 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -2974,6 +2974,164 @@ def process_weights_after_loading(self, module: torch.nn.Module): module, module.w3_w1_weight_scale.data[expert_idx]) +class NVFP4CuteDslB12xFusedMoEMethod(NVFP4CutlassFusedMoEMethod): + """NVFP4 quant method for the FlashInfer B12x MoE backend (SM120 / SM121). + + Inherits the full CUTLASS NVFP4 weight pipeline (cat + pad + + block_scale_interleave + setup_quant_scales) so the backend's + hybrid prefill path can continue to consume the standard CUTLASS + NVFP4 GroupGEMM layout via the inherited ``CutlassFusedMoE.run_moe``. + + On top of that base layout, ``post_load_weights`` materialises the + b12x-specific weight tensors: SF un-normalization (multiply per-block + FP8 scales by ``weight_scale_2 = fc_alpha * fc_input_scale``), + ``convert_sf_to_mma_layout`` reshape, per-expert ``w*_alpha = 1 / + fc_input_scale`` vectors, and a ``B12xMoEWrapper`` instance with a + shared cross-layer output buffer. The wrapper and the bundled + weight dict are attached to the MoE module as ``module.b12x_wrapper`` + / ``module._b12x_weights`` so the backend's ``run_moe`` can dispatch + to them on decode-shape inputs without holding any kernel-prep logic + of its own. + """ + + # b12x exposes two activation strings today: ``relu2`` (Nemotron-style + # ``x * relu(x)``) and ``silu`` (SwiGLU-style ``x * silu(gate)``). + _ACTIVATION_MAP = { + ActivationType.Relu2: "relu2", + ActivationType.Swiglu: "silu", + } + + def post_load_weights(self, module: torch.nn.Module): + # Base class handles shared-weight finalize, load-balancer init, + # and setup_quant_scales. Leaves the standard CUTLASS NVFP4 + # weight + SF layout in place for the inherited prefill path. + super().post_load_weights(module) + + try: + from flashinfer import B12xMoEWrapper + from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout + except ImportError as e: + raise RuntimeError( + "NVFP4CuteDslB12xFusedMoEMethod requires the `flashinfer` package " + "(B12xMoEWrapper, cute_dsl.utils.convert_sf_to_mma_layout). " + f"Original import error: {e}") from e + + # ``_SHARED_MOE_OUTPUT_BUF`` lives next to the backend so all MoE + # layers across the model share one wrapper-owned output buffer. + from .fused_moe_cute_dsl_b12x import _SHARED_MOE_OUTPUT_BUF + + num_local_experts = module.w3_w1_weight.shape[0] + # Tensor shapes use the *padded* per-rank dims because TP partitions + # may pad ``intermediate_size`` up to a kernel-friendly boundary. + # Recover them from the actual stored tensors rather than the logical + # model config so reshapes stay valid under TP > 1. + _, w3w1_out_dim, _ = module.w3_w1_weight.shape # (E, 2*I_pad, H//16) + _, w2_out_dim, w2_in_packed = module.w2_weight.shape # (E, H, I_pad//16) + w3w1_in_dim = module.hidden_size + w2_in_dim = w2_in_packed * 16 + + # b12x reuses the per-expert ``w1_alpha`` tensor as both (a) the + # online activation-quant ``global_scale`` and (b) the FC1 epilogue + # output-dequant multiplier. That dual use is only self-consistent + # when the FP4 weight block scales are stored in their *unnormalized* + # form (raw ``max_block / FP4_MAX``), not divided out by the + # per-tensor ``weight_scale_2``. HF / ModelOpt NVFP4 checkpoints + # store the normalized variant so the FP8 block scales fit in range, + # and TRT-LLM's NVFP4 loader preserves that form. To match b12x's + # convention we recover ``weight_scale_2 = fc_alpha * fc_input_scale`` + # and multiply each expert's FP8 block scales by it before handing + # them to ``convert_sf_to_mma_layout``. With the un-normalized scales + # in place we pass ``w1_alpha = w2_alpha = 1 / fc_input_scale`` + # (== ``s_in``) so the kernel's dual-use cancels algebraically and + # the stored input-side block scales remain FP8-representable. + w1_w_scale_2 = (module.fc31_alpha * module.fc31_input_scale).to( + torch.float32) + w2_w_scale_2 = (module.fc2_alpha * module.fc2_input_scale).to( + torch.float32) + + w1_sf_fp8_norm = module.w3_w1_weight_scale.view( + torch.float8_e4m3fn).float() + w2_sf_fp8_norm = module.w2_weight_scale.view( + torch.float8_e4m3fn).float() + + # Broadcast per-expert scalar over the trailing dims (E, *). + bcast1 = w1_w_scale_2.view(-1, *([1] * (w1_sf_fp8_norm.dim() - 1))) + bcast2 = w2_w_scale_2.view(-1, *([1] * (w2_sf_fp8_norm.dim() - 1))) + w1_sf_fp8 = (w1_sf_fp8_norm * bcast1).to(torch.float8_e4m3fn) + w2_sf_fp8 = (w2_sf_fp8_norm * bcast2).to(torch.float8_e4m3fn) + + w1_sf_b12x = convert_sf_to_mma_layout(w1_sf_fp8, + m=w3w1_out_dim, + k=w3w1_in_dim, + num_groups=num_local_experts) + w2_sf_b12x = convert_sf_to_mma_layout(w2_sf_fp8, + m=w2_out_dim, + k=w2_in_dim, + num_groups=num_local_experts) + + w1_alpha_b12x = ((1.0 / module.fc31_input_scale).expand( + module.num_experts).to(torch.float32).contiguous()) + w2_alpha_b12x = ((1.0 / module.fc2_input_scale).expand( + module.num_experts).to(torch.float32).contiguous()) + fc2_input_scale_b12x = (1.0 / module.fc2_input_scale).to(torch.float32) + + # TRT-LLM packs 16 FP4 values per int64. flashinfer's internal + # ``view(torch.float4_e2m1fn_x2)`` requires byte-contiguous storage + # (stride[-1] == 1 in bytes); a uint8 view of the int64 tensor + # provides that without copying. + module._b12x_weights = dict( + w1_weight=module.w3_w1_weight.view(torch.uint8), + w1_weight_sf=w1_sf_b12x, + w1_alpha=w1_alpha_b12x, + w2_weight=module.w2_weight.view(torch.uint8), + w2_weight_sf=w2_sf_b12x, + w2_alpha=w2_alpha_b12x, + fc2_input_scale=fc2_input_scale_b12x, + ) + + if module.activation_type not in self._ACTIVATION_MAP: + supported = ", ".join(a.name for a in self._ACTIVATION_MAP) + raise ValueError( + f"NVFP4CuteDslB12xFusedMoEMethod does not support activation " + f"{ActivationType(module.activation_type).name}; " + f"supported: {supported}.") + + module.b12x_wrapper = B12xMoEWrapper( + num_experts=module.num_experts, + top_k=module.routing_method.experts_per_token, + hidden_size=module.hidden_size, + intermediate_size=module.intermediate_size_per_partition, + use_cuda_graph=getattr(module, "_b12x_use_cuda_graph", False), + max_num_tokens=module.moe_max_num_tokens, + activation=self._ACTIVATION_MAP[module.activation_type], + ) + + # Replace the wrapper's per-instance output buffer with a shared one. + # Layers run sequentially on a single stream, so a single buffer of the + # right shape is correct and saves + # ``(num_moe_layers - 1) * max_num_tokens * hidden_size * 2`` bytes — + # ~2.5 GB on Nemotron-Super-120B with ``max_num_tokens=2048``, + # ``hidden=8192``, bf16, 80 MoE layers. + if module.b12x_wrapper._moe_output is not None: + buf = module.b12x_wrapper._moe_output + key = (buf.shape[0], buf.shape[1], buf.dtype, str(buf.device)) + shared = _SHARED_MOE_OUTPUT_BUF.get(key) + if shared is None: + _SHARED_MOE_OUTPUT_BUF[key] = buf + else: + # Free the freshly allocated buffer; reuse the existing one. + module.b12x_wrapper._moe_output = shared + + logger.info_once( + f"NVFP4CuteDslB12xFusedMoEMethod active: hidden={module.hidden_size}, " + f"intermediate={module.intermediate_size_per_partition}, " + f"experts={module.num_experts}, top_k=" + f"{module.routing_method.experts_per_token}, " + f"activation={self._ACTIVATION_MAP[module.activation_type]}.", + key="cute_dsl_b12x_moe_active", + ) + + class NVFP4TRTLLMGenFusedMoEBaseMethod(NVFP4FusedMoEMethod): weight_dtype = float4_sf_dtype block_scales_dtype = torch.float8_e4m3fn diff --git a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml index d4ec70fdac04..376139457596 100644 --- a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml +++ b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml @@ -15,7 +15,7 @@ l0_rtx_pro_6000: tests: # ------------- PyTorch tests --------------- - unittest/_torch/modeling -k "modeling_out_of_tree" - - unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py + - unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py # - unittest/_torch/modeling -k "modeling_qwen" # https://nvbugs/5234573 - unittest/_torch/attention/test_attention_mla.py # SM120 W4A16 / W4A8 mixed-dtype GEMM coverage (paired with FinegrainedMixedDtypeGemm diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index aa7e5282069d..abbaaab80236 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -44,11 +44,9 @@ CutlassFusedMoE, TRTLLMGenFusedMoE, ) +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl_b12x import CuteDslB12xFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import DeepGemmFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_densegemm import DenseGEMMFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x import ( - FlashInferNvfp4Sm12xFusedMoE, -) from tensorrt_llm._torch.modules.fused_moe.interface import MoE from tensorrt_llm._torch.modules.fused_moe.mega_moe import MegaMoEDeepGemm from tensorrt_llm._torch.utils import ActivationType, is_gated_activation @@ -69,7 +67,7 @@ class MoeBackendType(str, Enum): DEEPGEMM = "DEEPGEMM" DENSEGEMM = "DENSEGEMM" MEGAMOE = "MEGAMOE_DEEPGEMM" - FLASHINFER_NVFP4SM12X = "FLASHINFER_NVFP4SM12X" + CUTE_DSL_B12X = "CUTE_DSL_B12X" def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: @@ -81,7 +79,7 @@ def get_backend_class(backend_type: MoeBackendType) -> Type[MoE]: MoeBackendType.DEEPGEMM: DeepGemmFusedMoE, MoeBackendType.DENSEGEMM: DenseGEMMFusedMoE, MoeBackendType.MEGAMOE: MegaMoEDeepGemm, - MoeBackendType.FLASHINFER_NVFP4SM12X: FlashInferNvfp4Sm12xFusedMoE, + MoeBackendType.CUTE_DSL_B12X: CuteDslB12xFusedMoE, } return backend_class_map[backend_type] @@ -869,29 +867,29 @@ def should_skip_megamoe( return None -def should_skip_flashinfer_nvfp4_sm12x( +def should_skip_cute_dsl_b12x( backend_type: MoeBackendType, comm_method: Optional[str] = None, moe_tp_size: int = 1, parallel_mode: Optional[str] = None, ) -> Optional[str]: - """Check FlashInferNvfp4Sm12xFusedMoE constraints not covered by can_implement(). + """Check CuteDslB12xFusedMoE constraints not covered by can_implement(). can_implement() already gates SM version, quant_algo, dtype_activation, and swiglu_gptoss_style. This helper covers the additional EP / alltoall hard rejects enforced in __init__ (b12x has no expert-parallel dispatch/combine kernel). """ - if backend_type != MoeBackendType.FLASHINFER_NVFP4SM12X: + if backend_type != MoeBackendType.CUTE_DSL_B12X: return None if comm_method is not None or parallel_mode is not None: return ( - "FlashInferNvfp4Sm12xFusedMoE rejects expert parallelism / alltoall; " + "CuteDslB12xFusedMoE rejects expert parallelism / alltoall; " f"got comm_method={comm_method}, parallel_mode={parallel_mode}." ) if moe_tp_size != 1: - return f"FlashInferNvfp4Sm12xFusedMoE requires ep_size=1; got moe_tp_size={moe_tp_size}." + return f"CuteDslB12xFusedMoE requires ep_size=1; got moe_tp_size={moe_tp_size}." return None @@ -1001,11 +999,11 @@ def supports_autotuner_capture( Returns: True if autotuner capture/replay is supported, False otherwise """ - # DEEPGEMM, MEGAMOE, and FLASHINFER_NVFP4SM12X do not support autotuner capture + # DEEPGEMM, MEGAMOE, and CUTE_DSL_B12X do not support autotuner capture if backend_type in ( MoeBackendType.DEEPGEMM, MoeBackendType.MEGAMOE, - MoeBackendType.FLASHINFER_NVFP4SM12X, + MoeBackendType.CUTE_DSL_B12X, ): return False @@ -1085,7 +1083,7 @@ def get_quick_skip_reason( model_config=model_config, swiglu_gptoss_style=swiglu_gptoss_style, ), - lambda: should_skip_flashinfer_nvfp4_sm12x(backend_type), + lambda: should_skip_cute_dsl_b12x(backend_type), ] for check in skip_checks: skip_reason = check() diff --git a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py b/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py similarity index 85% rename from tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py rename to tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py index 6f2a37bd5ff0..776d4fa1276a 100644 --- a/tests/unittest/_torch/modules/moe/test_flashinfer_nvfp4_sm12x_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Negative-path + dispatch tests for FlashInferNvfp4Sm12xFusedMoE. +"""Negative-path + dispatch tests for CuteDslB12xFusedMoE. These checks run without a GPU: they verify the can_implement() gating matrix, the heuristic auto-promotion in create_moe.get_moe_cls (the @@ -29,28 +29,26 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.create_moe import get_moe_cls +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl_b12x import CuteDslB12xFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE -from tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x import ( - FlashInferNvfp4Sm12xFusedMoE, -) from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig -_FUSED_MOE_MODULE = "tensorrt_llm._torch.modules.fused_moe.fused_moe_flashinfer_nvfp4_sm12x" +_FUSED_MOE_MODULE = "tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl_b12x" @pytest.mark.parametrize("sm_version", [80, 89, 90, 100, 103]) def test_can_implement_rejects_unsupported_sm(sm_version): """can_implement returns False on every SM outside the supported set.""" with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=sm_version): - ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(QuantAlgo.NVFP4) + ok, reason = CuteDslB12xFusedMoE.can_implement(QuantAlgo.NVFP4) assert not ok assert reason is not None and f"SM{sm_version}" in reason -@pytest.mark.parametrize("sm_version", sorted(FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) +@pytest.mark.parametrize("sm_version", sorted(CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS)) def test_can_implement_accepts_supported_sm_with_nvfp4(sm_version): with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=sm_version): - ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(QuantAlgo.NVFP4) + ok, reason = CuteDslB12xFusedMoE.can_implement(QuantAlgo.NVFP4) assert ok assert reason is None @@ -68,16 +66,14 @@ def test_can_implement_accepts_supported_sm_with_nvfp4(sm_version): def test_can_implement_rejects_non_nvfp4(quant_algo): """Only NVFP4 is supported; everything else must be turned away.""" with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): - ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement(quant_algo) + ok, reason = CuteDslB12xFusedMoE.can_implement(quant_algo) assert not ok assert reason is not None and "NVFP4" in reason def test_can_implement_rejects_swiglu_gptoss_style(): with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): - ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement( - QuantAlgo.NVFP4, swiglu_gptoss_style=True - ) + ok, reason = CuteDslB12xFusedMoE.can_implement(QuantAlgo.NVFP4, swiglu_gptoss_style=True) assert not ok assert reason is not None and "swiglu_gptoss_style" in reason @@ -85,9 +81,7 @@ def test_can_implement_rejects_swiglu_gptoss_style(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float8_e4m3fn]) def test_can_implement_rejects_unsupported_activation_dtype(dtype): with patch(f"{_FUSED_MOE_MODULE}.get_sm_version", return_value=120): - ok, reason = FlashInferNvfp4Sm12xFusedMoE.can_implement( - QuantAlgo.NVFP4, dtype_activation=dtype - ) + ok, reason = CuteDslB12xFusedMoE.can_implement(QuantAlgo.NVFP4, dtype_activation=dtype) assert not ok assert reason is not None @@ -121,7 +115,7 @@ def test_get_moe_cls_falls_back_to_cutlass_on_unsupported_sm(): assert cls is CutlassFusedMoE -@pytest.mark.parametrize("sm_version", sorted(FlashInferNvfp4Sm12xFusedMoE._SUPPORTED_SM_VERSIONS)) +@pytest.mark.parametrize("sm_version", sorted(CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS)) def test_get_moe_cls_auto_promotes_on_supported_sm(sm_version): """CUTLASS + NVFP4 + SM120/121 + flashinfer importable → hybrid backend.""" cfg = ModelConfig() @@ -129,7 +123,7 @@ def test_get_moe_cls_auto_promotes_on_supported_sm(sm_version): cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) with patch("tensorrt_llm._utils.get_sm_version", return_value=sm_version): cls = get_moe_cls(cfg) - assert cls is FlashInferNvfp4Sm12xFusedMoE + assert cls is CuteDslB12xFusedMoE def test_get_moe_cls_falls_back_when_flashinfer_missing(monkeypatch): @@ -168,9 +162,9 @@ class _RoutePredicateStub: the unbound ``_route_to_cutlass`` without instantiating the whole MoE backend.""" - _PREFILL_VIA_CUTLASS_THRESHOLD = FlashInferNvfp4Sm12xFusedMoE._PREFILL_VIA_CUTLASS_THRESHOLD + _PREFILL_VIA_CUTLASS_THRESHOLD = CuteDslB12xFusedMoE._PREFILL_VIA_CUTLASS_THRESHOLD - _route_to_cutlass = FlashInferNvfp4Sm12xFusedMoE._route_to_cutlass + _route_to_cutlass = CuteDslB12xFusedMoE._route_to_cutlass def test_dispatch_routes_prefill_shape_via_cutlass(): diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index 0b5e4750e07a..7f47bfdd7bbe 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -139,12 +139,12 @@ def create_test_backend( pretrained_config.intermediate_size = intermediate_size pretrained_config.torch_dtype = dtype - # FLASHINFER_NVFP4SM12X is internal-only: the user-facing API selects it + # CUTE_DSL_B12X is internal-only: the user-facing API selects it # transparently via the CUTLASS heuristic auto-promotion on SM120/121 + # NVFP4. Route through "CUTLASS" so the test exercises the same code path # users hit. moe_backend_value = ( - "CUTLASS" if backend_type == MoeBackendType.FLASHINFER_NVFP4SM12X else backend_type.value + "CUTLASS" if backend_type == MoeBackendType.CUTE_DSL_B12X else backend_type.value ) model_config = ModelConfig( pretrained_config=pretrained_config, @@ -301,7 +301,7 @@ def run_backend_moe( MoeBackendType.DEEPGEMM, MoeBackendType.DENSEGEMM, MoeBackendType.MEGAMOE, - MoeBackendType.FLASHINFER_NVFP4SM12X, + MoeBackendType.CUTE_DSL_B12X, ] # Data types to test diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index 320656727fad..019cc6756774 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -249,11 +249,11 @@ def _create_model_config( else None ) - # FLASHINFER_NVFP4SM12X is an internal-only MoeBackendType — it has no + # CUTE_DSL_B12X is an internal-only MoeBackendType — it has no # corresponding user-facing MoeConfig.backend literal. Route through # "CUTLASS" so the test exercises the heuristic auto-promotion path that # users hit on SM120/121 + NVFP4. - if moe_backend == MoeBackendType.FLASHINFER_NVFP4SM12X.value: + if moe_backend == MoeBackendType.CUTE_DSL_B12X.value: moe_backend = MoeBackendType.CUTLASS.value kwargs = dict( @@ -824,7 +824,7 @@ def init_worker(custom_paths, comm_method_type, master_port): MoeBackendType.DEEPGEMM, MoeBackendType.DENSEGEMM, MoeBackendType.MEGAMOE, - MoeBackendType.FLASHINFER_NVFP4SM12X, + MoeBackendType.CUTE_DSL_B12X, ] # Data types to test From 88676a380ac41d419c228d153b24fe4eb4671724 Mon Sep 17 00:00:00 2001 From: list <58580514+farazkh80@users.noreply.github.com> Date: Thu, 21 May 2026 18:25:28 -0700 Subject: [PATCH 10/10] [None][fix] move FlashInfer NVFP4 MoE selection from CUTLASS to CUTEDSL backend path The CUTLASS path in get_moe_cls was auto-promoting to CuteDslB12xFusedMoE on SM120/SM121 + NVFP4 when flashinfer was importable, silently overriding explicit moe_backend=CUTLASS requests. On GB10 (DGX Spark, sm_121) this broke L0_Test-SBSA-Single-GPU GB10-PyTorch-1 test_configurable_moe_single_gpu CUTLASS +NVFP4 cases with "'CuteDslB12xFusedMoE' object does not support the context manager protocol". Selection now lives on the CUTEDSL path: CUTEDSL + NVFP4 + SM120/121 + flashinfer importable -> CuteDslB12xFusedMoE; otherwise CuteDslFusedMoE. Explicit CUTLASS always returns CutlassFusedMoE. Test-side: test_moe_module remaps MoeBackendType.CUTE_DSL_B12X -> "CUTEDSL" (was "CUTLASS"), and the get_moe_cls unit tests in test_cute_dsl_b12x_moe_backend.py are rewritten to assert the new ownership (CUTLASS never promotes; CUTEDSL selects b12x when eligible, falls back to CuteDslFusedMoE on unsupported SM or missing flashinfer, and to CutlassFusedMoE on unsupported quant). User-facing log + docs refer to the kernel as "FlashInfer NVFP4 MoE" rather than "b12x". Signed-off-by: list <58580514+farazkh80@users.noreply.github.com> --- .../modules/fused_moe/MOE_DEVELOPER_GUIDE.md | 2 +- .../_torch/modules/fused_moe/create_moe.py | 44 ++++++----- .../fused_moe/fused_moe_cute_dsl_b12x.py | 2 +- .../moe/test_cute_dsl_b12x_moe_backend.py | 78 ++++++++++++------- .../_torch/modules/moe/test_moe_module.py | 7 +- 5 files changed, 79 insertions(+), 54 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md b/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md index 3dba9fbf9ded..da4d914e90e7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md +++ b/tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md @@ -147,7 +147,7 @@ Still on old path (standalone, with embedded communication): | `fused_moe_deepgemm.py` | `DeepGemmFusedMoE` | SM100/SM103 | FP8 Block Scales on Blackwell | `EXTERNAL_COMM` | | `fused_moe_densegemm.py` | `DenseGEMMFusedMoE` | SM100/SM103 | NVFP4 min-latency; CuTe DSL dense GEMM packs all experts into one matrix (vs Cutlass per-expert scatter), efficient for small token counts | `EXTERNAL_COMM` | | `fused_moe_cute_dsl.py` | `CuteDslFusedMoE` | SM100/SM103 | High throughput NVFP4, generally faster than Cutlass | `EXTERNAL_COMM` | -| `fused_moe_cute_dsl_b12x.py` | `CuteDslB12xFusedMoE` | SM120/SM121 | NVFP4; auto-selected on the `CUTLASS` path | `EXTERNAL_COMM` | +| `fused_moe_cute_dsl_b12x.py` | `CuteDslB12xFusedMoE` | SM120/SM121 | NVFP4 hybrid CUTLASS-prefill / FlashInfer NVFP4 MoE decode; selected on the `CUTEDSL` path when flashinfer is importable | `EXTERNAL_COMM` | | `mega_moe/mega_moe_deepgemm.py` | `MegaMoEDeepGemm` | SM100/SM103 | W4A8_MXFP4_MXFP8 via DeepGEMM `fp8_fp4_mega_moe` fused dispatch+GEMM+act+GEMM+combine kernel; requires `hidden_size % 512 == 0` | `FUSED_COMM` | | `fused_moe_triton.py` | `TritonFusedMoE` | SM90 only | GPT-OSS on Hopper (requires `swiglu_gptoss_style=True`) | (legacy path) | | `fused_moe_wide_ep.py` | `WideEPMoE` | All GPUs | Deprecating — use ConfigurableMoE instead | (legacy path) | diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 69a1e23b8b35..6b752f90b088 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -34,27 +34,6 @@ def get_moe_cls( if override_quant_config is not None: quant_config = override_quant_config if moe_backend.upper() == "CUTLASS": - # Auto-promote to CuteDslB12xFusedMoE (hybrid CUTLASS-prefill - # / b12x-decode) on SM120 / SM121 + NVFP4 when flashinfer is available. - # Falls back to plain CutlassFusedMoE otherwise. - if quant_config is not None and quant_config.quant_mode.has_nvfp4(): - from tensorrt_llm._utils import get_sm_version - sm_version = get_sm_version() - if sm_version in CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS: - try: - import flashinfer # noqa: F401 - logger.info( - "Auto-selecting CuteDslB12xFusedMoE for hybrid " - "CUTLASS-prefill / b12x-decode (SM%d + NVFP4).", - sm_version, - ) - return CuteDslB12xFusedMoE - except ImportError: - logger.warning( - "CuteDslB12xFusedMoE eligible (SM%d + NVFP4) " - "but flashinfer is not importable; using CutlassFusedMoE.", - sm_version, - ) return CutlassFusedMoE elif moe_backend.upper() == "VANILLA": return VanillaMoE @@ -62,6 +41,29 @@ def get_moe_cls( if quant_config is not None and ( quant_config.quant_mode.has_fp8_block_scales() or quant_config.quant_mode.has_nvfp4()): + # On SM120 / SM121 + NVFP4 the hybrid CUTLASS-prefill / + # FlashInfer NVFP4 MoE decode backend (CuteDslB12xFusedMoE) is + # the optimized cuteDSL-family implementation, so prefer it when + # flashinfer is importable. + if quant_config.quant_mode.has_nvfp4(): + from tensorrt_llm._utils import get_sm_version + sm_version = get_sm_version() + if sm_version in CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS: + try: + import flashinfer # noqa: F401 + logger.info( + "Selecting CuteDslB12xFusedMoE for hybrid " + "CUTLASS-prefill / FlashInfer NVFP4 MoE decode " + "(SM%d + NVFP4).", + sm_version, + ) + return CuteDslB12xFusedMoE + except ImportError: + logger.warning( + "CuteDslB12xFusedMoE eligible (SM%d + NVFP4) " + "but flashinfer is not importable; using CuteDslFusedMoE.", + sm_version, + ) return CuteDslFusedMoE else: logger.warning( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py index 101d3d913ede..392c076d954f 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py @@ -79,7 +79,7 @@ class CuteDslB12xFusedMoE(CuteDslFusedMoE): The backend hard-rejects EP (b12x has no dispatch / combine kernel), MoE alltoall, ``Fp4QuantizedTensor`` input, ``swiglu_gptoss_style`` biased SwiGLU, and activations outside ``{Relu2, Swiglu}``. It is - auto-selected on the ``CUTLASS`` MoE path when SM120 / SM121 + NVFP4 + + selected on the ``CUTEDSL`` MoE path when SM120 / SM121 + NVFP4 + flashinfer-importable gates pass (see ``create_moe.get_moe_cls``). """ diff --git a/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py b/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py index 776d4fa1276a..ec0791e548ec 100644 --- a/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_cute_dsl_b12x_moe_backend.py @@ -15,11 +15,12 @@ """Negative-path + dispatch tests for CuteDslB12xFusedMoE. These checks run without a GPU: they verify the can_implement() gating -matrix, the heuristic auto-promotion in create_moe.get_moe_cls (the -backend is selected transparently from `moe_backend=CUTLASS` on -SM120/SM121 + NVFP4), and the hybrid CUTLASS-prefill / b12x-decode -dispatch predicate. Functional correctness of the b12x kernel is -covered by end-to-end model tests on SM120/SM121 hardware. +matrix, the SM120/SM121 + NVFP4 selection in create_moe.get_moe_cls (the +backend is selected on the `moe_backend=CUTEDSL` path when flashinfer +is importable, never from `moe_backend=CUTLASS`), and the hybrid +CUTLASS-prefill / b12x-decode dispatch predicate. Functional +correctness of the b12x kernel is covered by end-to-end model tests on +SM120/SM121 hardware. """ from unittest.mock import patch @@ -29,6 +30,7 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe.create_moe import get_moe_cls +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl_b12x import CuteDslB12xFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig @@ -86,52 +88,50 @@ def test_can_implement_rejects_unsupported_activation_dtype(dtype): assert reason is not None -def test_get_moe_cls_falls_back_to_cutlass_on_non_nvfp4(): - """Heuristic auto-promotion only fires on NVFP4; otherwise CUTLASS path stays.""" - cfg = ModelConfig() - cfg.moe_backend = "CUTLASS" - cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) - with patch("tensorrt_llm._utils.get_sm_version", return_value=120): - cls = get_moe_cls(cfg) - assert cls is CutlassFusedMoE - - -def test_get_moe_cls_falls_back_to_cutlass_on_missing_quant(): +@pytest.mark.parametrize("sm_version", [100, 103, 120, 121]) +@pytest.mark.parametrize( + "quant_algo", + [None, QuantAlgo.FP8, QuantAlgo.NVFP4, QuantAlgo.FP8_BLOCK_SCALES], +) +def test_get_moe_cls_cutlass_path_never_promotes(sm_version, quant_algo): + """Explicit `moe_backend=CUTLASS` must always return CutlassFusedMoE, + regardless of SM or quant. The b12x hybrid lives on the CUTEDSL path.""" cfg = ModelConfig() cfg.moe_backend = "CUTLASS" - cfg.quant_config = None - with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cfg.quant_config = QuantConfig(quant_algo=quant_algo) if quant_algo is not None else None + with patch("tensorrt_llm._utils.get_sm_version", return_value=sm_version): cls = get_moe_cls(cfg) assert cls is CutlassFusedMoE -def test_get_moe_cls_falls_back_to_cutlass_on_unsupported_sm(): - """NVFP4 + non-SM120/121 must not auto-promote.""" +def test_get_moe_cls_cutedsl_falls_back_to_cute_dsl_on_unsupported_sm(): + """CUTEDSL + NVFP4 + non-SM120/121 → plain CuteDslFusedMoE (no b12x).""" cfg = ModelConfig() - cfg.moe_backend = "CUTLASS" + cfg.moe_backend = "CUTEDSL" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) with patch("tensorrt_llm._utils.get_sm_version", return_value=100): cls = get_moe_cls(cfg) - assert cls is CutlassFusedMoE + assert cls is CuteDslFusedMoE @pytest.mark.parametrize("sm_version", sorted(CuteDslB12xFusedMoE._SUPPORTED_SM_VERSIONS)) -def test_get_moe_cls_auto_promotes_on_supported_sm(sm_version): - """CUTLASS + NVFP4 + SM120/121 + flashinfer importable → hybrid backend.""" +def test_get_moe_cls_cutedsl_selects_b12x_on_supported_sm(sm_version): + """CUTEDSL + NVFP4 + SM120/121 + flashinfer importable → hybrid backend.""" cfg = ModelConfig() - cfg.moe_backend = "CUTLASS" + cfg.moe_backend = "CUTEDSL" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) with patch("tensorrt_llm._utils.get_sm_version", return_value=sm_version): cls = get_moe_cls(cfg) assert cls is CuteDslB12xFusedMoE -def test_get_moe_cls_falls_back_when_flashinfer_missing(monkeypatch): - """Eligible hardware but flashinfer not importable → CutlassFusedMoE.""" +def test_get_moe_cls_cutedsl_falls_back_when_flashinfer_missing(monkeypatch): + """Eligible hardware on the CUTEDSL path but flashinfer not importable + → falls back to plain CuteDslFusedMoE (still cuteDSL family).""" import builtins cfg = ModelConfig() - cfg.moe_backend = "CUTLASS" + cfg.moe_backend = "CUTEDSL" cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4) real_import = builtins.__import__ @@ -142,6 +142,28 @@ def _raise_on_flashinfer(name, *args, **kwargs): return real_import(name, *args, **kwargs) monkeypatch.setattr(builtins, "__import__", _raise_on_flashinfer) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is CuteDslFusedMoE + + +def test_get_moe_cls_cutedsl_non_nvfp4_uses_plain_cute_dsl(): + """CUTEDSL + FP8 block scales must use CuteDslFusedMoE, not the b12x hybrid + (b12x is NVFP4-only).""" + cfg = ModelConfig() + cfg.moe_backend = "CUTEDSL" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + with patch("tensorrt_llm._utils.get_sm_version", return_value=120): + cls = get_moe_cls(cfg) + assert cls is CuteDslFusedMoE + + +def test_get_moe_cls_cutedsl_unsupported_quant_falls_back_to_cutlass(): + """CUTEDSL with a quant_config it does not support (e.g. plain FP8) + still falls back to CutlassFusedMoE, matching the legacy behavior.""" + cfg = ModelConfig() + cfg.moe_backend = "CUTEDSL" + cfg.quant_config = QuantConfig(quant_algo=QuantAlgo.FP8) with patch("tensorrt_llm._utils.get_sm_version", return_value=120): cls = get_moe_cls(cfg) assert cls is CutlassFusedMoE diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index 019cc6756774..2b793a173494 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -251,10 +251,11 @@ def _create_model_config( # CUTE_DSL_B12X is an internal-only MoeBackendType — it has no # corresponding user-facing MoeConfig.backend literal. Route through - # "CUTLASS" so the test exercises the heuristic auto-promotion path that - # users hit on SM120/121 + NVFP4. + # "CUTEDSL" so the test exercises the cuteDSL-family selection path that + # users hit on SM120/121 + NVFP4 (where get_moe_cls returns the hybrid + # CuteDslB12xFusedMoE backend when flashinfer is importable). if moe_backend == MoeBackendType.CUTE_DSL_B12X.value: - moe_backend = MoeBackendType.CUTLASS.value + moe_backend = MoeBackendType.CUTEDSL.value kwargs = dict( pretrained_config=pretrained_config,