Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions benchmark/kernels/quantization/bench_fp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from sgl_kernel.elementwise import silu_and_mul

from sglang.srt.layers import deep_gemm_wrapper
Expand All @@ -14,11 +17,11 @@ def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales)
out1, blk_scales1 = scaled_fp4_grouped_quantize(
silu_and_mul(x),
glb_scales,
masks,
glb_scales,
)

torch.testing.assert_close(out, out1)
Expand Down Expand Up @@ -87,19 +90,19 @@ def benchmark(M, K, provider):
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
lambda: scaled_fp4_grouped_quantize(
silu_and_mul(x),
glb_scales,
masks,
glb_scales,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
lambda: silu_and_mul_scaled_nvfp4_experts_quantize(
x,
glb_scales,
masks,
glb_scales,
),
quantiles=quantiles,
)
Expand Down
1 change: 1 addition & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its

| Environment Variable | Description | Default Value |
| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` |
| `SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH` | Use nvfp4 for dispatch | `"false"` |

## Memory Management

Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional

import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from sgl_kernel.gemm import (
scaled_fp4_grouped_quant,
silu_and_mul_scaled_fp4_grouped_quant,
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked


def get_cute_dtype(input: torch.Tensor) -> str:
Expand Down Expand Up @@ -97,10 +97,10 @@ def flashinfer_cutedsl_moe_masked(
num_experts,
), f"input_global_scale must be (l,), got {input_global_scale.shape}"

a_q, a_q_sf = scaled_fp4_grouped_quant(
a_q, a_q_sf = scaled_fp4_grouped_quantize(
hidden_states[0],
input_global_scale,
masked_m,
input_global_scale,
)

assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
Expand Down Expand Up @@ -148,10 +148,10 @@ def flashinfer_cutedsl_moe_masked(
) # in logical [m, n, l]

# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
gateup_output.permute(2, 0, 1),
a2_global_scale,
masked_m,
a2_global_scale,
)

if down_start_event is not None:
Expand Down
21 changes: 10 additions & 11 deletions sgl-kernel/tests/test_fp4_quantize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import pytest
import torch
from sgl_kernel import (
scaled_fp4_grouped_quant,
scaled_fp4_quant,
silu_and_mul,
silu_and_mul_scaled_fp4_grouped_quant,
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from sgl_kernel import scaled_fp4_quant, silu_and_mul

skip_condition = torch.cuda.get_device_capability() < (10, 0)

Expand Down Expand Up @@ -186,10 +185,10 @@ def test_quantize_to_fp4_grouped(shape):
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
output, output_scales = scaled_fp4_grouped_quant(
output, output_scales = scaled_fp4_grouped_quantize(
x,
x_sf_global,
mask,
x_sf_global,
)
# output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k).
Expand Down Expand Up @@ -225,15 +224,15 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape):
ref_y = silu_and_mul(x)
tensor_amax = ref_y.abs().amax(dim=(1, 2)).to(torch.float32)
y_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
ref_output, ref_output_scales = scaled_fp4_grouped_quant(
ref_output, ref_output_scales = scaled_fp4_grouped_quantize(
ref_y,
y_sf_global,
mask,
y_sf_global,
)
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
output, output_scales = silu_and_mul_scaled_nvfp4_experts_quantize(
x,
y_sf_global,
mask,
y_sf_global,
)

# output in logical (m, k, l), but its physical layout is (l, m, k).
Expand Down
12 changes: 6 additions & 6 deletions test/srt/test_cutedsl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Callable

import torch
from flashinfer import fp4_quantize
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
from flashinfer import fp4_quantize, scaled_fp4_grouped_quantize
from sgl_kernel import scaled_fp4_quant
from torch.nn import functional as F

from sglang.srt.layers.activation import SiluAndMul
Expand Down Expand Up @@ -370,18 +370,18 @@ def test_flashinfer_cutedsl_moe_masked(self):
(num_experts,), dtype=torch.float32, device=hidden_states.device
) # assume intermediate scale is 1.0

w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
w1,
w1_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w1.device)
* 2
* inter_dim,
w1_global_scale,
)
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
w2,
w2_global_scale,
torch.ones(num_experts, dtype=torch.int32, device=w2.device)
* hidden_dim,
w2_global_scale,
)

w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
Expand Down
12 changes: 6 additions & 6 deletions test/srt/test_fp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytest
import torch
from flashinfer import fp4_quantize
from flashinfer import fp4_quantize, scaled_fp4_grouped_quantize
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant, silu_and_mul
from sgl_kernel import scaled_fp4_quant
from torch.nn import functional as F

from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
Expand Down Expand Up @@ -190,16 +190,16 @@ def flashinfer_cutedsl_grouped_gemm_nt_masked(

# hidden_states: [l, m, k]
# weights: [l, n, k]
aq, aq_sf = scaled_fp4_grouped_quant(
aq, aq_sf = scaled_fp4_grouped_quantize(
hidden_states,
input_global_scale,
masked_m.to(hidden_states.device),
input_global_scale,
)
num_experts, n, k = weights.shape
bq, bq_sf = scaled_fp4_grouped_quant(
bq, bq_sf = scaled_fp4_grouped_quantize(
weights,
w_global_scale,
torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n,
w_global_scale,
)

out = torch.zeros(
Expand Down
Loading