Skip to content

[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4#33506

Merged
vllm-bot merged 15 commits intovllm-project:mainfrom
amitz-nv:support-fi-fused-moe-non-gated-fp8-nvfp4
Feb 12, 2026
Merged

[Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4#33506
vllm-bot merged 15 commits intovllm-project:mainfrom
amitz-nv:support-fi-fused-moe-non-gated-fp8-nvfp4

Conversation

@amitz-nv
Copy link
Contributor

@amitz-nv amitz-nv commented Feb 1, 2026

Purpose

Add support for Flashinfer trtllm fused MoE non-gated activation for FP8 and for NVFP4.

Changes:

  • Pass activation_type argument to FlashInfer trtllm fused MoE FP8 and NVFP4.
  • Add DeepSeek routing to supported routing list of Flashinfer trtllm fused MoE FP8
  • Add support to non-gated flow in Flashinfer trtllm fused MoE NVFP4
  • Use min_alignment=128 (padding) for non-gated activation in Flashinfer trtllm fused MoE
  • Fix tests/kernels/moe/test_flashinfer.py and expand it to also test relu2_no_mul activation for both cutlass and trtllm kernels.

lm_eval on Nemotron 3 Nano FP8:

export VLLM_USE_FLASHINFER_MOE_FP8=1
export VLLM_FLASHINFER_MOE_BACKEND=latency

LLM_FLASHINFER_MOE_BACKEND="backend" lm_eval --model vllm --model_args pretrained=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4,\ 
tensor_parallel_size=1,max_model_len=2048,kv_cache_dtype=auto \ 
--gen_kwargs temperature=0.0 --limit 500 --trust_remote_code \ 
--tasks gsm8k --num_fewshot 5 --batch_size 200 

Outputs:

vllm ({'pretrained': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'kv_cache_dtype': 'auto'}), gen_kwargs: ({'temperature': 0.0}), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.568|±  |0.0222|
|     |       |strict-match    |     5|exact_match|↑  |0.848|±  |0.0161|

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the nvidia label Feb 1, 2026
@mergify
Copy link

mergify bot commented Feb 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @amitz-nv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for non-gated Mixture of Experts (MoE) models using FlashInfer with FP8 and NVFP4 quantization. The changes are comprehensive, including updates to tests, support checks, activation handling, and weight preparation logic. Overall, the changes are well-aligned with the PR's objective. However, I've identified a critical bug in the FP4 MoE weight preparation logic that incorrectly calculates shapes for gated activations, which could lead to runtime errors or incorrect results. I have provided specific suggestions to address this issue.

Comment on lines +198 to +206

gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
num_experts, hidden_size, actual_intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
).reshape(
num_experts, hidden_size, actual_intermediate_size // 16
) # fp8 scaling factors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation for gemm2_weights_fp4 and gemm2_scales_linear_fp4 shapes is incorrect for gated activations. actual_intermediate_size is derived from w13's shape, which differs for gated and non-gated models. However, the down-projection (gemm2) should have a consistent intermediate dimension. This change introduces mlp_ffn_dim to correctly calculate the shapes for both gated and non-gated cases, fixing a bug for gated activations.

Suggested change
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, intermediate_size // 2
num_experts, hidden_size, actual_intermediate_size // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(num_experts, hidden_size, intermediate_size // 16) # fp8 scaling factors
).reshape(
num_experts, hidden_size, actual_intermediate_size // 16
) # fp8 scaling factors
mlp_ffn_dim = intermediate_size if is_gated_activation else 2 * intermediate_size
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
num_experts, hidden_size, mlp_ffn_dim // 2
) # packed fp4
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
torch.float8_e4m3fn
).reshape(
num_experts, hidden_size, mlp_ffn_dim // 16
) # fp8 scaling factors

Comment on lines 282 to 286
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
.reshape(num_experts, hidden_size, actual_intermediate_size // 16)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the previous comment, the reshape dimension for gemm2_scales_fp4_shuffled is incorrect for gated activations. It should use the mlp_ffn_dim variable (defined in the suggested fix for the previous issue) to ensure the correct shape.

Suggested change
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, intermediate_size // 16)
.reshape(num_experts, hidden_size, actual_intermediate_size // 16)
)
gemm2_scales_fp4_shuffled = (
torch.stack(gemm2_scales_fp4_shuffled)
.view(torch.float8_e4m3fn)
.reshape(num_experts, hidden_size, mlp_ffn_dim // 16)
)

@amitz-nv amitz-nv force-pushed the support-fi-fused-moe-non-gated-fp8-nvfp4 branch from 2d08ea2 to 5e81d21 Compare February 1, 2026 16:51
@mergify mergify bot removed the needs-rebase label Feb 1, 2026
@amitz-nv amitz-nv force-pushed the support-fi-fused-moe-non-gated-fp8-nvfp4 branch from 500f8e3 to e1b1314 Compare February 2, 2026 15:31
@mergify
Copy link

mergify bot commented Feb 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @amitz-nv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 8, 2026
@amitz-nv amitz-nv force-pushed the support-fi-fused-moe-non-gated-fp8-nvfp4 branch 2 times, most recently from 5956b26 to 0e40b63 Compare February 10, 2026 12:00
@mergify mergify bot removed the needs-rebase label Feb 10, 2026
@amitz-nv amitz-nv force-pushed the support-fi-fused-moe-non-gated-fp8-nvfp4 branch from 3e18b3d to 4550510 Compare February 10, 2026 12:35
@amitz-nv amitz-nv changed the title Support FI fused MoE non gated FP8 & NVFP4 [Kernel] Support Flashinfer fused MoE non gated FP8 & NVFP4 Feb 10, 2026
@amitz-nv amitz-nv changed the title [Kernel] Support Flashinfer fused MoE non gated FP8 & NVFP4 [Kernel] Support Flashinfer trtllm-gen fused MoE non gated FP8 & NVFP4 Feb 10, 2026
@amitz-nv amitz-nv changed the title [Kernel] Support Flashinfer trtllm-gen fused MoE non gated FP8 & NVFP4 [Kernel] Support Flashinfer trtllm fused MoE non gated FP8 & NVFP4 Feb 10, 2026
@amitz-nv amitz-nv marked this pull request as ready for review February 10, 2026 15:21
use_routing_scales_on_input: bool,
routing_method_type: int,
routed_scaling_factor: float = 1.0,
activation_type: int = 3, # Swiglu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the default value to always be explicit

Comment on lines +21 to +36
def is_gated_activation(activation: str) -> bool:
return not activation.lower().endswith("_no_mul")


def activation_str_to_int(activation: str) -> int:
from flashinfer.fused_moe.core import ActivationType

# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
ACTIVATION_TO_FI_ACTIVATION = {
"silu_no_mul": ActivationType.Silu,
"gelu_no_mul": ActivationType.Gelu,
"silu": ActivationType.Swiglu,
"gelu": ActivationType.Geglu,
"relu2_no_mul": ActivationType.Relu2,
}
return ACTIVATION_TO_FI_ACTIVATION[activation.lower()].value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could have this use the MoEActivation refactor, hopefully landing soon #33843

Copy link
Contributor Author

@amitz-nv amitz-nv Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I definitely agree that refactor is necessary!
Regarding the order, I think it depends on when the refactor PR is merged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# for the gate-up proj. Pad the weights to respect this.
is_gated = is_gated_activation(layer.activation)
if not block_quant:
min_alignment = 16 if is_gated else 128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some justification for 128 we can reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what the current Flashinfer kernels require, otherwise it doesn't find a suitable kernel.

For example, Nemotron 3 Nano TP=1 would fail unless it's set to 128 here:

(EngineCore_DP0 pid=3184059)   File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 2258, in trtllm_fp8_per_tensor_scale_moe
(EngineCore_DP0 pid=3184059)     return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(
(EngineCore_DP0 pid=3184059)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3184059)   File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 1488, in trtllm_fp8_per_tensor_scale_moe_op
(EngineCore_DP0 pid=3184059)     result = moe_op.trtllm_fp8_per_tensor_scale_moe(
(EngineCore_DP0 pid=3184059)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=3184059)   File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
(EngineCore_DP0 pid=3184059) RuntimeError: Error in function 'getValidConfigIndices' at /usr/local/lib/python3.12/dist-packages/flashinfer/data/csrc/trtllm_batched_gemm_runner.cu:416: No valid config found for the given problem shape

Comment on lines +511 to +513
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are in NVFP4, why would we expect weight_block_size in any case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was copied from the FP8 flow, removing it

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…ted MoE and rel2_no_mul activation, support DeepSeek routing in FP8 per-tensor, fix prepare_static_weights_for_trtllm_fp4_moe

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…m/utils/flashinfer.py

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…gated, otherwise use min_alignment=16

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
@amitz-nv amitz-nv force-pushed the support-fi-fused-moe-non-gated-fp8-nvfp4 branch from 5d43e07 to ea22768 Compare February 12, 2026 10:11
@mergify
Copy link

mergify bot commented Feb 12, 2026

Hi @amitz-nv, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed performance Performance-related issues labels Feb 12, 2026
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM nice work! Will manually trigger MoE refactor tests

Comment on lines 939 to -943
# time in the oracle rather than here.
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
assert not layer.renormalize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we need to update the compressed tensors side too, can do in followup PR

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 12, 2026
@vllm-bot vllm-bot merged commit f120bd4 into vllm-project:main Feb 12, 2026
61 of 67 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 12, 2026
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
…llm-project#33506)

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…llm-project#33506)

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…llm-project#33506)

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants