diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 227be1497d0e..3a1c3569cdab 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -30,7 +30,7 @@ def __init__( self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point - self.modules_to_not_convert = modules_to_not_convert or [] + self.ignored_modules = modules_to_not_convert or [] if self.weight_bits != 4: raise ValueError( @@ -42,7 +42,7 @@ def __repr__(self) -> str: return (f"AWQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"zero_point={self.zero_point}, " - f"modules_to_not_convert={self.modules_to_not_convert})") + f"modules_to_not_convert={self.ignored_modules})") def get_name(self) -> str: return "awq" @@ -75,14 +75,14 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped_awq(prefix, self.ignored_modules): return UnquantizedLinearMethod() return AWQLinearMethod(self) return None -def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): - return any(module_name in prefix for module_name in modules_to_not_convert) +def is_layer_skipped_awq(prefix: str, ignored_modules: List[str]): + return any(module_name in prefix for module_name in ignored_modules) class AWQLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 5ef11546fd41..779f3fb0d325 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -62,8 +62,9 @@ class QuantizationConfig(ABC): def __init__(self): super().__init__() - # mapping is updated by models as they initialize + # These attributes are updated by models as they initialize self.packed_modules_mapping: Dict[str, List[str]] = dict() + self.ignored_modules: List[str] = list() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1e8e7aa1b8c1..eb7dca3677e4 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -39,7 +39,7 @@ def __init__( self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight - self.llm_int8_skip_modules = llm_int8_skip_modules or [] + self.ignored_modules = llm_int8_skip_modules or [] self.llm_int8_threshold = llm_int8_threshold if self.bnb_4bit_quant_storage not in ["uint8"]: @@ -52,7 +52,7 @@ def __repr__(self) -> str: f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " - f"llm_int8_skip_modules={self.llm_int8_skip_modules})") + f"llm_int8_skip_modules={self.ignored_modules})") @classmethod def get_name(self) -> str: @@ -122,25 +122,25 @@ def get_safe_value(config, keys, default_value=None): def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): + if is_layer_skipped_bnb(prefix, self.ignored_modules): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) return None -def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): +def is_layer_skipped_bnb(prefix: str, ignored_modules: List[str]): # Split the prefix into its dot-separated components components = prefix.split('.') # Check if any of the skip modules exactly matches any component substr_check = any(module_name in components - for module_name in llm_int8_skip_modules) + for module_name in ignored_modules) # Allow certain layers to not be quantized set_components = set(".".join(components[:i + 1]) for i in range(len(components))) - set_llm_int8_skip_modules = set(llm_int8_skip_modules) - prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 + set_ignored_modules = set(ignored_modules) + prefix_check = len(set_ignored_modules & set_components) != 0 return substr_check or prefix_check diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c77324bab59c..0aa9ecfde19f 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, - Protocol, Type, Union, overload, runtime_checkable) +import inspect +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, + Optional, Protocol, Type, Union, overload, + runtime_checkable) import torch from torch import Tensor @@ -452,28 +454,53 @@ class SupportsQuant: quant_config: Optional[QuantizationConfig] = None def __new__(cls, *args, **kwargs) -> Self: + from .utils import WeightsMapper # avoid circular import + instance = super().__new__(cls) - quant_config = cls._find_quant_config(*args, **kwargs) + bound_args = inspect.signature(cls.__init__).bind( + instance, *args, **kwargs).arguments + + quant_config = cls._find_quant_config(bound_args) + prefix = cls._find_prefix(bound_args) + packed_modules_mapping = cls.packed_modules_mapping + hf_to_vllm_mapper: WeightsMapper = getattr(cls, "hf_to_vllm_mapper", + WeightsMapper()) + if quant_config is not None: + # 1. update qconfig's packed_modules_mapppings + # currently takes union, in the future could be more precise + # using prefix and hf_to_vllm_mapper + quant_config.packed_modules_mapping.update(packed_modules_mapping) + + # 2. update qconfig's ignored modules + quant_config.ignored_modules = [ + prefix + hf_to_vllm_mapper._map_name(module[len(prefix):]) + if module.startswith(prefix) else module + for module in quant_config.ignored_modules + ] + + # 3. set module's quantization config instance.quant_config = quant_config - instance.quant_config.packed_modules_mapping.update( - cls.packed_modules_mapping) + return instance @staticmethod - def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + def _find_quant_config( + bound_args: Dict[str, Any]) -> Optional[QuantizationConfig]: from vllm.config import VllmConfig # avoid circular import - args_values = list(args) + list(kwargs.values()) - for arg in args_values: + for arg in bound_args.values(): if isinstance(arg, VllmConfig): return arg.quant_config - - if isinstance(arg, QuantizationConfig): + elif isinstance(arg, QuantizationConfig): return arg return None + @staticmethod + def _find_prefix(bound_args: Dict[str, Any]) -> str: + return bound_args.get("prefix", "") + @runtime_checkable class SupportsTranscription(Protocol): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1e6ff1fec6d5..b91dd91e91eb 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -60,7 +60,7 @@ from vllm.transformers_utils.config import uses_mrope from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) + SupportsMultiModal, SupportsPP, SupportsQuant) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) @@ -764,7 +764,8 @@ def _get_mm_fields_config( info=Qwen2_5_VLProcessingInfo, dummy_inputs=Qwen2_5_VLDummyInputsBuilder) class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + SupportsLoRA, SupportsPP, + SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj",