-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[NVIDIA] [2/N] Optimize silu_and_mul_scaled_fp4_grouped_quant perf
#9556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
6bdd9ad
Optimize the nvfp4 fused quant perf
kaixih 2331d9b
Add bench
kaixih 02c7580
Use masked fp4 quant
kaixih cbf86a3
Typo
kaixih 8ab5844
Lint
kaixih 407a3ca
Compute offset on the fly instead of loading
kaixih 6df6848
Lint
kaixih a192644
Lint
kaixih File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| import argparse | ||
| import itertools | ||
|
|
||
| import torch | ||
| import triton | ||
| from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant | ||
| from sgl_kernel.elementwise import silu_and_mul | ||
|
|
||
| from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd | ||
| from sglang.srt.layers.quantization import deep_gemm_wrapper | ||
|
|
||
|
|
||
| def _test_accuracy_once(E, M, K, input_dtype, device): | ||
| x = torch.randn(E, M, K, device=device, dtype=input_dtype) | ||
| glb_scales = torch.ones((E,), dtype=torch.float32, device=device) | ||
| masks = torch.full((E,), M, dtype=torch.int32, device=device) | ||
| out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks) | ||
| out1, blk_scales1 = scaled_fp4_grouped_quant( | ||
| silu_and_mul(x), | ||
| glb_scales, | ||
| ) | ||
|
|
||
| torch.testing.assert_close(out, out1) | ||
| torch.testing.assert_close(blk_scales, blk_scales1) | ||
| print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK") | ||
|
|
||
|
|
||
| NUM_RANKS = 48 | ||
| M_PER_RANKs = [128, 256, 512, 1024] | ||
| Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs] | ||
| Ks = [2048, 4096, 7168] | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["M", "K"], | ||
| x_vals=list(itertools.product(Ms, Ks)), | ||
| x_log=False, | ||
| line_arg="provider", | ||
| line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], | ||
| line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], | ||
| styles=[("blue", "-"), ("orange", "-"), ("green", "-")], | ||
| ylabel="ms", | ||
| plot_name="fp4 quant", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark(M, K, provider): | ||
| E = 6 | ||
| device = "cuda" | ||
| x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16) | ||
| glb_scales = torch.ones((E,), dtype=torch.float32, device=device) | ||
| masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device) | ||
| fp8_out = torch.empty( | ||
| ( | ||
| x.shape[0], | ||
| x.shape[1], | ||
| x.shape[2] // 2, | ||
| ), | ||
| device=x.device, | ||
| dtype=torch.float8_e4m3fn, | ||
| ) | ||
| scale_block_size = 128 | ||
| fp8_scales = torch.empty( | ||
| ( | ||
| x.shape[0], | ||
| x.shape[1], | ||
| x.shape[2] // 2 // scale_block_size, | ||
| ), | ||
| device=x.device, | ||
| dtype=torch.float32, | ||
| ) | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
| if provider == "triton_fp8": | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: silu_and_mul_masked_post_quant_fwd( | ||
| x, | ||
| fp8_out, | ||
| fp8_scales, | ||
| scale_block_size, | ||
| masks, | ||
| scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| if provider == "cuda_unfused_fp4": | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: scaled_fp4_grouped_quant( | ||
| silu_and_mul(x), | ||
| glb_scales, | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
| if provider == "cuda_fused_fp4": | ||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||
| lambda: silu_and_mul_scaled_fp4_grouped_quant( | ||
| x, | ||
| glb_scales, | ||
| masks, | ||
| ), | ||
| quantiles=quantiles, | ||
| ) | ||
|
|
||
| return ms, min_ms, max_ms | ||
|
|
||
|
|
||
| def test_accuracy(): | ||
| E = 6 | ||
| N_RANKS = 48 | ||
| Ms = [128, 256, 512, 1024] | ||
| Ks = [2048, 4096, 7168] | ||
| input_dtype = torch.bfloat16 | ||
| for M in Ms: | ||
| for K in Ks: | ||
| _test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--save_path", | ||
| type=str, | ||
| default="./bench_fp4_quant_res", | ||
| help="Path to save fp4 quant benchmark results", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| test_accuracy() | ||
|
|
||
| benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.