Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader
from sglang.srt.utils import ceil_div, offloader
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The imports from sglang.srt.utils are split across two separate import statements (lines 9 and 30-40). Consider consolidating them into a single import statement for better code organization and consistency with Python style guidelines.

Copilot uses AI. Check for mistakes.

try:
from vllm import _custom_ops as ops
Expand All @@ -32,6 +32,8 @@
get_bool_env_var,
get_cuda_version,
get_device_capability,
get_device_sm,
is_blackwell_supported,
is_cuda,
is_flashinfer_available,
is_hip,
Expand Down Expand Up @@ -130,35 +132,41 @@ def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False
if _is_cuda:
major, minor = torch.cuda.get_device_capability()
sm_version = major * 10 + minor
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 0) and sm_version >= 90:
return True
sm_version = get_device_sm()
cuda_version = get_cuda_version()
return cuda_version >= (12, 0) and sm_version >= 90
return False


CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
ENABLE_FLASHINFER_FP8_GEMM = (

FLASHINFER_FP8_GEMM_SUPPORTED = is_blackwell_supported() and is_flashinfer_available()

ENABLE_FLASHINFER_FP8_GEMM = FLASHINFER_FP8_GEMM_SUPPORTED and (
envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get()
and is_blackwell_supported()
and is_flashinfer_available()
or (
not deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.is_set()
)
)
if ENABLE_FLASHINFER_FP8_GEMM:
from flashinfer.gemm import gemm_fp8_nt_groupwise


def dispatch_w8a8_block_fp8_linear() -> Callable:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback

if ENABLE_FLASHINFER_FP8_GEMM:
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:

if CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _use_aiter:

if _use_aiter:
return aiter_w8a8_block_fp8_linear
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
else:
return triton_w8a8_block_fp8_linear

return triton_w8a8_block_fp8_linear


def flashinfer_gemm_w8a8_block_fp8_linear(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
is_blackwell_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
Expand All @@ -45,6 +44,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils.common import (
get_bool_env_var,
is_blackwell_supported,
is_cuda,
is_sm120_supported,
next_power_of_2,
Expand Down
Loading