diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f4cb825d36..258c34ecc1 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -260,7 +260,7 @@ void TrtllmGenBatchedGemmRunner::run( auto const err = bmm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount, - enable_pdl, /*pinnedHostBuffer=*/nullptr, globalTrtllmGenBatchedGemmModuleCache); + enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache); FLASHINFER_CHECK(err == 0, "Error occurred when running GEMM!" diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 8f7f157fbb..1271fdc2b2 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1246,19 +1246,22 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { public: static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; - MxInt4BlockScaleLauncher(TensorView const& routing_logits, + MxInt4BlockScaleLauncher(Optional const& routing_logits, Optional const& routing_bias, TensorView const& hidden_states, TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, Optional const& gemm1_alpha, Optional const& gemm1_beta, Optional const& gemm1_clamp_limit, - TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale) - : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, - gemm1_weights, Optional(), Optional(), - gemm2_weights, Optional()), + TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale, + TensorView const& expert_indices, TensorView const& expert_weights_in) + : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, + Optional(), Optional(), gemm2_weights, + Optional()), gemm1_weights_scale(gemm1_weights_scale), - gemm2_weights_scale(gemm2_weights_scale) {} + gemm2_weights_scale(gemm2_weights_scale), + expert_indices(expert_indices), + expert_weights_in(expert_weights_in) {} void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type) { @@ -1280,7 +1283,29 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu); } - void check_routing() const override { FusedMoeLauncher::check_routing_common(); } + void check_routing() const override { + if (has_precomputed_indices()) { + // Pre-computed routing: expert_indices is a packed tensor + // Format: (expert_id << 16) | (weight_bf16.view(int16)) + TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) + << "expert_indices and hidden_states must have same number of tokens."; + TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) + << "expert_indices dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; + } + + if (has_precomputed_weights()) { + // Pre-computed expert weights: validate shape and dtype + TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0)) + << "expert_weights_in and hidden_states must have same number of tokens."; + TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k) + << "expert_weights_in dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16) + << "expert_weights_in must be bfloat16."; + } + + FusedMoeLauncher::check_routing_common(); + } void prepare_routing() override { FusedMoeLauncher::prepare_routing_common(); @@ -1292,11 +1317,24 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; - expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device()); + if (has_precomputed_indices()) { + // Use expert_indices directly + workspace.routing_expert_indexes = + static_cast(const_cast(expert_indices.data_ptr())); + } else { + // Use routing_logits directly + args->routing_logits = static_cast(routing_logits.value().data_ptr()); + } - workspace.expert_weights = expert_weights.data_ptr(); + if (has_precomputed_weights()) { + workspace.expert_weights = const_cast(expert_weights_in.data_ptr()); + } else { + // Allocate expert_weights buffer for routing output + auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16; + expert_weights = alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, + hidden_states.device()); + workspace.expert_weights = expert_weights.data_ptr(); + } } void check_moe() const override { @@ -1364,10 +1402,64 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { Optional gemm1_beta; Optional gemm1_clamp_limit; TensorView gemm2_weights_scale; + TensorView expert_indices; + TensorView expert_weights_in; int32_t max_num_padded_tokens_gemm1{}; int32_t max_num_padded_tokens_gemm2{}; + // Helper to check if pre-computed routing indices are provided + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool has_precomputed_indices() const { + return expert_indices.ndim() == 2 && expert_indices.size(0) > 0; + } + + // Helper to check if pre-computed routing weights are provided + bool has_precomputed_weights() const { + return expert_weights_in.ndim() == 2 && expert_weights_in.size(0) > 0; + } + public: + // Override to handle pre-computed routing + Array run(int64_t moe_tactic, bool enable_pdl = true, + bool use_routing_scales_on_input = false, + bool use_deep_seek_fp8 = false) override { + check_routing(); + prepare_routing(); + + cudaStream_t routing_stream = get_stream(hidden_states.device()); + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + + // When using pre-computed routing, pass nullptr as routing_logits to tell the + // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes + routing_runner.run( + has_precomputed_indices() ? nullptr : args->routing_logits, args->routing_bias, + args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, + args->local_expert_offset, args->local_num_experts, args->routed_scaling_factor, + workspace.routing_expert_indexes, static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights, + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); + + check_moe(); + prepare_moe(moe_tactic); + + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner->run(*args, workspace, hidden_states.device().device_id, moe_stream, moe_tactic, + enable_pdl); + + if (args->do_finalize) { + return {output}; + } + return {gemm2_output, FusedMoeLauncher::expert_weights, expanded_idx_to_permuted_idx}; + } + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, int64_t num_tokens) { @@ -2114,8 +2206,9 @@ Array trtllm_fp4_block_scale_moe( } Array trtllm_mxint4_block_scale_moe( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView gemm1_weights, TensorView gemm1_weights_scale, Optional gemm1_alpha, + Optional routing_logits, TensorView expert_indices, TensorView expert_weights, + Optional routing_bias, TensorView hidden_states, TensorView gemm1_weights, + TensorView gemm1_weights_scale, Optional gemm1_alpha, Optional gemm1_beta, Optional gemm1_clamp_limit, TensorView gemm2_weights, TensorView gemm2_weights_scale, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, @@ -2132,10 +2225,23 @@ Array trtllm_mxint4_block_scale_moe( TVM_FFI_ICHECK(weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; - TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16) - << "routing_logits must be float or bfloat16."; - TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape."; + // Either routing_logits or expert_indices must be provided + // expert_indices is a packed tensor: (expert_id << 16) | (weight_bf16.view(int16)) + bool use_routing_logits = routing_logits.has_value(); + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool use_precomputed_routing = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; + + TVM_FFI_ICHECK(use_routing_logits || use_precomputed_routing) + << "Either routing_logits or expert_indices must be provided."; + + if (use_routing_logits) { + TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || + routing_logits.value().dtype() == dl_bfloat16) + << "routing_logits must be float or bfloat16."; + TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) + << "routing_logits has incorrect shape."; + } if (routing_bias.has_value()) { TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16) << "routing_bias must be bfloat16."; TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; @@ -2178,7 +2284,8 @@ Array trtllm_mxint4_block_scale_moe( // Create and initialize launcher for this tile size auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale, - gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale); + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, + expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type); launchers_map[curr_tile_N] = std::move(launcher); diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index e2b4cab3d6..9c52168f1d 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -35,6 +35,7 @@ trtllm_bf16_moe, trtllm_bf16_routed_moe, trtllm_mxint4_block_scale_moe, + trtllm_mxint4_block_scale_routed_moe, ) from .fused_routing_dsv3 import ( # noqa: F401 @@ -73,6 +74,7 @@ "trtllm_fp8_block_scale_routed_moe", "trtllm_fp8_per_tensor_scale_moe", "trtllm_mxint4_block_scale_moe", + "trtllm_mxint4_block_scale_routed_moe", "fused_topk_deepseek", ] diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 8bcbe36381..fdba456712 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1109,6 +1109,8 @@ def forward( ): moe_op.trtllm_mxint4_block_scale_moe( routing_logits, + topk_ids, + expert_weights, kwargs["routing_bias"], hidden_states, kwargs["gemm1_weights"], @@ -1972,7 +1974,9 @@ def _fake_trtllm_fp4_block_scale_moe( mutates_args=(""), ) def trtllm_mxint4_block_scale_moe_op( - routing_logits: torch.Tensor, + routing_logits: Optional[torch.Tensor], + topk_ids: Optional[torch.Tensor], + expert_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, @@ -1982,6 +1986,7 @@ def trtllm_mxint4_block_scale_moe_op( gemm1_clamp_limit: Optional[torch.Tensor], gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, + output: Optional[torch.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -1993,24 +1998,27 @@ def trtllm_mxint4_block_scale_moe_op( routing_method_type: int, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: - routing_dtype = routing_logits.dtype + # Determine routing mode: compute from logits or use pre-computed + if routing_logits is None: + assert topk_ids is not None, ( + "either topk_ids or routing_logits must be provided." + ) + assert topk_ids.dtype == torch.int32, "topk_ids must be an int32 tensor." + routing_dtype = torch.bfloat16 + else: + routing_dtype = routing_logits.dtype + + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + hidden_size = hidden_states.shape[-1] if hidden_states.dtype == torch.uint8: hidden_size = hidden_size * 2 num_tokens = hidden_states.shape[0] - # workspace buffers required by trtllm-gen - topk_ids = torch.empty( - num_tokens, top_k, dtype=torch.int32, device=hidden_states.device - ) - expert_weights = torch.empty( - num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device - ) - if enable_pdl is None: - enable_pdl = device_support_pdl(hidden_states.device) + # Create output buffer if not provided if output is None: output = torch.empty( num_tokens, @@ -2018,6 +2026,25 @@ def trtllm_mxint4_block_scale_moe_op( dtype=torch.bfloat16, device=hidden_states.device, ) + else: + check_shape_dtype_device( + output, None, torch.bfloat16, hidden_states.device, "output" + ) + if routing_logits is not None: + # When routing_logits is provided, we must pass topk_ids/expert_weights with no allocation + topk_ids = torch.empty(0, dtype=torch.int32, device=hidden_states.device) + expert_weights = torch.empty( + 0, dtype=routing_dtype, device=hidden_states.device + ) + else: + # When routing_logits is None, we either have topk_ids/expert_weights + # packed into a single tensor as topk_ids, + # or have them individually as topk_ids and expert_weights respectively + expert_weights = ( + expert_weights + if expert_weights is not None + else torch.empty(0, dtype=routing_dtype, device=hidden_states.device) + ) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -2036,9 +2063,17 @@ def trtllm_mxint4_block_scale_moe_op( use_shuffled_weight=True, ) tunning_config = MoERunner.tuning_config_no_hidden_states_scales + # Create placeholder for tuning when routing_logits is None (routed mode) + routing_logits_for_tuning = ( + routing_logits + if routing_logits is not None + else torch.empty( + num_tokens, num_experts, dtype=routing_dtype, device="meta" + ) + ) inputs = [ output, - routing_logits, + routing_logits_for_tuning, topk_ids, expert_weights, hidden_states, @@ -2070,6 +2105,8 @@ def trtllm_mxint4_block_scale_moe_op( # Call the C++ function for block scale MoE intermediate_output = moe_op.trtllm_mxint4_block_scale_moe( routing_logits, + topk_ids, + expert_weights, routing_bias, hidden_states, gemm1_weights, @@ -2098,13 +2135,19 @@ def trtllm_mxint4_block_scale_moe_op( else: return [ torch.from_dlpack(intermediate_output[0]), - torch.from_dlpack(intermediate_output[1]), + ( + torch.from_dlpack(intermediate_output[1]) + if routing_logits is not None or expert_weights.numel() == 0 + else expert_weights + ), torch.from_dlpack(intermediate_output[2]), ] @register_fake_op("flashinfer::trtllm_mxint4_block_scale_moe") def _fake_trtllm_mxint4_block_scale_moe( - routing_logits: torch.Tensor, + routing_logits: Optional[torch.Tensor], + topk_ids: Optional[torch.Tensor], + expert_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, gemm1_weights: torch.Tensor, @@ -2114,6 +2157,7 @@ def _fake_trtllm_mxint4_block_scale_moe( gemm1_clamp_limit: Optional[torch.Tensor], gemm2_weights: torch.Tensor, gemm2_weights_scale: torch.Tensor, + output: Optional[torch.Tensor], num_experts: int, top_k: int, n_group: Optional[int], @@ -2123,9 +2167,9 @@ def _fake_trtllm_mxint4_block_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], routing_method_type: int, - enable_pdl: bool, - output: Optional[torch.Tensor], - tune_max_num_tokens: int, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2989,6 +3033,8 @@ def trtllm_mxint4_block_scale_moe( """ return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe( routing_logits, + None, # topk_ids + None, # expert_weights routing_bias, hidden_states, gemm1_weights, @@ -2998,6 +3044,7 @@ def trtllm_mxint4_block_scale_moe( gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, + output, num_experts, top_k, n_group, @@ -3009,6 +3056,103 @@ def trtllm_mxint4_block_scale_moe( routing_method_type, do_finalize, enable_pdl, + tune_max_num_tokens, + ) + + +@flashinfer_api +def trtllm_mxint4_block_scale_routed_moe( + topk_ids: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 8192, +) -> List[torch.Tensor]: + """MxInt4 block scale MoE operation with pre-computed routing (packed format). + + This function is used when routing decisions have already been computed + and packed into a single tensor. This is useful for: + - CUDA Graph capture (avoids CPU-GPU sync from routing_logits processing) + - Distributed MoE where routing is computed elsewhere + + Args: + topk_ids: [seq_len, top_k] tensor of packed expert indices and weights (int32). + Format: (expert_id << 16) | (weight_bf16.view(int16)) + Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16) + routing_bias: [num_experts] tensor of routing bias (can be None) + hidden_states: [seq_len, hidden_size] tensor of input hidden states. Must be bfloat16. + gemm1_weights: [num_experts, 2 * intermediate_size, hidden_size // 2] tensor of FC1 weights. + Dtype must be uint8 (packed mxint4). + gemm1_weights_scale: [num_experts, 2 * intermediate_size, hidden_size // 32] tensor of FC1 scales. + Dtype must be bfloat16. + gemm1_alpha: Optional [num_experts] tensor of swiglu alpha. Dtype is float32. + gemm1_beta: Optional [num_experts] tensor of swiglu beta. Dtype is float32. + gemm1_clamp_limit: Optional [num_experts] tensor of swiglu clamp limit. Dtype is float32. + gemm2_weights: [num_experts, hidden_size, intermediate_size // 2] tensor of FC2 weights. + Dtype must be uint8 (packed mxint4). + gemm2_weights_scale: [num_experts, hidden_size, intermediate_size // 32] tensor of FC2 scales. + Dtype must be bfloat16. + num_experts: Total number of experts + top_k: Number of experts to route to per token + n_group: Number of expert groups + topk_group: Number of groups to consider for top-k routing + intermediate_size: Size of intermediate layer + local_expert_offset: Offset of local experts in global expert space + local_num_experts: Number of experts handled by this device + routed_scaling_factor: Scaling factor for routing + routing_method_type: Type of routing method to use (default: 0) + do_finalize: Whether to finalize the output (default: True) + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + output: Optional [seq_len, hidden_size] inplace output tensor. + tune_max_num_tokens: Maximum number of tokens for tuning. (default: 8192) + + Returns: + when do_finalize=True, returns the final MoE output. + otherwise, returns the intermediate results (gemm2_output, undefined, expanded_idx_to_permuted_idx) + that need further processing. + """ + return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe( + None, # routing_logits + topk_ids, + None, # expert_weights + routing_bias, + hidden_states, + gemm1_weights, + gemm1_weights_scale, + gemm1_alpha, + gemm1_beta, + gemm1_clamp_limit, + gemm2_weights, + gemm2_weights_scale, output, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + do_finalize, + enable_pdl, tune_max_num_tokens, ) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index f5fd5cc263..a045e52830 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -36,6 +36,8 @@ trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_routed_moe, WeightLayout, + trtllm_mxint4_block_scale_moe, + trtllm_mxint4_block_scale_routed_moe, ) from flashinfer.fused_moe.core import Fp8QuantizationType from flashinfer.utils import device_support_pdl @@ -46,6 +48,10 @@ routing_reference_renormalize, routing_reference_renormalize_naive, routing_reference_topk, + mxint4_quantize, + block_scale_interleave, + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, ) from flashinfer.utils import get_compute_capability @@ -705,3 +711,213 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): close = torch.isclose(output_ref, output_routed, atol=1e-2, rtol=1e-2) mismatch_pct = (~close).float().mean().item() * 100 assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" + + +@pytest.mark.parametrize("num_tokens", [8, 64]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("top_k", [2, 4]) +@pytest.mark.parametrize( + "routing_method_type", + [ + RoutingMethodType.Renormalize, + ], +) +def test_trtllm_gen_mxint4_routed_fused_moe( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, + routing_method_type: RoutingMethodType, +): + """Test MxInt4 block scale routed MoE matches standard routing.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + torch.manual_seed(42) + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + + # Generate random routing logits for reference + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + + # Generate random hidden states in BF16 + hidden_states = ( + torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 + ) + + # Generate BF16 weights and quantize to MxInt4 + gemm1_weights_bf16 = ( + torch.randn(num_experts, 2 * intermediate_size, hidden_size, device=device).to( + torch.bfloat16 + ) + * 0.1 + ) + gemm2_weights_bf16 = ( + torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + * 0.1 + ) + + # Quantize weights to MxInt4 + sf_vec_size = 32 + gemm1_weights_int4, gemm1_scales = mxint4_quantize(gemm1_weights_bf16, sf_vec_size) + gemm2_weights_int4, gemm2_scales = mxint4_quantize(gemm2_weights_bf16, sf_vec_size) + gemm1_scales = gemm1_scales.to(torch.bfloat16).reshape( + num_experts, 2 * intermediate_size, hidden_size // sf_vec_size + ) + gemm2_scales = gemm2_scales.to(torch.bfloat16).reshape( + num_experts, hidden_size, intermediate_size // sf_vec_size + ) + + # Prepare shuffled weights for kernel + epilogue_tile_m = 128 + gemm1_weights_mxint4_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_mxint4_shuffled = [] + gemm2_scales_shuffled = [] + cache_permute_indices = {} + + for i in range(num_experts): + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + cache_permute_indices, + gemm1_weights_int4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm1_weights_shuffled = ( + gemm1_weights_int4[i] + .view(torch.uint8)[permute_indices.to(gemm1_weights_int4.device)] + .contiguous() + ) + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + cache_permute_indices, + gemm1_scales[i].view(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=32, + ) + gemm1_scales_shuffled.append( + block_scale_interleave( + gemm1_scales[i] + .view(torch.bfloat16)[permute_sf_indices.to(gemm1_scales.device)] + .contiguous() + ) + ) + + permute_indices = get_w2_permute_indices_with_cache( + cache_permute_indices, + gemm2_weights_int4[i].view(torch.uint8), + epilogue_tile_m, + ) + gemm2_weights_shuffled = ( + gemm2_weights_int4[i] + .view(torch.uint8)[permute_indices.to(gemm2_weights_int4.device)] + .contiguous() + ) + + permute_sf_indices = get_w2_permute_indices_with_cache( + cache_permute_indices, + gemm2_scales[i].view(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + block_scale_interleave( + gemm2_scales[i] + .view(torch.bfloat16)[permute_sf_indices.to(gemm2_scales.device)] + .contiguous() + ) + ) + + block_k = 128 + gemm1_weights_shuffled = convert_to_block_layout( + gemm1_weights_shuffled, block_k + ) + gemm2_weights_shuffled = convert_to_block_layout( + gemm2_weights_shuffled.view(torch.uint8), block_k + ) + + gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) + gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) + + gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) + gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) + gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) + gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) + + # Run reference with routing_logits + reference_output = trtllm_mxint4_block_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + gemm1_weights_mxint4_shuffled, + gemm1_scales_shuffled, + None, # gemm1_alpha + None, # gemm1_beta + None, # gemm1_clamp_limit + gemm2_weights_mxint4_shuffled, + gemm2_scales_shuffled, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + routing_method_type.value, + True, # do_finalize + enable_pdl, + )[0].to(torch.float) + + # Compute routing using reference implementation + permute_info, expert_weights_ref = routing_reference_renormalize( + routing_logits, top_k, num_experts, 8 + ) + topk_ids = permute_info["topKIndices"].to(torch.int32) + expert_weights = expert_weights_ref.view(num_tokens, num_experts)[ + torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids + ].to(torch.bfloat16) + + # Pack topk_ids and expert_weights into single tensor + # Format: (expert_id << 16) | (weight_bf16.view(int16)) + packed_topk_ids = (topk_ids << 16) | expert_weights.view(torch.int16).to( + torch.int32 + ) + + # Run with pre-computed routing (packed format) + output = trtllm_mxint4_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights_mxint4_shuffled, + gemm1_weights_scale=gemm1_scales_shuffled, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=gemm2_weights_mxint4_shuffled, + gemm2_weights_scale=gemm2_scales_shuffled, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=routing_method_type.value, + enable_pdl=enable_pdl, + )[0].to(torch.float) + + mask = torch.isclose(output, reference_output, rtol=1e-2, atol=1e-2) + + # mismatch percentage + mismatch_pct = (~mask).float().mean().item() * 100 + assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%"