diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6f6ec68d8eb3..071550154f13 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -77,6 +77,7 @@ is_npu, is_sm90_supported, is_sm100_supported, + is_sm120_supported, log_info_on_rank0, print_warning_once, set_weight_attrs, @@ -683,7 +684,9 @@ def __init__(self, quant_config: Fp8Config): cutlass_fp8_supported() ), "cutlass_fp8 MoE requires CUDA 12.0+ with SM90 or CUDA 12.4+ with SM89" assert self.block_quant, "cutlass_fp8 MoE requires block quantization" - assert is_sm100_supported() or is_sm90_supported() + assert ( + is_sm100_supported() or is_sm90_supported() or is_sm120_supported() + ), "cutlass_fp8 MoE requires SM90, SM100, or SM120 GPUs" @staticmethod def is_deepgemm_moe_runner_backend_enabled() -> bool: diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 4a143d724f95..1466bac6bec4 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -37,6 +37,8 @@ is_cpu, is_cuda, is_hip, + is_sm100_supported, + is_sm120_supported, log_info_on_rank0, ) from sglang.srt.utils.custom_op import register_custom_op @@ -1286,9 +1288,16 @@ def mxfp8_block_scaled_matmul_triton( block_m: int = 128, block_n: int = 256, block_k: int = 128, - num_stages: int = 4, + num_stages: Optional[int] = None, ) -> torch.Tensor: - """Block-scaled matmul for MXFP8 using Triton dot_scaled.""" + """Block-scaled matmul for MXFP8 using Triton dot_scaled. + + Args: + num_stages: Number of pipeline stages. If None, auto-selects based on GPU: + SM120: 1, SM100: 4. + """ + if num_stages is None: + num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) M, K = a.shape N, K_b = b.shape assert K == K_b diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3b9616e2798f..eb841e1edf35 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -41,6 +41,7 @@ is_hip, is_sm90_supported, is_sm100_supported, + is_sm120_supported, offloader, ) @@ -662,8 +663,8 @@ def triton_mxfp8_blockscaled_linear( bias: Optional[torch.Tensor] = None, output_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if not (_is_cuda and is_sm100_supported()): - raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100+).") + if not (_is_cuda and (is_sm100_supported() or is_sm120_supported())): + raise RuntimeError("MXFP8 dense linear requires Blackwell GPUs (SM100/SM120).") input_2d = input.view(-1, input.shape[-1]).contiguous() output_shape = [*input.shape[:-1], weight.shape[0]] @@ -714,6 +715,7 @@ def triton_mxfp8_blockscaled_linear( a_scale_packed = _pack_mxfp8_scales(x_scale_u8) b_scale_packed = _pack_mxfp8_scales(weight_scale) + num_stages = 1 if is_sm120_supported() else (4 if is_sm100_supported() else 1) output = mxfp8_block_scaled_matmul_triton( q_input, a_scale_packed, @@ -723,6 +725,7 @@ def triton_mxfp8_blockscaled_linear( block_m=block_m, block_n=block_n, block_k=block_k, + num_stages=num_stages, ) output = output[:m, :] if bias is not None: diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index edfd98e42a9c..6b73a736acb3 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -19,7 +19,7 @@ mxfp8_group_quantize, triton_mxfp8_blockscaled_linear, ) -from sglang.srt.utils import is_sm100_supported +from sglang.srt.utils import is_sm100_supported, is_sm120_supported from sglang.test.test_utils import CustomTestCase _is_cuda = torch.cuda.is_available() and torch.version.cuda @@ -452,8 +452,8 @@ class TestMXFP8DenseLinear(CustomTestCase): def setUpClass(cls): if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA is not available") - if not is_sm100_supported(): - raise unittest.SkipTest("MXFP8 requires Blackwell (SM100+)") + if not (is_sm100_supported() or is_sm120_supported()): + raise unittest.SkipTest("MXFP8 requires Blackwell (SM100/SM120)") torch.set_default_device("cuda") def _mxfp8_dense_linear(self, M, NK, dtype, seed):