diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 4dddd407f296..88a1e1df43c7 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -34,6 +34,7 @@ get_device_capability, is_cuda, is_flashinfer_available, + is_gfx95_supported, is_hip, ) @@ -45,10 +46,11 @@ if _use_aiter: import aiter + from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant - # from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant - from aiter import gemm_a8w8_bpreshuffle, get_hip_quant - from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale + if is_gfx95_supported(): + from aiter import gemm_a8w8_bpreshuffle, get_hip_quant + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)