Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/advanced_features/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ Backend selection is supported only for **blockwise FP8** and **NVFP4** GEMM. Wh
| Backend | Hardware | Description |
|---------|----------|-------------|
| `auto` | SM100/120 | Auto-selects: `flashinfer_cudnn` on SM120; `flashinfer_cutlass` on SM100 |
| `cutlass` | SM100/120 | SGLang CUTLASS kernel |
| `flashinfer_cutlass` | SM100/120 | FlashInfer CUTLASS backend |
| `flashinfer_cudnn` | SM100/120 (CUDA 13+, cuDNN 9.15+) | FlashInfer cuDNN backend; used on SM120 for performance |
| `flashinfer_trtllm` | SM100 | FlashInfer TensorRT-LLM backend |

When FlashInfer is unavailable for NVFP4, sgl-kernel CUTLASS is used as an automatic fallback.
When FlashInfer is unavailable for NVFP4, the SGLang CUTLASS kernel is used as an automatic fallback.

## Offline Quantization

Expand Down
6 changes: 3 additions & 3 deletions python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark
from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import is_sm100_supported
from sglang.srt.utils import is_sm100_supported, is_sm120_supported
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")

FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
BLOCK_SIZE = 16
_NVFP4_SUPPORTED = is_sm100_supported()
_NVFP4_SUPPORTED = is_sm100_supported() or is_sm120_supported()

K_E2M1_TO_FLOAT = [
0.0,
Expand Down Expand Up @@ -178,7 +178,7 @@ def benchmark(m, n, k, provider):

if __name__ == "__main__":
if not _NVFP4_SUPPORTED:
print("[skip] NVFP4 scaled_mm benchmark requires sm100+ with CUDA 12.8+.")
print("[skip] NVFP4 scaled_mm benchmark requires sm100/sm120 with CUDA 12.8+.")
sys.exit(0)
if not _AOT_SCALED_MM_AVAILABLE:
print(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2026 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <sgl_kernel/ffi.h>
#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.h>

#include <sgl_kernel/runtime.cuh>
#include <sgl_kernel/utils.cuh>

#include <cstddef>
#include <cstdint>
#include <cuda_runtime.h>

using namespace host;

// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}

using namespace cute;

inline uint32_t next_pow_2(uint32_t x) noexcept {
if (x <= 1) return 1;
return 1u << (32 - __builtin_clz(x - 1));
}

inline auto alloc_workspace_tensor(size_t required_bytes, DLDevice device) -> tvm::ffi::Tensor {
if (required_bytes == 0) return {};
DLDataType u8 = {kDLUInt, 8, 1};
int64_t shape[] = {static_cast<int64_t>(required_bytes)};
return ffi::empty(tvm::ffi::ShapeView(shape, 1), u8, device);
}

inline int getSMVersion(int device_id) {
int sm_major = 0;
int sm_minor = 0;
RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id));
RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id));
return sm_major * 10 + sm_minor;
}
Loading
Loading