Skip to content

Conversation

@wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Sep 30, 2025

Add grouped_gemm_nt_masked from flashinfer to support nvfp4 MoE.

depends on silu_and_mul nvfp4 quanization fusion rework

Purpose

Test Plan

VLLM_WORKER_MULTIPROC_METHOD="spawn" \
VLLM_ALL2ALL_BACKEND="masked_gemm" \
VLLM_USE_STANDALONE_COMPILE=0 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="cutedsl" \
lm_eval --model vllm --model_args pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=False,tensor_parallel_size=1,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

vllm (pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=True,tensor_parallel_size=1,max_model_len=2048,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9591|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 6, 2025
@wenscarl wenscarl force-pushed the cutedsl_grp_gemm branch 3 times, most recently from 3d56913 to 99d4080 Compare October 7, 2025 03:55
Copy link
Collaborator

@bnellnm bnellnm Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be existing utilities for a number of these functions, e.g. test_moe, dequantize_nvfp4_to_dtype, etc. Can you switch over to the existing implementations?

It would also be good to add the FlashInferCuteDSLExperts to the test_modular_kernel_combinations.py test. It should be fairly simple to register them in modular_kernel_tools/mk_objects.py. The test already supports nvfp4 so there should not be much additional work.

@varun-sundar-rabindranath
Copy link
Contributor

Thanks for working on this ! I think this will also help enable gpt-oss + DeepEPLowLatency on blackwell 🙌

@mergify mergify bot removed the needs-rebase label Oct 10, 2025
@wenscarl wenscarl marked this pull request as ready for review October 10, 2025 03:35
@wenscarl wenscarl requested a review from bnellnm October 10, 2025 03:35
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines 4 to 21
import pytest
import torch
from flashinfer import fp4_quantize
from torch.nn import functional as F

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
flashinfer_cutedsl_moe_masked,
scaled_fp4_grouped_quant,
)
from vllm.utils.flashinfer import (
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
)

if torch.cuda.get_device_capability() < (10, 0):
pytest.skip(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard optional FlashInfer/GPU dependencies in new test

The new CUTEDSL MoE test imports flashinfer and calls torch.cuda.get_device_capability() at module import time. In environments without the optional FlashInfer package or without CUDA support, these imports raise ImportError/RuntimeError before pytest has a chance to apply the skip, causing the entire test suite to fail during collection. Wrap the import with pytest.importorskip("flashinfer") and check torch.cuda.is_available() before calling get_device_capability so the module skips cleanly when the dependency or hardware is absent.

Useful? React with 👍 / 👎.

@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines 122 to 127
if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl":
logger.info_once(
"Skip quantization when using FlashInfer CUTEDSL for "
"ModelOptNvFp4FusedMoE."
)
q_dtype = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantization can be skipped if the quant_dtype field is left as None in the quant_config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to limit the scope of this temporary change to dispatch since the whole model is still nvfp4. When fp4 dispatched is supported by deepep(actually already supported but not in main branch), we can remove this.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me overall, it seems we just need to wait for the flashinfer change to get in

@wenscarl
Copy link
Contributor Author

@mgoin flashinfer-ai/flashinfer#1927 is merged. Should unblock this PR.

@mergify
Copy link

mergify bot commented Nov 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 12, 2025
@wenscarl wenscarl requested review from bnellnm and mgoin November 13, 2025 05:24
@mergify mergify bot removed the needs-rebase label Nov 13, 2025
Copy link
Collaborator

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Just had a couple minor comments.

@mergify
Copy link

mergify bot commented Nov 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl When I run the test locally, I see a failure for the last case, PTAL

tests/kernels/moe/test_cutedsl_moe.py .......F                                                                                                                                                         [100%]

================================================================================================== FAILURES ==================================================================================================
_________________________________________________________________________________ test_grouped_gemm_nt_masked[16-128-512-5] __________________________________________________________________________________

bs = 16, hidden_dim = 128, inter_dim = 512, topk = 5

    @pytest.mark.parametrize(
        "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
    )
    @torch.inference_mode()
    def test_grouped_gemm_nt_masked(
        bs: int, hidden_dim: int, inter_dim: int, topk: int
    ) -> None:
        torch.manual_seed(42)
        B = bs
        D = hidden_dim
        N = inter_dim
        # CuteDSL group gemm has issue when not all experts are active.
        # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
        # see https://github.com/flashinfer-ai/flashinfer/issues/1856
        num_experts = bs
        hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
        weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
        router_logits = torch.randn(B, num_experts, dtype=torch.float32)
    
        hidden_states_expanded = (
            hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
        )
        hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
            hidden_states_expanded, router_logits, num_experts, topk
        )
    
        a_amax = (
            hidden_states_3d.abs()
            .amax(dim=(1, 2))
            .to(torch.float32)
            .to(hidden_states.device)
        )
        b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
        a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
        b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
        out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
            hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
        )
        # reference
        out_ref = grouped_gemm_ref(
            hidden_states_expanded=hidden_states_expanded,
            hidden_states_3d=hidden_states_3d,
            weights=weights,
            topk_idx=topk_idx,
            masked_m=masked_m,
            B=B,
            topk=topk,
            num_experts=num_experts,
        )
        # Note: just to compare the masked position due to cutedsl may write nan
        # into unmasked position.
        for i in range(num_experts):
>           torch.testing.assert_close(
                out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
                out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
                atol=1e-1,
                rtol=1e-1,
            )
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1529 / 1536 (99.5%)
E           Greatest absolute difference: 42.5 at index (1, 212) (up to 0.1 allowed)
E           Greatest relative difference: 1.0 at index (0, 0) (up to 0.1 allowed)

tests/kernels/moe/test_cutedsl_moe.py:570: AssertionError

@mergify mergify bot added the ci/build label Nov 17, 2025
Signed-off-by: Shu Wang. <[email protected]>
@wenscarl wenscarl requested a review from mgoin November 18, 2025 20:27
@wenscarl
Copy link
Contributor Author

@wenscarl When I run the test locally, I see a failure for the last case, PTAL

tests/kernels/moe/test_cutedsl_moe.py .......F                                                                                                                                                         [100%]

================================================================================================== FAILURES ==================================================================================================
_________________________________________________________________________________ test_grouped_gemm_nt_masked[16-128-512-5] __________________________________________________________________________________

bs = 16, hidden_dim = 128, inter_dim = 512, topk = 5

    @pytest.mark.parametrize(
        "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
    )
    @torch.inference_mode()
    def test_grouped_gemm_nt_masked(
        bs: int, hidden_dim: int, inter_dim: int, topk: int
    ) -> None:
        torch.manual_seed(42)
        B = bs
        D = hidden_dim
        N = inter_dim
        # CuteDSL group gemm has issue when not all experts are active.
        # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
        # see https://github.com/flashinfer-ai/flashinfer/issues/1856
        num_experts = bs
        hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
        weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
        router_logits = torch.randn(B, num_experts, dtype=torch.float32)
    
        hidden_states_expanded = (
            hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
        )
        hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
            hidden_states_expanded, router_logits, num_experts, topk
        )
    
        a_amax = (
            hidden_states_3d.abs()
            .amax(dim=(1, 2))
            .to(torch.float32)
            .to(hidden_states.device)
        )
        b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
        a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
        b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
        out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
            hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
        )
        # reference
        out_ref = grouped_gemm_ref(
            hidden_states_expanded=hidden_states_expanded,
            hidden_states_3d=hidden_states_3d,
            weights=weights,
            topk_idx=topk_idx,
            masked_m=masked_m,
            B=B,
            topk=topk,
            num_experts=num_experts,
        )
        # Note: just to compare the masked position due to cutedsl may write nan
        # into unmasked position.
        for i in range(num_experts):
>           torch.testing.assert_close(
                out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
                out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
                atol=1e-1,
                rtol=1e-1,
            )
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1529 / 1536 (99.5%)
E           Greatest absolute difference: 42.5 at index (1, 212) (up to 0.1 allowed)
E           Greatest relative difference: 1.0 at index (0, 0) (up to 0.1 allowed)

tests/kernels/moe/test_cutedsl_moe.py:570: AssertionError

It's because the global scaling factors have nan. Fixed by filling 1s at initialization.

@vllm-bot vllm-bot merged commit 613abb5 into vllm-project:main Nov 19, 2025
52 of 54 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Nov 19, 2025
Victor49152 pushed a commit to Victor49152/vllm that referenced this pull request Nov 20, 2025
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
…project#25990)

Signed-off-by: Shu Wang. <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: LuminolT <[email protected]>
bigPYJ1151 pushed a commit that referenced this pull request Nov 25, 2025
Signed-off-by: Shu Wang. <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: jiang1.li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants