Skip to content

[NVFP4] Support NVFP4 MOE models on AMD Instinct, Nvidia Ampere, Hopper through NVFP4 MOE emulation#35737

Open
fxmarty-amd wants to merge 47 commits intovllm-project:mainfrom
fxmarty-amd:upstream-nvfp4-simulation-support-moe
Open

[NVFP4] Support NVFP4 MOE models on AMD Instinct, Nvidia Ampere, Hopper through NVFP4 MOE emulation#35737
fxmarty-amd wants to merge 47 commits intovllm-project:mainfrom
fxmarty-amd:upstream-nvfp4-simulation-support-moe

Conversation

@fxmarty-amd
Copy link
Copy Markdown
Contributor

@fxmarty-amd fxmarty-amd commented Mar 2, 2026

This PR depends on #35733 for dense models. Please see the correct diff at: fxmarty-amd/vllm@upstream-nvfp4-simulation-support-rocm...upstream-nvfp4-simulation-support-moe

Purpose

This PR enables running NVFP4 MOE models on AMD Instinct, Nvidia Ampere, Hopper.

This is useful for researchers, anybody trying out microscaling formats, and people who would like to run e.g. https://huggingface.co/nvidia/Qwen3-30B-A3B-NVFP4 or https://huggingface.co/RedHatAI/Qwen3-30B-A3B-NVFP4 on non-Blackwell devices.

Test Plan

See

And see

export PRETRAINED_PATH="/shareddata/Qwen/Qwen3-30B-A3B"

CUDA_VISIBLE_DEVICES=4 nohup lm_eval \
  --model vllm \
  --model_args '{"pretrained":"'"${PRETRAINED_PATH}"'","dtype":"auto","tensor_parallel_size":1,"enable_thinking": false,"chat_template_args":{"enable_thinking":false}}' \
  --device "cuda" \
  --tasks wikitext,piqa \
  --batch_size auto &> lm_eval.log &

giving:

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7922|±  |0.0095|
|        |       |none  |     0|acc_norm       |↑  | 0.8030|±  |0.0093|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6443|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5630|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |10.8936|±  |   N/A|

And export PRETRAINED_PATH="/shareddata/nvidia/Qwen3-30B-A3B-NVFP4"

(EngineCore_DP0 pid=183384) INFO 03-04 16:36:56 [nvfp4_utils.py:87] Using NvFp4LinearBackend.EMULATION for NVFP4 GEMM
(EngineCore_DP0 pid=183384) INFO 03-04 16:36:56 [rocm.py:464] Using Triton Attention backend.
(EngineCore_DP0 pid=183384) INFO 03-04 16:36:56 [nvfp4.py:266] Using 'EMULATION' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN', 'EMULATION'].
(EngineCore_DP0 pid=183384) WARNING 03-04 16:37:12 [quantization_emulation_moe.py:51] Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will dequantize weights on the fly and may be slower than native quantized MOE. Consider using a device with native quantization support (e.g. Nvidia Blackwell) for better performance.
...
| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7867|±  |0.0096|
|        |       |none  |     0|acc_norm       |↑  | 0.7938|±  |0.0094|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6527|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5721|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |11.2391|±  |   N/A|

And export PRETRAINED_PATH="/shareddata/RedHatAI/Qwen3-30B-A3B-NVFP4"

(EngineCore_DP0 pid=184824) INFO 03-04 16:39:30 [nvfp4_utils.py:87] Using NvFp4LinearBackend.EMULATION for NVFP4 GEMM
(EngineCore_DP0 pid=184824) INFO 03-04 16:39:30 [rocm.py:464] Using Triton Attention backend.
(EngineCore_DP0 pid=184824) INFO 03-04 16:39:30 [nvfp4.py:266] Using 'EMULATION' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN', 'EMULATION'].
(EngineCore_DP0 pid=184824) WARNING 03-04 16:39:36 [quantization_emulation_moe.py:51] Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will dequantize weights on the fly and may be slower than native quantized MOE. Consider using a device with native quantization support (e.g. Nvidia Blackwell) for better performance.
...
| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|piqa    |      1|none  |     0|acc            |↑  | 0.7867|±  |0.0096|
|        |       |none  |     0|acc_norm       |↑  | 0.7954|±  |0.0094|
|wikitext|      2|none  |     0|bits_per_byte  |↓  | 0.6604|±  |   N/A|
|        |       |none  |     0|byte_perplexity|↓  | 1.5806|±  |   N/A|
|        |       |none  |     0|word_perplexity|↓  |11.5645|±  |   N/A|

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mergify mergify bot added nvidia rocm Related to AMD ROCm labels Mar 2, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 MOE models on a wider range of hardware, including AMD Instinct, Nvidia Ampere, and Hopper, through an emulation backend. The changes are extensive, touching quantization layers, model execution, and tests to accommodate this new emulation path. The implementation appears solid and well-integrated. I've found one critical issue that needs to be addressed.

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Comment on lines +351 to +356
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the activation global scale for inputs are different"
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we do have some kernels that support different global scales per expert, for instance see #21408

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin flashinfer default backends use a single shared global scale across all experts for both gate_up_proj and down_proj, see:

# For some FI kernels, the input scales are shared by all experts.
if is_global_sf_supported_for_nvfp4_backend(backend):
num_experts = w13.shape[0]
a13_scale = a13_scale.max().to(torch.float32).expand(num_experts)
a2_scale = a2_scale.max().to(torch.float32).expand(num_experts)
else:
a13_scale = a13_scale.max(dim=1).values.to(torch.float32)
.

This logic is here to use similarly a single global scale for gate_up_proj input and down_proj input in the emulation code path using TritonExperts.

We display a warning because there is no logic in vLLM at the moment to recompute fp8_e4m3 scales when taking this .max(). Thankfully enough, Model-Optimizer and compressed-tensors produce models that have the same global_scale for all of gate_proj/up_proj & experts, so this is not an issue. But in case, the serialized global scales are different, taking simply the .max() as done currently is not enough.

This may be fixed in an other PR

block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
emulation: bool = False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bad practice to add emulation as an argument in this function and only use it for a single quant_dtype case. Why don't you just call ref_nvfp4_quant_dequant(A, A_scale, block_size=16) inline in apply?

Copy link
Copy Markdown
Contributor Author

@fxmarty-amd fxmarty-amd Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Which apply are you talking about? Nvfp4QuantizationEmulationTritonExperts inherits from TritonExperts.apply and I do NOT want to modify TritonExperts.experts, and QDQ needs to be applied to BOTH a13 and a2.

For example, moe_kernel_quantize_input already handles MXFP4/MXFP6_E3M2/MXFP6_E4M3 fake QDQ through _mxfp4_quantize, _mxfp6_e3m2_quantize, _mxfp6_e2m3_quantize.

I agree this should be clarified. Do you propose to keep moe_kernel_quantize_input for REAL quantization cases, and have an other function handling all QDQ case?

and have in TritonExperts.apply:

if not emulation:
        qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
            intermediate_cache2,
            a2_scale,
            self.quant_dtype,
            self.per_act_token_quant,
            self.block_shape,
            emulation=self.emulation,
        )
else:
        qintermediate_cache2, a2q_scale = moe_kernel_input_fake_quantization(
            intermediate_cache2,
            a2_scale,
            self.quant_dtype,
            self.per_act_token_quant,
            self.block_shape,
        )

? Let me know!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Michael may be suggesting that the other argument combinations (fp8 + emulation) are not handled and instead silently fall back to real quantization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, let me address properly

Comment on lines +3 to +12
"""
Quantization Emulation Experts for MoE.

This module provides emulation support for MOE quantization schemes that
don't have native hardware support. It dequantizes weights on the fly
and falls back to calling fused_experts with activation quantization.

Similar to QuarkOCP_MX_MoEMethod's emulation path but abstracted into
a reusable NvFp4MoeBackend.
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to be a general emulation moe or specific to nvfp4? I'm confused about the name vs the description

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant for NVFP4 only, if it is okay.

Let me update the name/description accordingly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +358 to +369
# moe_kernel_quantize_input -> ref_nvfp4_quant_dequant use the inverse scale.
# Similar to model_executor/layers/quantization/utils/flashinfer_fp4_moe.py.
# NOTE: at this point `a13_scale` and `a2_scale` are the inverses such that:
# `x_fp8_range = x * 1 / global_scale`, and `global_scale` is small.
# We take the max following e.g. flashinfer_fp4_moe.py, which results in likely
# overflow of the fp8 range, and scale clamping!
# It may be better to use min here.
a13_scale = a13_scale.max().to(torch.float32)
a2_scale = a2_scale.max().to(torch.float32)

a13_scale = 1.0 / a13_scale
a2_scale = 1.0 / a2_scale
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment needs to be reworked. Also can just do a13_scale = 1.0 / a13_scale.max().to(torch.float32) etc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fxmarty-amd.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 30, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Apr 1, 2026

Hi @fxmarty-amd, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify mergify bot removed the needs-rebase label Apr 1, 2026
fxmarty-amd and others added 5 commits April 1, 2026 10:38
…emes/compressed_tensors_w4a4_nvfp4.py

Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
…emes/compressed_tensors_w4a4_nvfp4.py

Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: fxmarty-amd <felmarty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Apr 1, 2026

Hi @fxmarty-amd, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reposting what I commented on the other PR: #35859 (review)

I think as it stands, passing the emulation_dequantize_weights is creating a lot of branching and modifications on existing quantization schemes. I would strongly consider breaking this out into a separate scheme, similar to Fp8OnlineLinearMethod, otherwise a lot of function contracts/behavior get changed.

I agree that emulation_dequantize_weights=False should be a linear backend, no problem there.

@fxmarty-amd
Copy link
Copy Markdown
Contributor Author

@kylesayrs Thanks a lot for reviewing! #35859 was based off #35855 that has been deemed not acceptable, so I will remove the logic about emulation_dequantize_weights in there. This logic is not in this PR.

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Apr 2, 2026

Hi @fxmarty-amd, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia rocm Related to AMD ROCm

Projects

Status: Todo
Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants