diff --git a/vllm/config.py b/vllm/config.py index fcbf962ac685..9635d7fc49a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -28,6 +28,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum, current_platform @@ -765,12 +766,43 @@ def _verify_quantization(self) -> None: if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "marlin", + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built in ones. + quantization_methods = quantization_methods + overrides + # Detect which checkpoint is it - for name in QUANTIZATION_METHODS: + for name in quantization_methods: method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) - if quantization_override: + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + if (name in get_args(QuantizationMethods) + and name not in overrides): + raise ValueError( + f"Quantization method {name} is an override but " + "is has not been added to the `overrides` list " + "above. This is necessary to ensure that the " + "overrides are checked in order of preference.") quant_method = quantization_override self.quantization = quantization_override break diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9e1bf05dab9e..15e08220b7b5 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Type +from typing import Literal, Type, get_args from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -QUANTIZATION_METHODS: List[str] = [ +QuantizationMethods = Literal[ "aqlm", "awq", "deepspeedfp", @@ -15,8 +15,6 @@ "fbgemm_fp8", "modelopt", "nvfp4", - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin", "bitblas", "gguf", @@ -36,6 +34,7 @@ "moe_wna16", "torchao", ] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) # The customized quantization methods which will be added to this dict. _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} @@ -111,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .torchao import TorchAOConfig from .tpu_int8 import Int8TpuConfig - method_to_config: Dict[str, Type[QuantizationConfig]] = { + method_to_config: dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, @@ -120,8 +119,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, "nvfp4": ModelOptNvFp4Config, - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gguf": GGUFConfig, @@ -150,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: __all__ = [ "QuantizationConfig", + "QuantizationMethods", "get_quantization_config", "QUANTIZATION_METHODS", ] \ No newline at end of file