Skip to content

Commit fbdb439

Browse files
authored
Enable renormalize(naive) routing for fp8 per-tensor (#2030)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Disable expert weights in the FC1 except for Llama routing. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Re-enabled Renormalize routing that was previously blocked. * Made token_scales available for Llama4 routing. * Corrected GEMM1 input so the proper data source is used during MoE processing. * **Tests** * Added FP8PerTensorMoe to test parameterization. * Expanded Renormalize and DeepSeekV3 test coverage and removed related skips. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: siyuanf <[email protected]>
1 parent d42fb90 commit fbdb439

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

β€Žcsrc/trtllm_fused_moe_kernel_launcher.cuβ€Ž

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,9 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
584584
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
585585

586586
workspace.expert_weights = expert_weights.data_ptr();
587+
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
588+
workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel
589+
}
587590
}
588591

589592
void check_moe() const override {

β€Žcsrc/trtllm_fused_moe_runner.cuβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d
518518
auto const& config = mPassingConfigs[configIndex];
519519

520520
mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights,
521-
args.gemm1_weights_scale, workspace.expert_weights, args.output1_scales_scalar,
521+
args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar,
522522
args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha,
523523
args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output,
524524
workspace.gemm1_output_scale, args.top_k, args.hidden_size,

β€Žinclude/flashinfer/trtllm/fused_moe/runner.hβ€Ž

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,11 @@ struct MoEWorkspace {
305305
int32_t* expanded_idx_to_permuted_idx = nullptr;
306306
int32_t* permuted_idx_to_expanded_idx = nullptr;
307307
int32_t* permuted_idx_to_token_idx = nullptr;
308+
309+
// consumed by finalize kernel
308310
void* expert_weights = nullptr; // [num_tokens, top_k] in bfloat16 = mDtypeExpW
311+
// consumed by permuteGemm1 kernel
312+
void* token_scales = nullptr;
309313

310314
int32_t* cta_idx_xy_to_batch_idx = nullptr;
311315
int32_t* cta_idx_xy_to_mn_limit = nullptr;

β€Žtests/moe/test_trtllm_gen_fused_moe.pyβ€Ž

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,6 +2275,7 @@ def run_moe_test(
22752275
[
22762276
pytest.param(BF16Moe(), id="BF16xBF16"),
22772277
pytest.param(FP8BlockScaleMoe(), id="FP8_Block"),
2278+
pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"),
22782279
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"),
22792280
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"),
22802281
pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"),
@@ -2293,7 +2294,12 @@ def run_moe_test(
22932294
"routed_scaling": None,
22942295
"has_routing_bias": False,
22952296
"routing_method_type": RoutingMethodType.Renormalize,
2296-
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe],
2297+
"compatible_moe_impls": [
2298+
FP8PerTensorMoe,
2299+
FP8BlockScaleMoe,
2300+
FP4Moe,
2301+
BF16Moe,
2302+
],
22972303
"compatible_intermediate_size": [384, 768, 1024],
22982304
},
22992305
id="Qwen3",
@@ -2308,7 +2314,12 @@ def run_moe_test(
23082314
"routed_scaling": None,
23092315
"has_routing_bias": False,
23102316
"routing_method_type": RoutingMethodType.Renormalize,
2311-
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe],
2317+
"compatible_moe_impls": [
2318+
FP8PerTensorMoe,
2319+
FP8BlockScaleMoe,
2320+
FP4Moe,
2321+
BF16Moe,
2322+
],
23122323
"compatible_intermediate_size": [384, 1024],
23132324
},
23142325
id="Renorm",
@@ -2323,7 +2334,12 @@ def run_moe_test(
23232334
"routed_scaling": None,
23242335
"has_routing_bias": False,
23252336
"routing_method_type": RoutingMethodType.Renormalize,
2326-
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe, BF16Moe],
2337+
"compatible_moe_impls": [
2338+
FP8PerTensorMoe,
2339+
FP8BlockScaleMoe,
2340+
FP4Moe,
2341+
BF16Moe,
2342+
],
23272343
"compatible_intermediate_size": [512],
23282344
},
23292345
id="Qwen3_next",

0 commit comments

Comments
Β (0)