Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,9 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset,
int64_t local_num_experts, int64_t routing_method_type,
bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
Array<int64_t> moe_tactic) {
int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout,
bool enable_pdl, Array<int64_t> moe_tactic) {
// Just some basic type validation first and leave more checks to the launcher
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
Expand Down Expand Up @@ -1443,7 +1443,7 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
args->top_k = top_k;
args->n_group = n_group.value_or(0);
args->topk_group = topk_group.value_or(0);
;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->local_expert_offset = local_expert_offset;
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
Expand Down Expand Up @@ -1808,7 +1808,12 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
<< "routing_logits must be float or bfloat16.";
TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape.";
TVM_FFI_ICHECK(!routing_bias.has_value()) << "routing_bias is not supported for MxInt4 MoE.";
if (routing_bias.has_value()) {
TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16) << "routing_bias must be bfloat16.";
TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D.";
TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts)
<< "routing_bias has incorrect shape.";
}

// Determine activation type
TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8)
Expand Down
12 changes: 11 additions & 1 deletion flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ def forward(
self.intermediate_size,
kwargs["local_expert_offset"],
self.num_local_experts,
kwargs["routed_scaling_factor"],
kwargs["routing_method_type"],
kwargs["use_shuffled_weight"],
kwargs["weight_layout"],
Expand Down Expand Up @@ -1284,6 +1285,7 @@ def trtllm_bf16_moe_op(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int,
use_shuffled_weight: bool,
weight_layout: int,
Expand Down Expand Up @@ -1342,6 +1344,7 @@ def trtllm_bf16_moe_op(
topk_group=topk_group,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
use_shuffled_weight=use_shuffled_weight,
weight_layout=weight_layout,
Expand All @@ -1362,6 +1365,7 @@ def trtllm_bf16_moe_op(
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
routing_method_type,
use_shuffled_weight,
weight_layout,
Expand Down Expand Up @@ -2091,6 +2095,7 @@ def trtllm_bf16_moe(
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float] = None,
routing_method_type: int = 0,
use_shuffled_weight: bool = True,
weight_layout: int = WeightLayout.BlockMajorK,
Expand Down Expand Up @@ -2120,6 +2125,7 @@ def trtllm_bf16_moe(
intermediate_size: Size of intermediate layer.
local_expert_offset: Offset of local experts in global expert space.
local_num_experts: Number of experts handled by this device.
routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods)
routing_method_type: Type of routing method to use (default: 0).
- 0: Default (Softmax -> TopK)
- 1: Renormalize (TopK -> Softmax)
Expand Down Expand Up @@ -2150,6 +2156,7 @@ def trtllm_bf16_moe(
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
routing_method_type,
use_shuffled_weight,
weight_layout,
Expand Down Expand Up @@ -2575,6 +2582,7 @@ def trtllm_fp4_block_scale_routed_moe(
@flashinfer_api
def trtllm_mxint4_block_scale_moe(
routing_logits: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
Expand All @@ -2601,6 +2609,8 @@ def trtllm_mxint4_block_scale_moe(
Args:
routing_logits (torch.Tensor): shape [seq_len, num_experts]
Input tensor of routing logits. Supports float32, bfloat16.
routing_bias: Optional [num_experts] tensor of routing bias.
Must be bfloat16 if provided.
hidden_states (torch.Tensor): shape [seq_len, hidden_size]
Tensor of input hidden states. Supports bfloat16.
gemm1_weights (torch.Tensor): shape [num_experts, 2 * intermediate_size, hidden_size // 2]
Expand Down Expand Up @@ -2640,7 +2650,7 @@ def trtllm_mxint4_block_scale_moe(
"""
return get_trtllm_moe_sm100_module().trtllm_mxint4_block_scale_moe(
routing_logits,
None,
routing_bias,
hidden_states,
gemm1_weights,
gemm1_weights_scale,
Expand Down
22 changes: 19 additions & 3 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,18 +752,21 @@ def call_moe(
):
"""Call MoE with runtime input quantization + kernel execution (done at runtime)."""
expert_logits = kwargs["expert_logits"]
routing_bias = kwargs["routing_bias"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness and consistency with other parts of the code (e.g., enable_autotune), it's better to use kwargs.get("routing_bias") instead of direct access. This will prevent a KeyError if the key is missing.

Suggested change
routing_bias = kwargs["routing_bias"]
routing_bias = kwargs.get("routing_bias")

num_experts = kwargs["num_experts"]
top_k = kwargs["top_k"]
n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routing_method_type = kwargs["routing_method_type"]
enable_autotune = kwargs.get("enable_autotune", True)
routed_scaling = kwargs.get("routed_scaling", 1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value of 1.0 is already handled in the C++ layer. To maintain a single source of truth for default values and for consistency, it's better to remove the default value here. kwargs.get("routed_scaling") will return None if the key is missing, and the C++ layer will correctly use its default of 1.0.

Suggested change
routed_scaling = kwargs.get("routed_scaling", 1.0)
routed_scaling = kwargs.get("routed_scaling")


# Use autotuner for optimal kernel selection
with autotune(enable_autotune):
output = trtllm_mxint4_block_scale_moe(
expert_logits, # float
routing_bias,
hidden_states_orig,
static_data["gemm1_weights"],
static_data["gemm1_scales"],
Expand All @@ -779,7 +782,7 @@ def call_moe(
intermediate_size,
0,
num_experts,
1.0,
routed_scaling,
routing_method_type=routing_method_type,
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
)
Expand Down Expand Up @@ -1308,6 +1311,7 @@ def call_moe(
n_groups = kwargs["n_groups"]
top_k_groups = kwargs["top_k_groups"]
intermediate_size = kwargs["intermediate_size"]
routed_scaling = kwargs["routed_scaling"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness and consistency, it's better to use kwargs.get("routed_scaling") instead of direct access. This will prevent a KeyError if the key is missing and will return None, which is a valid value for this optional parameter and is handled correctly by the downstream C++ function.

Suggested change
routed_scaling = kwargs["routed_scaling"]
routed_scaling = kwargs.get("routed_scaling")

routing_method_type = kwargs["routing_method_type"]
enable_autotune = kwargs.get("enable_autotune", True)

Expand All @@ -1326,6 +1330,7 @@ def call_moe(
intermediate_size,
0,
num_experts,
routed_scaling,
use_shuffled_weight=static_data["use_shuffled_weight"],
weight_layout=static_data["weight_layout"],
routing_method_type=routing_method_type,
Expand Down Expand Up @@ -2626,6 +2631,8 @@ def test_renormalize_routing(
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"),
pytest.param(BF16Moe(), id="Bf16xBf16"),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2657,7 +2664,12 @@ def test_renormalize_routing(
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3,
"compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe],
"compatible_moe_impls": [
FP4Moe,
FP8BlockScaleMoe,
MxInt4BlockScaleMoe,
BF16Moe,
],
"compatible_intermediate_size": [512, 1024, 2048],
"enable_autotune": True,
},
Expand Down Expand Up @@ -2704,7 +2716,11 @@ def test_renormalize_routing(
{
"use_shuffled_weight": True,
"layout": WeightLayout.BlockMajorK,
"compatible_moe_impls": [FP8BlockScaleMoe],
"compatible_moe_impls": [
FP8BlockScaleMoe,
MxInt4BlockScaleMoe,
BF16Moe,
],
},
id="Shuffled_BlockMajorK",
),
Expand Down