-
Notifications
You must be signed in to change notification settings - Fork 828
fix: add DeepSeek routing for Bf16xBf16 and MxIntxBf16 TRT-LLM Gen MoE #2234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -752,18 +752,21 @@ def call_moe( | |||||
| ): | ||||||
| """Call MoE with runtime input quantization + kernel execution (done at runtime).""" | ||||||
| expert_logits = kwargs["expert_logits"] | ||||||
| routing_bias = kwargs["routing_bias"] | ||||||
| num_experts = kwargs["num_experts"] | ||||||
| top_k = kwargs["top_k"] | ||||||
| n_groups = kwargs["n_groups"] | ||||||
| top_k_groups = kwargs["top_k_groups"] | ||||||
| intermediate_size = kwargs["intermediate_size"] | ||||||
| routing_method_type = kwargs["routing_method_type"] | ||||||
| enable_autotune = kwargs.get("enable_autotune", True) | ||||||
| routed_scaling = kwargs.get("routed_scaling", 1.0) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value of
Suggested change
|
||||||
|
|
||||||
| # Use autotuner for optimal kernel selection | ||||||
| with autotune(enable_autotune): | ||||||
| output = trtllm_mxint4_block_scale_moe( | ||||||
| expert_logits, # float | ||||||
| routing_bias, | ||||||
| hidden_states_orig, | ||||||
| static_data["gemm1_weights"], | ||||||
| static_data["gemm1_scales"], | ||||||
|
|
@@ -779,7 +782,7 @@ def call_moe( | |||||
| intermediate_size, | ||||||
| 0, | ||||||
| num_experts, | ||||||
| 1.0, | ||||||
| routed_scaling, | ||||||
| routing_method_type=routing_method_type, | ||||||
| tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, | ||||||
| ) | ||||||
|
|
@@ -1308,6 +1311,7 @@ def call_moe( | |||||
| n_groups = kwargs["n_groups"] | ||||||
| top_k_groups = kwargs["top_k_groups"] | ||||||
| intermediate_size = kwargs["intermediate_size"] | ||||||
| routed_scaling = kwargs["routed_scaling"] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For robustness and consistency, it's better to use
Suggested change
|
||||||
| routing_method_type = kwargs["routing_method_type"] | ||||||
| enable_autotune = kwargs.get("enable_autotune", True) | ||||||
|
|
||||||
|
|
@@ -1326,6 +1330,7 @@ def call_moe( | |||||
| intermediate_size, | ||||||
| 0, | ||||||
| num_experts, | ||||||
| routed_scaling, | ||||||
| use_shuffled_weight=static_data["use_shuffled_weight"], | ||||||
| weight_layout=static_data["weight_layout"], | ||||||
| routing_method_type=routing_method_type, | ||||||
|
|
@@ -2626,6 +2631,8 @@ def test_renormalize_routing( | |||||
| pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), | ||||||
| pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), | ||||||
| pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), | ||||||
| pytest.param(MxInt4BlockScaleMoe(), id="MxInt4xBf16"), | ||||||
| pytest.param(BF16Moe(), id="Bf16xBf16"), | ||||||
| ], | ||||||
| ) | ||||||
| @pytest.mark.parametrize( | ||||||
|
|
@@ -2657,7 +2664,12 @@ def test_renormalize_routing( | |||||
| "routed_scaling": 2.5, | ||||||
| "has_routing_bias": True, | ||||||
| "routing_method_type": RoutingMethodType.DeepSeekV3, | ||||||
| "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], | ||||||
| "compatible_moe_impls": [ | ||||||
| FP4Moe, | ||||||
| FP8BlockScaleMoe, | ||||||
| MxInt4BlockScaleMoe, | ||||||
| BF16Moe, | ||||||
| ], | ||||||
| "compatible_intermediate_size": [512, 1024, 2048], | ||||||
| "enable_autotune": True, | ||||||
| }, | ||||||
|
|
@@ -2704,7 +2716,11 @@ def test_renormalize_routing( | |||||
| { | ||||||
| "use_shuffled_weight": True, | ||||||
| "layout": WeightLayout.BlockMajorK, | ||||||
| "compatible_moe_impls": [FP8BlockScaleMoe], | ||||||
| "compatible_moe_impls": [ | ||||||
| FP8BlockScaleMoe, | ||||||
| MxInt4BlockScaleMoe, | ||||||
| BF16Moe, | ||||||
| ], | ||||||
| }, | ||||||
| id="Shuffled_BlockMajorK", | ||||||
| ), | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For robustness and consistency with other parts of the code (e.g.,
enable_autotune), it's better to usekwargs.get("routing_bias")instead of direct access. This will prevent aKeyErrorif the key is missing.