Skip to content

Add support for FlashInfer mxfp8#18945

Draft
zianglih wants to merge 4 commits intosgl-project:mainfrom
zianglih:flashinfer-mxfp8
Draft

Add support for FlashInfer mxfp8#18945
zianglih wants to merge 4 commits intosgl-project:mainfrom
zianglih:flashinfer-mxfp8

Conversation

@zianglih
Copy link
Contributor

@zianglih zianglih commented Feb 18, 2026

Motivation

@HumansAnd

WIP.

FlashInfer has very good mxfp8 kernels.

Modifications

Dependency

Dense linear requires #19005
MoE requires flashinfer-ai/flashinfer#2581

Accuracy Tests

# Preparation

# Install this sglang branch
cd /sgl-workspace/
rm -rf sglang
git clone -b flashinfer-mxfp8 https://github.com/zianglih/sglang.git && \
cd sglang && \
pip install --upgrade pip && \
pip install -e "python"

# Install miles for converting HF models to MXFP8
cd /root/
rm -rf miles
git clone -b main https://github.com/zianglih/miles.git
cd /root/miles
pip install -e . --no-deps

# Install flashinfer from source
pip uninstall -y flashinfer-jit-cache flashinfer-python flashinfer-cubin
cd /root
git clone -b mxfp8 https://github.com/zianglih/flashinfer.git --recursive flashinfer-src
cd flashinfer-src
python -m pip install --no-build-isolation -e . -v


# Download and convert the model to MXFP8
hf download Qwen/Qwen3-30B-A3B-Instruct-2507 --local-dir /data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507
python /root/miles/tools/convert_hf_to_mxfp8.py --model-dir /data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507 --save-dir /data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507-MXFP8

# Previous working baseline
pkill -9 sglang ; pkill -9 python
cd /sgl-workspace/sglang
python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507-MXFP8 --fp8-gemm-backend triton --moe-runner-backend cutlass &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.963
Invalid: 0.000
Latency: 16.731 s
Output throughput: 10141.918 token/s
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507-MXFP8",
    "flush_cache": true,
    "abort_all_requests": false
  }'
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.963
Invalid: 0.000
Latency: 17.151 s
Output throughput: 9893.989 token/s
# With FlashInfer MXFP8
pkill -9 sglang ; pkill -9 python
cd /sgl-workspace/sglang
python -m sglang.launch_server --kv-cache-dtype bf16 --model /data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507-MXFP8 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_cutlass &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.964
Invalid: 0.000
Latency: 9.902 s
Output throughput: 17213.156 token/s
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/home/ziangli/models/Qwen3-30B-A3B-Instruct-2507-MXFP8",
    "flush_cache": true,
    "abort_all_requests": false
  }'
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.964
Invalid: 0.000
Latency: 9.740 s
Output throughput: 17493.374 token/s

Benchmarking and Profiling

On B200 TP1 Qwen3-30B-A3B-Instruct-2507-MXFP8, python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum throughput is 1.7x

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the MXFP8 quantization capabilities by integrating FlashInfer's highly optimized kernels. It introduces FlashInfer support for both general linear operations and Mixture of Experts (MoE) layers, aiming to improve performance and efficiency for models utilizing MXFP8 quantization, especially on Blackwell GPUs. The changes involve updating the quantization dispatch logic and server arguments to enable these new backends.

Highlights

  • FlashInfer MXFP8 Linear Operations: Added support for FlashInfer's MXFP8 kernels for linear operations, enhancing performance for mixed-precision computations, especially on Blackwell GPUs.
  • FlashInfer CUTLASS MoE Integration: Integrated FlashInfer CUTLASS as a new backend for MXFP8 Mixture of Experts (MoE) operations, providing an optimized path for MoE layers.
  • Server Configuration Updates: Updated server configuration to allow flashinfer_cutlass as a valid MoE runner backend when using MXFP8 quantization, and adjusted the default moe_runner_backend for MXFP8.
  • Dynamic Kernel Selection: Refactored the Fp8Quantizer to dynamically select between Triton and FlashInfer MXFP8 linear kernels based on the configured backend.
  • MoE Weight Loading Property: Introduced a new property load_up_proj_weight_first in Fp8Quantizer to handle specific weight loading orders required by FlashInfer CUTLASS MoE.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/quantization/fp8.py
    • Imported dispatch_w8a8_mxfp8_linear and removed triton_mxfp8_blockscaled_linear from imports.
    • Added imports for is_flashinfer_available and next_power_of_2.
    • Introduced conditional imports for flashinfer.fused_moe components.
    • Modified the Fp8Quantizer's __init__ method to initialize w8a8_mxfp8_linear and w8a8_block_fp8_linear conditionally based on use_mxfp8.
    • Updated the Fp8Quantizer.apply method to utilize self.w8a8_mxfp8_linear for MXFP8 linear operations.
    • Added a load_up_proj_weight_first property to Fp8Quantizer to manage weight loading order for FlashInfer CUTLASS MoE.
    • Implemented the _apply_flashinfer_cutlass_mxfp8 method to handle MoE processing with FlashInfer CUTLASS.
    • Integrated the new _apply_flashinfer_cutlass_mxfp8 method into the main apply logic when the FlashInfer CUTLASS MoE backend is active.
  • python/sglang/srt/layers/quantization/fp8_utils.py
    • Imported flashinfer.mm_mxfp8 and set it to None if FlashInfer is unavailable.
    • Added the dispatch_w8a8_mxfp8_linear function to select the appropriate MXFP8 linear kernel (Triton or FlashInfer).
    • Implemented the flashinfer_mxfp8_blockscaled_linear function, providing the FlashInfer-specific implementation for MXFP8 dense linear operations.
  • python/sglang/srt/server_args.py
    • Modified the _handle_moe_kernel_config method to include flashinfer_cutlass as a supported moe_runner_backend for mxfp8 quantization.
    • Adjusted the default moe_runner_backend for mxfp8 from a forced cutlass to auto, which will then default to cutlass if not explicitly set to flashinfer_cutlass.
Activity
  • No human activity (comments, reviews, etc.) has been recorded for this pull request yet. The description includes instructions for accuracy tests and benchmarking.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FlashInfer mxfp8, including both linear layers and MoE layers. The changes are well-structured, introducing new dispatch logic and kernels for the mxfp8 backend. The code includes necessary validations and updates to server arguments to accommodate the new feature. My review includes a couple of suggestions to improve error messaging for better user experience and to enhance code clarity for future maintenance. Overall, the implementation appears solid.

Comment on lines +1390 to +1391
w13_scale_block = layer.w13_weight_scale_inv.contiguous().view(torch.int32)
w2_scale_block = layer.w2_weight_scale_inv.contiguous().view(torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conversion from uint8 (the dtype of weight_scale_inv for mxfp8) to int32 using .view() is implicit and relies on the memory layout. For better code clarity and maintainability, it would be beneficial to add a comment explaining this conversion, especially since it's part of a low-level kernel integration.

        # The flashinfer kernel expects scales to be packed as int32, where each int32 contains four uint8 scales.
        # The shape check `shape[-1] % 4 == 0` above ensures this view is safe.
        w13_scale_block = layer.w13_weight_scale_inv.contiguous().view(torch.int32)
        w2_scale_block = layer.w2_weight_scale_inv.contiguous().view(torch.int32)

Comment on lines +211 to +215
if not (is_blackwell_supported() and is_flashinfer_available()):
raise RuntimeError(
"MXFP8 FlashInfer GEMM requested via --fp8-gemm-backend=flashinfer_trtllm, "
"but FlashInfer is unavailable or unsupported on this hardware."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is a bit generic. It could be more specific about the requirements to help users debug, similar to other error messages in this file. For example, it could mention that Blackwell GPUs and FlashInfer installation are required.

Suggested change
if not (is_blackwell_supported() and is_flashinfer_available()):
raise RuntimeError(
"MXFP8 FlashInfer GEMM requested via --fp8-gemm-backend=flashinfer_trtllm, "
"but FlashInfer is unavailable or unsupported on this hardware."
)
if not (is_blackwell_supported() and is_flashinfer_available()):
raise RuntimeError(
"MXFP8 FlashInfer GEMM requested via --fp8-gemm-backend=flashinfer_trtllm, "
"but it is not available or supported on this hardware. "
"This backend requires Blackwell (SM100+) GPUs and FlashInfer to be installed."
)

@zianglih zianglih marked this pull request as ready for review February 18, 2026 09:42
@zianglih zianglih marked this pull request as draft February 26, 2026 19:47
aleozlx pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Mar 7, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

@HumansAnd

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

#2505 implements mxfp8
for trtllm backend.

However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses
SGLang topk implementation and does not work with expert routing replay
in MoE RL.

We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works
with MoE RL training.

This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`:

https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191

## 🔍 Related Issues

<!-- Link any related issues here -->
miles MXFP8/NVFP4 RL roadmap:
radixark/miles#615
SGLang FlashInfer MXFP8 integration:
sgl-project/sglang#18945

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Toggleable MXFPX/MXFP8 activation-scaling across MOE inference,
updating workspace sizing, kernel selection, block-scaling and dispatch
to enable MXFP8-aware execution and validation.
* Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware
GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling
flag.

* **Tests**
* Added unit tests and helpers for MXFP8 quantization,
packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

@HumansAnd

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#2505 implements mxfp8
for trtllm backend.

However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses
SGLang topk implementation and does not work with expert routing replay
in MoE RL.

We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works
with MoE RL training.

This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`:

https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191

## 🔍 Related Issues

<!-- Link any related issues here -->
miles MXFP8/NVFP4 RL roadmap:
radixark/miles#615
SGLang FlashInfer MXFP8 integration:
sgl-project/sglang#18945

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Toggleable MXFPX/MXFP8 activation-scaling across MOE inference,
updating workspace sizing, kernel selection, block-scaling and dispatch
to enable MXFP8-aware execution and validation.
* Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware
GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling
flag.

* **Tests**
* Added unit tests and helpers for MXFP8 quantization,
packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

@HumansAnd

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#2505 implements mxfp8
for trtllm backend.

However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses
SGLang topk implementation and does not work with expert routing replay
in MoE RL.

We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works
with MoE RL training.

This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`:

https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191

## 🔍 Related Issues

<!-- Link any related issues here -->
miles MXFP8/NVFP4 RL roadmap:
radixark/miles#615
SGLang FlashInfer MXFP8 integration:
sgl-project/sglang#18945

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Toggleable MXFPX/MXFP8 activation-scaling across MOE inference,
updating workspace sizing, kernel selection, block-scaling and dispatch
to enable MXFP8-aware execution and validation.
* Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware
GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling
flag.

* **Tests**
* Added unit tests and helpers for MXFP8 quantization,
packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant