diff --git a/benchmark/kernels/quantization/bench_fp4_quant.py b/benchmark/kernels/quantization/bench_fp4_quant.py index 9a5b6946339..afc12dd8d3f 100644 --- a/benchmark/kernels/quantization/bench_fp4_quant.py +++ b/benchmark/kernels/quantization/bench_fp4_quant.py @@ -3,7 +3,10 @@ import torch import triton -from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant +from flashinfer import ( + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, +) from sgl_kernel.elementwise import silu_and_mul from sglang.srt.layers import deep_gemm_wrapper @@ -14,11 +17,11 @@ def _test_accuracy_once(E, M, K, input_dtype, device): x = torch.randn(E, M, K, device=device, dtype=input_dtype) glb_scales = torch.ones((E,), dtype=torch.float32, device=device) masks = torch.full((E,), M, dtype=torch.int32, device=device) - out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks) - out1, blk_scales1 = scaled_fp4_grouped_quant( + out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales) + out1, blk_scales1 = scaled_fp4_grouped_quantize( silu_and_mul(x), - glb_scales, masks, + glb_scales, ) torch.testing.assert_close(out, out1) @@ -87,19 +90,19 @@ def benchmark(M, K, provider): ) if provider == "cuda_unfused_fp4": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: scaled_fp4_grouped_quant( + lambda: scaled_fp4_grouped_quantize( silu_and_mul(x), - glb_scales, masks, + glb_scales, ), quantiles=quantiles, ) if provider == "cuda_fused_fp4": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: silu_and_mul_scaled_fp4_grouped_quant( + lambda: silu_and_mul_scaled_nvfp4_experts_quantize( x, - glb_scales, masks, + glb_scales, ), quantiles=quantiles, ) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 4814748c187..0e530b9a9c9 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -48,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` | +| `SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH` | Use nvfp4 for dispatch | `"false"` | ## Memory Management diff --git a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py index 504afcdaecd..74455c93121 100644 --- a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +++ b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py @@ -1,11 +1,11 @@ from typing import Optional import torch -from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked -from sgl_kernel.gemm import ( - scaled_fp4_grouped_quant, - silu_and_mul_scaled_fp4_grouped_quant, +from flashinfer import ( + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) +from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked def get_cute_dtype(input: torch.Tensor) -> str: @@ -97,10 +97,10 @@ def flashinfer_cutedsl_moe_masked( num_experts, ), f"input_global_scale must be (l,), got {input_global_scale.shape}" - a_q, a_q_sf = scaled_fp4_grouped_quant( + a_q, a_q_sf = scaled_fp4_grouped_quantize( hidden_states[0], - input_global_scale, masked_m, + input_global_scale, ) assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" @@ -148,10 +148,10 @@ def flashinfer_cutedsl_moe_masked( ) # in logical [m, n, l] # SILU and quantization - diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant( + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( gateup_output.permute(2, 0, 1), - a2_global_scale, masked_m, + a2_global_scale, ) if down_start_event is not None: diff --git a/sgl-kernel/tests/test_fp4_quantize.py b/sgl-kernel/tests/test_fp4_quantize.py index 3e83e47ac67..e29bac2119d 100644 --- a/sgl-kernel/tests/test_fp4_quantize.py +++ b/sgl-kernel/tests/test_fp4_quantize.py @@ -1,11 +1,10 @@ import pytest import torch -from sgl_kernel import ( - scaled_fp4_grouped_quant, - scaled_fp4_quant, - silu_and_mul, - silu_and_mul_scaled_fp4_grouped_quant, +from flashinfer import ( + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) +from sgl_kernel import scaled_fp4_quant, silu_and_mul skip_condition = torch.cuda.get_device_capability() < (10, 0) @@ -186,10 +185,10 @@ def test_quantize_to_fp4_grouped(shape): mask = torch.randint(1, max_m, (l,), dtype=torch.int32) tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32) x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - output, output_scales = scaled_fp4_grouped_quant( + output, output_scales = scaled_fp4_grouped_quantize( x, - x_sf_global, mask, + x_sf_global, ) # output in logical (m, k, l), but its physical layout is (l, m, k). # So permute first to (l, m, k). @@ -225,15 +224,15 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape): ref_y = silu_and_mul(x) tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32) y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - ref_output, ref_output_scales = scaled_fp4_grouped_quant( + ref_output, ref_output_scales = scaled_fp4_grouped_quantize( ref_y, - y_sf_global, mask, + y_sf_global, ) - output, output_scales = silu_and_mul_scaled_fp4_grouped_quant( + output, output_scales = silu_and_mul_scaled_nvfp4_experts_quantize( x, - y_sf_global, mask, + y_sf_global, ) # output in logical (m, k, l), but its physical layout is (l, m, k). diff --git a/test/srt/test_cutedsl_moe.py b/test/srt/test_cutedsl_moe.py index e751223571d..d60ce723f8a 100644 --- a/test/srt/test_cutedsl_moe.py +++ b/test/srt/test_cutedsl_moe.py @@ -3,8 +3,8 @@ from typing import Callable import torch -from flashinfer import fp4_quantize -from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant +from flashinfer import fp4_quantize, scaled_fp4_grouped_quantize +from sgl_kernel import scaled_fp4_quant from torch.nn import functional as F from sglang.srt.layers.activation import SiluAndMul @@ -370,18 +370,18 @@ def test_flashinfer_cutedsl_moe_masked(self): (num_experts,), dtype=torch.float32, device=hidden_states.device ) # assume intermediate scale is 1.0 - w1_fp4, w1_blockscale = scaled_fp4_grouped_quant( + w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize( w1, - w1_global_scale, torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, + w1_global_scale, ) - w2_fp4, w2_blockscale = scaled_fp4_grouped_quant( + w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize( w2, - w2_global_scale, torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, + w2_global_scale, ) w1_alpha = 1.0 / (input_global_scale * w1_global_scale) diff --git a/test/srt/test_fp4_moe.py b/test/srt/test_fp4_moe.py index 269f02c66ec..ac42fa8609e 100644 --- a/test/srt/test_fp4_moe.py +++ b/test/srt/test_fp4_moe.py @@ -3,9 +3,9 @@ import pytest import torch -from flashinfer import fp4_quantize +from flashinfer import fp4_quantize, scaled_fp4_grouped_quantize from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe -from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant, silu_and_mul +from sgl_kernel import scaled_fp4_quant from torch.nn import functional as F from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 @@ -190,16 +190,16 @@ def flashinfer_cutedsl_grouped_gemm_nt_masked( # hidden_states: [l, m, k] # weights: [l, n, k] - aq, aq_sf = scaled_fp4_grouped_quant( + aq, aq_sf = scaled_fp4_grouped_quantize( hidden_states, - input_global_scale, masked_m.to(hidden_states.device), + input_global_scale, ) num_experts, n, k = weights.shape - bq, bq_sf = scaled_fp4_grouped_quant( + bq, bq_sf = scaled_fp4_grouped_quantize( weights, - w_global_scale, torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n, + w_global_scale, ) out = torch.zeros(