Skip to content

[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp#34664

Open
mgoin wants to merge 15 commits intovllm-project:mainfrom
neuralmagic:mxfp8-marlin
Open

[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp#34664
mgoin wants to merge 15 commits intovllm-project:mainfrom
neuralmagic:mxfp8-marlin

Conversation

@mgoin
Copy link
Copy Markdown
Member

@mgoin mgoin commented Feb 17, 2026

Purpose

The Marlin kernel already supports FP8 (per-channel/group scales) and MXFP4 (per-32-element e8m0 scales). MXFP8 is a natural combination: FP8 weights (like existing FP8 Marlin) with e8m0 microscaling block scales (like existing MXFP4 Marlin). We just have to wire the kernel building blocks together.

This PR also consolidates gemm kernel backend specific logic more into the Mxfp8LinearOp class for modelopt.py and mxfp8.py

Test Plan

Existing online quant test for mxfp8 will now run on L4 in CI tests/models/quantization/test_mxfp8.py

Test Result

vllm serve mgoin/Qwen3-0.6B-MXFP8
vllm serve Qwen/Qwen3-0.6B --quantization mxfp8
vllm serve Qwen/Qwen3-30B-A3B --quantization mxfp8

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MXFP8 quantization in the Marlin kernel, providing a faster alternative to the existing emulation path. The changes span across kernel generation, C++ dispatch logic, and Python-level integration. The implementation introduces new utility functions for handling MXFP8-specific weight and scale preparation for Marlin. My review identifies a critical issue in the hardware capability check that could lead to runtime errors on unsupported GPUs.

@mgoin mgoin added performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed labels Feb 20, 2026
@danisereb
Copy link
Copy Markdown
Contributor

danisereb commented Feb 22, 2026

Hey @mgoin,
please also see my PR for the Flashinfer cutlass MXFP8 GEMM:
#35053

The GEMM is available in flashinfer 0.6.4 (recently bumped in vLLM).

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 24, 2026
mgoin added 4 commits March 19, 2026 17:49
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Move backend selection, weight processing, and apply logic from
ModelOptMxFp8LinearMethod into Mxfp8LinearOp so all MXFP8 linear
backends (emulation, flashinfer CUTLASS, Marlin) are managed in
one place.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: mgoin <mgoin64@gmail.com>
Use select_mxfp8_linear_backend() and delegate weight processing
to Mxfp8LinearOp.process_weights(), enabling Marlin backend support
for online MXFP8 quantization on SM80+ GPUs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: mgoin <mgoin64@gmail.com>
mgoin added 5 commits March 19, 2026 18:07
Removed comment about backend-specific weight processing.
Use torch.get_default_dtype() and layer.output_size_per_partition /
layer.input_size_per_partition directly instead of stashing copies
as layer.orig_dtype, layer.marlin_size_n, layer.marlin_size_k.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin changed the title Add MXFP8 to Marlin dense kernel [Kernel] Add MXFP8 to Marlin dense kernel Mar 19, 2026
@mgoin mgoin added the nvidia label Mar 19, 2026
@mgoin mgoin changed the title [Kernel] Add MXFP8 to Marlin dense kernel [Kernel] Add MXFP8 to Marlin dense kernel and refactor Mxfp8LinearOp Mar 19, 2026
mgoin added 2 commits March 30, 2026 16:13
@mgoin mgoin requested a review from WoosukKwon as a code owner March 30, 2026 21:15
@mgoin mgoin changed the title [Kernel] Add MXFP8 to Marlin dense kernel and refactor Mxfp8LinearOp [Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp Mar 30, 2026
Signed-off-by: mgoin <mgoin64@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 31, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2026
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Signed-off-by: mgoin <mgoin64@gmail.com>
@mergify mergify bot removed the needs-rebase label Mar 31, 2026
Signed-off-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants