[feat] trtllm-gen mxfp8 gemm#2653
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds TRTLLM backend and runners for quantized GEMM (MXFP8/FP4/FP8), centralizes TritLLM enums into Changes
Sequence DiagramsequenceDiagram
participant Client as Python Client
participant API as flashinfer API
participant Dispatcher as Backend Dispatcher
participant TrtRunner as TRTLLM Runner
participant KernelMod as TRTLLM Kernel Module
participant GPU as CUDA Device
Client->>API: call mm_mxfp8(..., backend="trtllm", sf_layout=...)
API->>Dispatcher: select backend and prepare per-backend inputs
Dispatcher->>TrtRunner: pass inputs, sf_layout, dtypes, options
TrtRunner->>KernelMod: request tactics/configs (metadata)
KernelMod->>TrtRunner: return tactics/configs
loop try tactics
TrtRunner->>GPU: load/instantiate kernel with chosen tactic
end
TrtRunner->>GPU: execute trtllm_gemm with inputs, scales, ptrScaleAct, sparsity info
GPU->>Client: return output
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.Add a .trivyignore file to your project to customize which findings Trivy reports. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new feature that enables MxFp8 GEMM operations, significantly expanding the mixed-precision capabilities of the system. The changes involve deep modifications across the C++ and Python components, focusing on improving the flexibility and robustness of GEMM configurations. Key updates include explicit handling of input/output data types, support for valid problem dimensions, and the introduction of sparsity and advanced CUDA architecture-specific optimizations. The refactoring also centralizes enum definitions for better code maintainability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for mxfp8 GEMM using trtllm-gen kernels, which involves a substantial refactoring of both Python and C++ code to create a more generalized GEMM infrastructure. The Python code has been improved by centralizing enums and creating a generic GEMM runner factory. The C++ side sees extensive updates to support new hardware features like sparsity and flexible cluster dimensions. The changes are well-structured and enhance the project's capabilities. I have one suggestion to improve code clarity and remove a redundancy in the newly added enum definitions.
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
…antize Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
[FAILED] Pipeline #46298268: 10/20 passed |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #46494877: 6/20 passed |
|
/bot run |
|
[CANCELING] Pipeline #46541481: canceled |
|
/bot run |
|
[SUCCESS] Pipeline #46551653: 14/20 passed |
📌 Description
flashinfer/tllm_enums.pyfor storing Trtllm-gen related enums.trtllmbackend tomm_mxfp8use_8x4_sf_layoutas the last argumentget_trtllm_gemm_module()fp8Quantize.cpp. It supports either 128x4 or 8x4 swizzle layout. This is needed since the first matrix of trtllm-gen mxfp8 GEMM can be 8x4 swizzle layout.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
Unit tests:
pytest tests/gemm/test_mm_mxfp8.py -k trtllmBenchmark:
python benchmarks/flashinfer_benchmark.py --routine mm_mxfp8 --backend trtllm [--use_128x4_sf_layout]unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Breaking Changes