-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[NVIDIA] [3/N] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked #9199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zhyncs
merged 33 commits into
sgl-project:main
from
wenscarl:flashinfer_cutedsl_grp_gemm
Sep 12, 2025
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
79ba699
Add flashinfer CuteDSL masked grouped gemm support
wenscarl 64f72e7
zero to empty init.
wenscarl 4e03150
Address comment
wenscarl 0a1d699
Upd
wenscarl 73aa90a
Add masked_m
wenscarl b09d92d
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 6b96c98
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 73b9605
Update python/sglang/srt/layers/quantization/modelopt_quant.py
fzyzcjy f7fc26d
fix error
fzyzcjy ec2c719
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy dfb3ac3
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy f9bb5bc
Merge branch 'sgl-project:main' into flashinfer_cutedsl_grp_gemm
wenscarl 7773823
Skip fusing scaling factor into router weights for cutedsl backend
wenscarl 1497ccc
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy a88062e
Make unittest rigorous
wenscarl 7c7a6dc
Fix lint
wenscarl 21ff185
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 5950cd6
Merge branch 'main-upstream' into flashinfer_cutedsl_grp_gemm
fzyzcjy 6bac5df
Add comment
wenscarl 4d8812f
Enable fused scaling factor
wenscarl 4cac99f
Address comments
wenscarl f387bf0
Add e2e test
wenscarl ee04919
Merge remote-tracking branch 'origin/main' into flashinfer_cutedsl_gr…
wenscarl cc2a57f
add e2e test to test_suite
wenscarl 023e004
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 73a207c
Fix lint
wenscarl a2407b5
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 2cd04a1
Remove CI test temporarily
wenscarl 54b2657
Merge remote-tracking branch 'origin/main' into flashinfer_cutedsl_gr…
wenscarl 6276e25
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy c4552f1
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy 2422f5d
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy cd64a56
Merge branch 'main' into flashinfer_cutedsl_grp_gemm
fzyzcjy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import torch | ||
| from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked | ||
| from sgl_kernel.gemm import ( | ||
| scaled_fp4_grouped_quant, | ||
| silu_and_mul_scaled_fp4_grouped_quant, | ||
| ) | ||
|
|
||
|
|
||
| def get_cute_dtype(input: torch.Tensor) -> str: | ||
| if input.dtype == torch.bfloat16: | ||
| return "bfloat16" | ||
| elif input.dtype == torch.float16: | ||
| return "float16" | ||
| elif input.dtype == torch.float32: | ||
| return "float32" | ||
| else: | ||
| raise ValueError(f"Unsupported cute dtype {input.dtype}") | ||
|
|
||
|
|
||
| def flashinfer_cutedsl_moe_masked( | ||
| hidden_states: torch.Tensor, | ||
| input_global_scale: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w1_blockscale: torch.Tensor, | ||
| w1_alpha, | ||
| w2: torch.Tensor, | ||
| a2_global_scale: torch.Tensor, | ||
| w2_blockscale: torch.Tensor, | ||
| w2_alpha, | ||
| masked_m: torch.Tensor, | ||
| ): | ||
| """ | ||
| Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL | ||
| kernels. | ||
|
|
||
| Args: | ||
| hidden_states (torch.Tensor): [num_experts, m, k], bf16 | ||
| input_global_scale (torch.Tensor): (l,) | ||
| w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 | ||
| w1_blockscale (torch.Tensor): blockscale factors, e4m3, | ||
| w1_alpha (torch.Tensor): (l,) | ||
| w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 | ||
| a2_global_scale (torch.Tensor): (l,) | ||
| w2_blockscale (torch.Tensor): blockscale factors, e4m3, | ||
| w2_alpha (torch.Tensor): (l,) | ||
| masked_m (torch.Tensor): Masked dimension indices | ||
|
|
||
| Notes: | ||
| - Assumes max(masked_m) <= m. | ||
| """ | ||
|
|
||
| # === Assertions on dtypes === | ||
| assert ( | ||
| input_global_scale.dtype == torch.float32 | ||
| ), f"input_global_scale must be float32, got {input_global_scale.dtype}" | ||
| assert w1.dtype == torch.uint8, f"w1 must be uint8 (fp4 packed), got {w1.dtype}" | ||
| assert ( | ||
| w1_blockscale.dtype == torch.float8_e4m3fn | ||
| ), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" | ||
| assert ( | ||
| w1_alpha.dtype == torch.float32 | ||
| ), f"w1_alpha must be float32, got {w1_alpha.dtype}" | ||
| assert w2.dtype == torch.uint8, f"w2 must be uint8 (fp4 packed), got {w2.dtype}" | ||
| assert ( | ||
| a2_global_scale.dtype == torch.float32 | ||
| ), f"a2_global_scale must be float32, got {a2_global_scale.dtype}" | ||
| assert ( | ||
| w2_blockscale.dtype == torch.float8_e4m3fn | ||
| ), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" | ||
| assert ( | ||
| w2_alpha.dtype == torch.float32 | ||
| ), f"w2_alpha must be float32, got {w2_alpha.dtype}" | ||
|
|
||
| # === Assertions on shapes === | ||
| n = w2.shape[-1] * 2 # intermediate dimension | ||
| num_experts, m, k = hidden_states.shape | ||
|
|
||
| assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" | ||
| assert ( | ||
| w1.shape[-1] * 2 == k | ||
| ), f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" | ||
| assert w2.shape[-2:] == ( | ||
| k, | ||
| n // 2, | ||
| ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n//2)}" | ||
|
|
||
| assert input_global_scale.shape == ( | ||
| num_experts, | ||
| ), f"input_global_scale must be (l,), got {input_global_scale.shape}" | ||
| assert w1_alpha.shape == ( | ||
| num_experts, | ||
| ), f"w1_alpha must be (l,), got {w1_alpha.shape}" | ||
| assert a2_global_scale.shape == ( | ||
| num_experts, | ||
| ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}" | ||
| assert w2_alpha.shape == ( | ||
| num_experts, | ||
| ), f"w2_alpha must be (l,), got {w2_alpha.shape}" | ||
|
|
||
| aq, aq_sf = scaled_fp4_grouped_quant( | ||
| hidden_states, | ||
| input_global_scale, | ||
| masked_m, | ||
| ) | ||
| gateup_output = torch.empty( | ||
| (num_experts, m, n * 2), dtype=hidden_states.dtype, device=aq.device | ||
| ) | ||
| gateup_output = gateup_output.permute(1, 2, 0) # requirement of kernel | ||
| sf_vec_size = 16 | ||
| assert aq_sf.dtype == torch.float8_e4m3fn | ||
| assert aq.dtype == torch.uint8 | ||
| ab_dtype = "float4_e2m1fn" | ||
| sf_dtype = "float8_e4m3fn" | ||
|
|
||
| c_dtype = get_cute_dtype(hidden_states) | ||
|
|
||
| # Gemm1 | ||
|
|
||
| grouped_gemm_nt_masked( | ||
| (aq, aq_sf), | ||
| (w1.permute(1, 2, 0), w1_blockscale), | ||
| gateup_output, | ||
| masked_m, | ||
| ab_dtype=ab_dtype, | ||
| sf_dtype=sf_dtype, | ||
| c_dtype=c_dtype, | ||
| sf_vec_size=sf_vec_size, | ||
| alpha=w1_alpha.view(1, 1, num_experts), | ||
| alpha_dtype=get_cute_dtype(w1_alpha), | ||
| ) # in logical [m, n, l] | ||
|
|
||
| # SILU and quantization | ||
| diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant( | ||
| gateup_output.permute(2, 0, 1), | ||
| a2_global_scale, | ||
| masked_m, | ||
| ) | ||
|
|
||
| # Gemm2 | ||
| out = torch.empty_like(hidden_states) | ||
| out = out.permute(1, 2, 0) # requirement of kernel | ||
| grouped_gemm_nt_masked( | ||
| (diq, diq_sf), | ||
| (w2.permute(1, 2, 0), w2_blockscale), | ||
| out, | ||
| masked_m, | ||
| ab_dtype=ab_dtype, | ||
| sf_dtype=sf_dtype, | ||
| c_dtype=c_dtype, | ||
| sf_vec_size=sf_vec_size, | ||
| alpha=w2_alpha.view(1, 1, num_experts), | ||
| alpha_dtype=get_cute_dtype(w2_alpha), | ||
| ) # in logical [m, k, l] | ||
| return out.permute(2, 0, 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.