Skip to content

[Quant][Feature] Support online MXFP8 quantization for MoE and dense models#35448

Merged
mgoin merged 4 commits intovllm-project:mainfrom
EdalatiAli:fi-moe-mxfp8
Mar 16, 2026
Merged

[Quant][Feature] Support online MXFP8 quantization for MoE and dense models#35448
mgoin merged 4 commits intovllm-project:mainfrom
EdalatiAli:fi-moe-mxfp8

Conversation

@EdalatiAli
Copy link
Contributor

@EdalatiAli EdalatiAli commented Feb 26, 2026

Purpose

Add support for online MXFP8 quantization (--quantization mxfp8), enabling BF16/FP16 models to be dynamically quantized to MXFP8 (microscaling FP8 with block-32 scales) at load time — for both linear layers and MoE expert layers.

This is powered by the FlashInfer kernels

This PR implements part of the online quantization support proposed in #32029 and #32412

Usage

# Serve any BF16 model with online MXFP8 quantization
vllm serve <model> --quantization mxfp8

# Or with layer skipping via config.json:
# "quantization_config": {
#     "quant_method": "mxfp8",
#     "modules_to_not_convert": ["lm_head", "mlp.gate"]
# }

Requires SM 100+ (Blackwell) GPU.

Test Plan

E2E tests for both a dense model (Qwen/Qwen3-0.6B) and a MoE model (Qwen/Qwen3-30B-A3B): logprobs comparison against BF16 baseline + generation smoke test.

python -m pytest tests/models/quantization/test_mxfp8.py -v

In addition, we report the the accuracy on MMLU_pro and GM8K using lm_eval_harness as well as performance benchmarks for Qwen/Qwen3-30B-A3B.

Test Result

Accuracy

lm_eval \
  --model vllm \
  --trust_remote_code \
  --model_args pretrained=Qwen/Qwen3-30B-A3B,quantization=mxfp8,enforce_eager=True \
  --tasks gsm8k,mmlu_pro \
  --batch_size auto
Task BF16 MXFP8
GSM8K (flexible-extract) 85.7 87.8
GSM8K (strict-match) 89.7 88.2
MMLU-Pro 69.2 68.9

Performance

vllm bench throughput --model Qwen/Qwen3-30B-A3B \
--tensor-parallel-size 1 \
--trust-remote-code \
--async-scheduling \
--backend vllm \
--dataset-name random \
--random-prefix-len 0 \
--random-input-len 1024 \
--random-output-len 1024 \
--max-num-seqs 128 \
--num-prompts 512 \
--quantization mxfp8    ## Remove for the bf16 model

BF16 performance

Throughput: 7.73 requests/s, 15833.85 total tokens/s, 7916.92 output tokens/s

MXFP8 performance

Throughput: 10.34 requests/s, 21179.37 total tokens/s, 10589.68 output tokens/s


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for online MXFP8 MoE quantization. The changes are comprehensive, adding a new Mxfp8Config and Mxfp8OnlineMoEMethod, updating backend selection logic, and modifying the flashinfer_fused_moe_blockscale_fp8 custom op. The implementation correctly handles the specifics of MXFP8, such as block shapes and scale types. I've identified a critical bug in the Mxfp8OnlineMoEMethod.create_weights method related to parameter naming that would cause a runtime error. My review includes a suggested fix for this issue. Overall, the changes are well-structured to integrate the new quantization mode.

@mergify
Copy link

mergify bot commented Mar 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @EdalatiAli.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 4, 2026
@EdalatiAli EdalatiAli changed the title [Quant][Feature] Support online MXFP8 MoE quantization using trtllm_fp8_block_scale_moe [Quant][Feature] Support online MXFP8 quantization for MoE and dense models Mar 5, 2026
@EdalatiAli EdalatiAli marked this pull request as ready for review March 5, 2026 01:34
@EdalatiAli EdalatiAli marked this pull request as draft March 10, 2026 15:36
@EdalatiAli EdalatiAli force-pushed the fi-moe-mxfp8 branch 2 times, most recently from a630758 to 1e9ce28 Compare March 10, 2026 19:10
@mergify mergify bot removed the needs-rebase label Mar 10, 2026
@EdalatiAli EdalatiAli marked this pull request as ready for review March 10, 2026 20:31
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Mar 11, 2026
@mergify
Copy link

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @EdalatiAli.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
from flashinfer.fused_moe import Fp8QuantizationType

assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you assert SILU but don't restrict selection in _supports_activation, is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's necessary — the monolithic class serves both block-scale (SILU only, hardcoded in flashinfer) and per-tensor (SILU + RELU2) paths, and _supports_activation can't distinguish them. Adding quant context to _supports_activation would require changing the abstract interface and every implementation across the codebase.

Comment on lines +226 to +237
# For Blackwell block-FP8 (used by online MXFP8), prefer FlashInfer TRTLLM
# so execution goes through the monolithic blockscale kernel path.
if (
current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
and weight_key == kMxfp8Static
and activation_key == kMxfp8Dynamic
and Fp8MoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS
):
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM)
AVAILABLE_BACKENDS.insert(0, Fp8MoeBackend.FLASHINFER_TRTLLM)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have an mxfp8 oracle at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py, could we use that rather than overloading fp8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I changed the code to use https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py. I had to make a few other minor changes to correctly use it.

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
@EdalatiAli EdalatiAli requested a review from mgoin March 13, 2026 21:07
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay seems reasonable to me to accept. I'm not sure how much we should truly reuse with the fp8 methods but it is fair enough to follow that pattern for now

Comment on lines +18 to +26
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
{
Fp8MoeBackend.FLASHINFER_TRTLLM,
}
)

class MxFp8MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit confusing to use Fp8MoeBackend here and elsewhere for mxfp8, but I guess it is needed to reuse the moe utils

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, using Fp8MoeBackend was needed to reuse the rest of the FP8 moe utils.
We can follow a better approach when more MXFP8 backend are available in the future.
Thank you for the feedback!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 16, 2026
@mgoin mgoin merged commit e5b8076 into vllm-project:main Mar 16, 2026
65 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 16, 2026
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…models (vllm-project#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
andylolu2 pushed a commit to andylolu2/vllm that referenced this pull request Mar 18, 2026
…models (vllm-project#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…models (vllm-project#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…models (vllm-project#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…models (vllm-project#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants