-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked #25990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3ba900a to
e186f4c
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
3d56913 to
99d4080
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be existing utilities for a number of these functions, e.g. test_moe, dequantize_nvfp4_to_dtype, etc. Can you switch over to the existing implementations?
It would also be good to add the FlashInferCuteDSLExperts to the test_modular_kernel_combinations.py test. It should be fairly simple to register them in modular_kernel_tools/mk_objects.py. The test already supports nvfp4 so there should not be much additional work.
|
Thanks for working on this ! I think this will also help enable gpt-oss + DeepEPLowLatency on blackwell 🙌 |
80a4edf to
32dd1a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
| import pytest | ||
| import torch | ||
| from flashinfer import fp4_quantize | ||
| from torch.nn import functional as F | ||
|
|
||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
| from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( | ||
| flashinfer_cutedsl_moe_masked, | ||
| scaled_fp4_grouped_quant, | ||
| ) | ||
| from vllm.utils.flashinfer import ( | ||
| flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked, | ||
| ) | ||
|
|
||
| if torch.cuda.get_device_capability() < (10, 0): | ||
| pytest.skip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard optional FlashInfer/GPU dependencies in new test
The new CUTEDSL MoE test imports flashinfer and calls torch.cuda.get_device_capability() at module import time. In environments without the optional FlashInfer package or without CUDA support, these imports raise ImportError/RuntimeError before pytest has a chance to apply the skip, causing the entire test suite to fail during collection. Wrap the import with pytest.importorskip("flashinfer") and check torch.cuda.is_available() before calling get_device_capability so the module skips cleanly when the dependency or hardware is absent.
Useful? React with 👍 / 👎.
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
Outdated
Show resolved
Hide resolved
|
This pull request has merge conflicts that must be resolved before it can be |
32dd1a1 to
8a224da
Compare
| if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl": | ||
| logger.info_once( | ||
| "Skip quantization when using FlashInfer CUTEDSL for " | ||
| "ModelOptNvFp4FusedMoE." | ||
| ) | ||
| q_dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantization can be skipped if the quant_dtype field is left as None in the quant_config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just want to limit the scope of this temporary change to dispatch since the whole model is still nvfp4. When fp4 dispatched is supported by deepep(actually already supported but not in main branch), we can remove this.
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me overall, it seems we just need to wait for the flashinfer change to get in
|
@mgoin flashinfer-ai/flashinfer#1927 is merged. Should unblock this PR. |
|
This pull request has merge conflicts that must be resolved before it can be |
bnellnm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Just had a couple minor comments.
|
This pull request has merge conflicts that must be resolved before it can be |
fc02f8e to
b97c80d
Compare
Signed-off-by: Shu Wang. <[email protected]>
b97c80d to
0a83133
Compare
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wenscarl When I run the test locally, I see a failure for the last case, PTAL
tests/kernels/moe/test_cutedsl_moe.py .......F [100%]
================================================================================================== FAILURES ==================================================================================================
_________________________________________________________________________________ test_grouped_gemm_nt_masked[16-128-512-5] __________________________________________________________________________________
bs = 16, hidden_dim = 128, inter_dim = 512, topk = 5
@pytest.mark.parametrize(
"bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
)
@torch.inference_mode()
def test_grouped_gemm_nt_masked(
bs: int, hidden_dim: int, inter_dim: int, topk: int
) -> None:
torch.manual_seed(42)
B = bs
D = hidden_dim
N = inter_dim
# CuteDSL group gemm has issue when not all experts are active.
# i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
# see https://github.com/flashinfer-ai/flashinfer/issues/1856
num_experts = bs
hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
router_logits = torch.randn(B, num_experts, dtype=torch.float32)
hidden_states_expanded = (
hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
)
hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
hidden_states_expanded, router_logits, num_experts, topk
)
a_amax = (
hidden_states_3d.abs()
.amax(dim=(1, 2))
.to(torch.float32)
.to(hidden_states.device)
)
b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
)
# reference
out_ref = grouped_gemm_ref(
hidden_states_expanded=hidden_states_expanded,
hidden_states_3d=hidden_states_3d,
weights=weights,
topk_idx=topk_idx,
masked_m=masked_m,
B=B,
topk=topk,
num_experts=num_experts,
)
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for i in range(num_experts):
> torch.testing.assert_close(
out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
atol=1e-1,
rtol=1e-1,
)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 1529 / 1536 (99.5%)
E Greatest absolute difference: 42.5 at index (1, 212) (up to 0.1 allowed)
E Greatest relative difference: 1.0 at index (0, 0) (up to 0.1 allowed)
tests/kernels/moe/test_cutedsl_moe.py:570: AssertionError
Signed-off-by: mgoin <[email protected]>
Signed-off-by: Shu Wang. <[email protected]>
It's because the global scaling factors have nan. Fixed by filling 1s at initialization. |
…project#25990) Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: Michael Goin <[email protected]>
…project#25990) Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: LuminolT <[email protected]>
Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: jiang1.li <[email protected]>
Add grouped_gemm_nt_masked from flashinfer to support nvfp4 MoE.
depends on silu_and_mul nvfp4 quanization fusion rework
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.