diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index da6f2931219c..6f86b7ca197a 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,4 +1,5 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +import logging from typing import Callable, Dict, Optional, Type import torch @@ -28,6 +29,8 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +logger = logging.getLogger(__name__) + QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, @@ -81,6 +84,9 @@ def awq_get_quant_method(self, layer, prefix): AWQMarlinLinearMethod, AWQMoEMethod, ) + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer, + ) from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE @@ -91,6 +97,14 @@ def awq_get_quant_method(self, layer, prefix): ): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMarlin. " + "Falling back to unoptimized AWQ kernels." + ) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self)