Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
is_npu,
is_sm90_supported,
is_sm100_supported,
is_sm120_supported,
log_info_on_rank0,
print_warning_once,
set_weight_attrs,
Expand Down Expand Up @@ -683,7 +684,9 @@ def __init__(self, quant_config: Fp8Config):
cutlass_fp8_supported()
), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89"
assert self.block_quant, "cutlass_fp8 MoE requires block quantization"
assert is_sm100_supported() or is_sm90_supported()
assert (
is_sm100_supported() or is_sm90_supported() or is_sm120_supported()
), "cutlass_fp8 MoE requires SM90, SM100, or SM120 GPUs"

@staticmethod
def is_deepgemm_moe_runner_backend_enabled() -> bool:
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
is_cpu,
is_cuda,
is_hip,
is_sm100_supported,
is_sm120_supported,
log_info_on_rank0,
)
from sglang.srt.utils.custom_op import register_custom_op
Expand Down Expand Up @@ -1286,9 +1288,16 @@ def mxfp8_block_scaled_matmul_triton(
block_m: int = 128,
block_n: int = 256,
block_k: int = 128,
num_stages: int = 4,
num_stages: Optional[int] = None,
) -> torch.Tensor:
"""Block-scaled matmul for MXFP8 using Triton dot_scaled."""
"""Block-scaled matmul for MXFP8 using Triton dot_scaled.

Args:
num_stages: Number of pipeline stages. If None, auto-selects based on GPU:
SM120: 1, SM100: 4.
"""
if num_stages is None:
num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1)
M, K = a.shape
N, K_b = b.shape
assert K == K_b
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_hip,
is_sm90_supported,
is_sm100_supported,
is_sm120_supported,
offloader,
)

Expand Down Expand Up @@ -662,8 +663,8 @@ def triton_mxfp8_blockscaled_linear(
bias: Optional[torch.Tensor] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if not (_is_cuda and is_sm100_supported()):
raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100+).")
if not (_is_cuda and (is_sm100_supported() or is_sm120_supported())):
raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100/SM120).")

input_2d = input.view(-1, input.shape[-1]).contiguous()
output_shape = [*input.shape[:-1], weight.shape[0]]
Expand Down Expand Up @@ -714,6 +715,7 @@ def triton_mxfp8_blockscaled_linear(
a_scale_packed = _pack_mxfp8_scales(x_scale_u8)
b_scale_packed = _pack_mxfp8_scales(weight_scale)

num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1)
output = mxfp8_block_scaled_matmul_triton(
q_input,
a_scale_packed,
Expand All @@ -723,6 +725,7 @@ def triton_mxfp8_blockscaled_linear(
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_stages=num_stages,
)
output = output[:m, :]
if bias is not None:
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/test/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
mxfp8_group_quantize,
triton_mxfp8_blockscaled_linear,
)
from sglang.srt.utils import is_sm100_supported
from sglang.srt.utils import is_sm100_supported, is_sm120_supported
from sglang.test.test_utils import CustomTestCase

_is_cuda = torch.cuda.is_available() and torch.version.cuda
Expand Down Expand Up @@ -452,8 +452,8 @@ class TestMXFP8DenseLinear(CustomTestCase):
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
if not is_sm100_supported():
raise unittest.SkipTest("MXFP8 requires Blackwell (SM100+)")
if not (is_sm100_supported() or is_sm120_supported()):
raise unittest.SkipTest("MXFP8 requires Blackwell (SM100/SM120)")
torch.set_default_device("cuda")

def _mxfp8_dense_linear(self, M, NK, dtype, seed):
Expand Down
Loading