From 103bfd889c30e2d4e3862e0026ba764b9d0f92ba Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Fri, 21 Nov 2025 13:52:49 -0800 Subject: [PATCH] more --- .../srt/layers/quantization/fp8_utils.py | 38 +++++++++++-------- .../srt/layers/quantization/modelopt_quant.py | 2 +- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 4dddd407f296..0d1551fa9242 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -6,7 +6,7 @@ from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil -from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader +from sglang.srt.utils import ceil_div, offloader try: from vllm import _custom_ops as ops @@ -32,6 +32,8 @@ get_bool_env_var, get_cuda_version, get_device_capability, + get_device_sm, + is_blackwell_supported, is_cuda, is_flashinfer_available, is_hip, @@ -130,35 +132,41 @@ def cutlass_block_fp8_supported() -> bool: if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"): return False if _is_cuda: - major, minor = torch.cuda.get_device_capability() - sm_version = major * 10 + minor - cuda_version = tuple(map(int, torch.version.cuda.split("."))) - if cuda_version >= (12, 0) and sm_version >= 90: - return True + sm_version = get_device_sm() + cuda_version = get_cuda_version() + return cuda_version >= (12, 0) and sm_version >= 90 return False CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() -ENABLE_FLASHINFER_FP8_GEMM = ( + +FLASHINFER_FP8_GEMM_SUPPORTED = is_blackwell_supported() and is_flashinfer_available() + +ENABLE_FLASHINFER_FP8_GEMM = FLASHINFER_FP8_GEMM_SUPPORTED and ( envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get() - and is_blackwell_supported() - and is_flashinfer_available() + or ( + not deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.is_set() + ) ) if ENABLE_FLASHINFER_FP8_GEMM: from flashinfer.gemm import gemm_fp8_nt_groupwise def dispatch_w8a8_block_fp8_linear() -> Callable: + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: + return deepgemm_w8a8_block_fp8_linear_with_fallback + if ENABLE_FLASHINFER_FP8_GEMM: return flashinfer_gemm_w8a8_block_fp8_linear - elif CUTLASS_BLOCK_FP8_SUPPORTED: + + if CUTLASS_BLOCK_FP8_SUPPORTED: return cutlass_w8a8_block_fp8_linear_with_fallback - elif _use_aiter: + + if _use_aiter: return aiter_w8a8_block_fp8_linear - elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: - return deepgemm_w8a8_block_fp8_linear_with_fallback - else: - return triton_w8a8_block_fp8_linear + + return triton_w8a8_block_fp8_linear def flashinfer_gemm_w8a8_block_fp8_linear( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index e627a3151587..fa64463bda49 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -32,7 +32,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, - is_blackwell_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod @@ -45,6 +44,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils.common import ( get_bool_env_var, + is_blackwell_supported, is_cuda, is_sm120_supported, next_power_of_2,