enable flashinfer fp8 gemm if deepgemm disabled#13608
Conversation
Summary of ChangesHello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a performance optimization for FP8 General Matrix Multiply (GEMM) operations. It modifies the activation conditions for Flashinfer's FP8 GEMM implementation, ensuring it is utilized on Blackwell architectures (SM100) when the DeepGEMM library is not enabled. This strategic enablement aims to bypass less performant Triton kernels, thereby enhancing computational efficiency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly enables Flashinfer FP8 GEMM when DeepGEMM is disabled on Blackwell GPUs. This change aligns with the goal of using more performant kernels, as Flashinfer is preferred over the Triton fallback in this scenario. The introduction of the FLASHINFER_FP8_GEMM_SUPPORTED variable improves code readability by clearly defining the conditions for Flashinfer support. The logic is sound and the implementation is clean.
|
But will this block the usage of other fp8 kernels, like the cutlass one? |
|
@Fridge003 No it still works, you can still enable it through |
But when I only add flag |
|
Err, good point. Here is the dispatching logics now:
|
There was a problem hiding this comment.
Pull Request Overview
This PR enables FlashInfer FP8 GEMM as an automatic fallback when DeepGEMM is disabled on SM100 (GB100) GPUs, providing better performance than Triton for small batch sizes.
Key changes:
- Moves
is_blackwell_supportedfromfp8_utilstoutilsmodule and switches tois_sm100_supportedfor FlashInfer support detection - Auto-enables FlashInfer FP8 GEMM when DeepGEMM is disabled and CUTLASS is not explicitly configured
- Reorganizes dispatch priority to check DeepGEMM first before other backends
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| python/sglang/srt/layers/quantization/modelopt_quant.py | Updates import to use is_blackwell_supported from utils instead of fp8_utils |
| python/sglang/srt/layers/quantization/fp8_utils.py | Refactors FlashInfer FP8 GEMM enablement logic and dispatch priority, switches from Blackwell to SM100 detection |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
f74e3b7 to
103bfd8
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 | ||
| from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil | ||
| from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader | ||
| from sglang.srt.utils import ceil_div, offloader |
There was a problem hiding this comment.
[nitpick] The imports from sglang.srt.utils are split across two separate import statements (lines 9 and 30-40). Consider consolidating them into a single import statement for better code organization and consistency with Python style guidelines.
|
Cleaner solution in #14379 |
If DeepGEMM is disabled and on SM100, don't fall back to Triton, we have Flashinfer FP8 GEMM available, which is a lot more performant.
The reasoning is bc personally, I use this for when the BS is small, the speedup is quite noticeable compared to DeepGEMM.