Feat: Trtllm-gen MxFP8 MoE integration#2505
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an FP8 quantization enum and MxFP8 support across Python, C++ launchers, benchmarks, and tests; threads a new Changes
Sequence Diagram(s)mermaid CLI->>Autotuner: parse --quant-mode (e.g., MxFP8xMxFP8) Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the TensorRT-LLM fused Mixture-of-Experts (MoE) implementation by integrating MxFP8 quantization. This integration provides a new, flexible FP8 quantization option alongside the existing DeepSeek FP8, allowing for fine-grained control over mixed-precision computations. The changes span core kernel logic, benchmarking, and testing, ensuring that the new quantization mode is robustly supported and validated across the system. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates mxfp8 support into the trtllm fused MoE kernels. The changes are extensive, touching benchmark scripts, C++ kernel launchers, and Python bindings. The introduction of Fp8QuantizationType is a good refactoring that makes the code more extensible. The tests have also been updated to cover the new quantization modes.
My review focuses on improving code maintainability by reducing duplication in the benchmark scripts and C++ kernel launcher. I've also pointed out some leftover debugging code and minor issues that should be addressed before merging.
| print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") | ||
|
|
||
|
|
||
| def bench_trtllm_gen_fused_moe_autotuner_mxint4( |
There was a problem hiding this comment.
This function bench_trtllm_gen_fused_moe_autotuner_mxint4 is very similar to bench_trtllm_gen_fused_moe_autotuner_fp8 and bench_trtllm_gen_fused_moe_autotuner_fp4. To improve maintainability and reduce code duplication, consider refactoring these into a more generic benchmark function or a base class. This could accept quantization functions and the specific MoE kernel as parameters, centralizing the common benchmarking logic.
| FusedMoeLauncher::check_moe_common(); | ||
|
|
||
| TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) | ||
| << "hidden_states_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) | ||
| << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) | ||
| << "hidden_states_scale dim1 must match num_tokens."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8); | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) | ||
| << "gemm1_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) | ||
| << "gemm1_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8) | ||
| << "gemm1_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) | ||
| << "gemm2_weights_scale must be float."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) | ||
| << "gemm2_weights_scale has incorrect shape."; | ||
| } else if (quantization_type == Fp8QuantizationType::MxFp8) { | ||
| TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8) | ||
| << "gemm2_weights_scale must be uint8."; | ||
| } | ||
|
|
||
| check_weights_shape("gemm1"); | ||
| check_weights_shape("gemm2"); | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
|
|
||
| if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { | ||
| TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) | ||
| << "intermediate_size must be a multiple of 128."; | ||
| } | ||
| } |
There was a problem hiding this comment.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Hi @IwakuraRein . Currently we use this in sgl. However it seems like we are missing cubin for some dim. I build from src from this branch on this commit 1dc688d Context: we are building the sglang MXFP8 trtllm_moe runner along with mm_mxfp8 flashinfer modelopt linear, so this would be quite useful. If it turns out that my usages is wrong... user error. but even after inspect cubin, it seem like this shape should be available. Do you have any ideas? should there be tileSize=64 cubin? |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
@vincentzed Hi. There are tile size 64 cubins for mxfp8. I tried your problem shape and cannot reproduce the error. Could you try pull the latest commit? 1dc688d won't compile due to a typo so maybe flashinfer is using the old jit cache. |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
0adc056 to
aae1719
Compare
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[CANCELING] Pipeline #43998281: canceled |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1079-1105:⚠️ Potential issue | 🔴 Critical
getValidConfigsuses wrong Runner constructor for MxFp8, causing config mismatch with runtime.For MxFp8,
prepare_moe_common(lines 326–335) constructs the Runner with the two-dtype constructor (passingmDtypeAct,mDtypeWeights,activation_type) when the conditionE4m3 && E4m3 && mUseDeepSeekFp8is false. However,getValidConfigsalways uses the weights-only constructor (line 1091–1094), regardless ofquantization_type. This means config enumeration and the actual kernel runner see different valid config sets — the root cause of "No kernel found" errors at runtime.Proposed fix: branch getValidConfigs to match prepare_moe_common logic
for (int32_t tile_N : selected_tile_nums) { - auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( - dtype_weights, // dtype_weights for DeepSeek FP8 - quantization_type == Fp8QuantizationType::DeepSeekFp8, // useDeepSeekFp8 - tile_N, use_shuffled_weight, static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, true /* useDeepSeekFp8 */, tile_N, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } else { + // MxFp8: match two-dtype constructor from prepare_moe_common + moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( + dtype_weights, dtype_weights, false /* useDeepSeekFp8 */, tile_N, + ActivationType::Swiglu, use_shuffled_weight, + static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout)); + } auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens);
🧹 Nitpick comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1004-1012: MxFp8 path does not explicitly setworkspace.activation_output/workspace.activation_output_scale.Only the DeepSeekFp8 branch (lines 1007–1010) assigns these workspace pointers. The MxFp8 path relies on implicit zero-initialization. Consider explicitly setting them to
nullptrto be safe against future refactors whereprepare_moemight be re-entered or workspace partially reused.Proposed fix
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { workspace.activation_output = activation_output.data_ptr(); workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr()); + } else { + workspace.activation_output = nullptr; + workspace.activation_output_scale = nullptr; }
1006-1006:static_cast<float*>on adl_uint8tensor for MxFp8 — type mismatch in workspace pointer.For MxFp8,
gemm1_output_scaleis allocated asdl_uint8(line 990), but line 1006 unconditionally casts it tofloat*. The kernel likely consumes the raw address, but this cast is misleading and could mask bugs if the workspace struct gains type-safety. Consider avoid*intermediate or a comment noting the intentional reinterpretation.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
3e0dbdd to
03cac02
Compare
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #44028049: 14/20 passed |
|
Hey, @IwakuraRein We want to use it with Nemotron models: |
Hi @danisereb, currently the cubins for Relu2 are not generated yet. We can add it in another PR. |
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> #2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Author: @nekorobov
Add the trtllm-gen mxfp8 moe. It uses the existing
trtllm_fp8_block_scale_moeapi and can be selected by settingfp8_quantization_type🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Tests
Chores