Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions benchmark/kernels/quantization/bench_fp4_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from sgl_kernel.elementwise import silu_and_mul

from sglang.srt.layers import deep_gemm_wrapper
Expand All @@ -14,11 +17,11 @@ def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales)
out1, blk_scales1 = scaled_fp4_grouped_quantize(
silu_and_mul(x),
glb_scales,
masks,
glb_scales,
)

torch.testing.assert_close(out, out1)
Expand Down Expand Up @@ -87,19 +90,19 @@ def benchmark(M, K, provider):
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
lambda: scaled_fp4_grouped_quantize(
silu_and_mul(x),
glb_scales,
masks,
glb_scales,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
lambda: silu_and_mul_scaled_nvfp4_experts_quantize(
x,
glb_scales,
masks,
glb_scales,
),
quantiles=quantiles,
)
Expand Down
1 change: 1 addition & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ SGLang supports various environment variables that can be used to configure its

| Environment Variable | Description | Default Value |
| `SGLANG_DEEPEP_BF16_DISPATCH` | Use Bfloat16 for dispatch | `"false"` |
| `SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH` | Use nvfp4 for dispatch | `"false"` |

## Memory Management

Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/compilation/weak_ref_tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/ops.h

#include <torch/extension.h>

#include <vector>

static at::Tensor weak_ref_tensor(at::Tensor &tensor) {
static at::Tensor weak_ref_tensor(at::Tensor& tensor) {
TORCH_CHECK(tensor.is_cuda(), "weak_ref_tensor expects a CUDA tensor");

void *data_ptr = tensor.data_ptr();
void* data_ptr = tensor.data_ptr();
std::vector<int64_t> sizes = tensor.sizes().vec();
std::vector<int64_t> strides = tensor.strides().vec();

Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional

import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
from sgl_kernel.gemm import (
scaled_fp4_grouped_quant,
silu_and_mul_scaled_fp4_grouped_quant,
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked


def get_cute_dtype(input: torch.Tensor) -> str:
Expand Down Expand Up @@ -97,10 +97,10 @@ def flashinfer_cutedsl_moe_masked(
num_experts,
), f"input_global_scale must be (l,), got {input_global_scale.shape}"

a_q, a_q_sf = scaled_fp4_grouped_quant(
a_q, a_q_sf = scaled_fp4_grouped_quantize(
hidden_states[0],
input_global_scale,
masked_m,
input_global_scale,
)

assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
Expand Down Expand Up @@ -148,10 +148,10 @@ def flashinfer_cutedsl_moe_masked(
) # in logical [m, n, l]

# SILU and quantization
diq, diq_sf = silu_and_mul_scaled_fp4_grouped_quant(
diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize(
gateup_output.permute(2, 0, 1),
a2_global_scale,
masked_m,
a2_global_scale,
)

if down_start_event is not None:
Expand Down
5 changes: 0 additions & 5 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor output_scale_offset_by_experts) -> ()");
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);

m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);

m.def(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
Expand Down
83 changes: 0 additions & 83 deletions sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -643,86 +643,3 @@ void scaled_fp4_experts_quant_sm100a(
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}

void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul) {
auto sm_version = getSMVersion();
TORCH_CHECK(sm_version >= 100, "fp4_quant is only supported on sm100+");

CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(mask, "mask must be a CUDA tensor");

TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);

TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(mask.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);

const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k_by_2 = input.size(1);
auto k = k_by_2;
if (use_silu_and_mul) {
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
k = k_by_2 / 2;
}
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(mask.size(0) == n_experts);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);

auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
nullptr, // input_offset_by_experts
nullptr, // output_scale_offset_by_experts
mask.data_ptr(),
use_silu_and_mul,
m_topk,
k,
n_experts,
stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
nullptr, // input_offset_by_experts
nullptr, // output_scale_offset_by_experts
mask.data_ptr(),
use_silu_and_mul,
m_topk,
k,
n_experts,
stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
22 changes: 0 additions & 22 deletions sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ void scaled_fp4_experts_quant_sm100a(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);

void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul);

#endif

void scaled_fp4_quant(
Expand All @@ -61,17 +53,3 @@ void scaled_fp4_experts_quant(
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}

void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, mask, use_silu_and_mul);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}
9 changes: 0 additions & 9 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,6 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);

void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& mask,
bool use_silu_and_mul);

/*
* From csrc/moe/cutlass_moe/w4a8
*/
Expand Down
2 changes: 0 additions & 2 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,13 @@
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_8bit,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.hadamard import (
Expand Down
Loading
Loading