From b660bf60a50c26a2ff15a812fbcc3082cd192f49 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Feb 2025 12:15:11 -0500 Subject: [PATCH 1/7] add SupportsQuant to phi3 and clip Signed-off-by: Kyle Sayers --- .../layers/quantization/base_config.py | 4 ++-- vllm/model_executor/models/clip.py | 6 ++++-- vllm/model_executor/models/interfaces.py | 14 ++++++++++++++ vllm/model_executor/models/phi3v.py | 11 ++++++----- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index c0d8553c0df1..a7917643493b 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Mapping, Optional, Type +from typing import Any, Dict, List, Optional, Type import torch from torch import nn @@ -59,7 +59,7 @@ def method_has_implemented_embedding( class QuantizationConfig(ABC): """Base class for quantization configs.""" - packed_modules_mapping: Mapping[str, List[str]] = dict() + packed_modules_mapping: Dict[str, List[str]] = dict() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 1e784f5b4172..24513d58ae12 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -19,6 +19,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsQuant from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) @@ -468,10 +469,10 @@ def forward( return encoder_outputs -class CLIPVisionModel(nn.Module): - +class CLIPVisionModel(nn.Module, SupportsQuant): config_class = CLIPVisionConfig main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__( self, @@ -483,6 +484,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + SupportsQuant.__init__(self, quant_config) self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db179c..0b8ca7ae0865 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.config import QuantizationConfig from vllm.multimodal.inputs import NestedTensors # noqa: F401 from vllm.sequence import IntermediateTensors @@ -441,3 +442,16 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +class SupportsQuant: + """The interface required for all models that support quantization.""" + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + quant_config: Optional[QuantizationConfig] = None + + def __init__(self, quant_config: "QuantizationConfig"): + super().__init__() + self.quant_config = quant_config + self.quant_config.packed_modules_mapping.update( + self.packed_modules_mapping) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 053390c521fc..af8bb62063be 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -50,7 +50,7 @@ from vllm.utils import is_list_of from .clip import CLIPVisionModel -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -498,7 +498,8 @@ def _apply_prompt_replacements( @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, dummy_inputs=Phi3VDummyInputsBuilder) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -509,8 +510,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + SupportsQuant.__init__(self, vllm_config.quant_config) config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -520,14 +521,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - quant_config=quant_config, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( config, - quant_config, + self.quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) self.language_model = init_vllm_registered_model( From e945180453434d010591fa32e6c04b01c2cab697 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Feb 2025 13:02:32 -0500 Subject: [PATCH 2/7] fix type hint Signed-off-by: Kyle Sayers --- vllm/model_executor/models/interfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0b8ca7ae0865..591aeec74c85 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -448,7 +448,7 @@ class SupportsQuant: """The interface required for all models that support quantization.""" packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} - quant_config: Optional[QuantizationConfig] = None + quant_config: Optional["QuantizationConfig"] = None def __init__(self, quant_config: "QuantizationConfig"): super().__init__() From 630ef76fe1acfc3f9c386a686ac24313d839d6a1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 13 Feb 2025 12:13:39 -0500 Subject: [PATCH 3/7] overload __new__ Signed-off-by: Kyle Sayers --- vllm/model_executor/models/clip.py | 1 - vllm/model_executor/models/interfaces.py | 34 ++++++++++++++++++------ vllm/model_executor/models/phi3v.py | 1 - 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 1691dbf6b0f2..73c109a27ac7 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -351,7 +351,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - SupportsQuant.__init__(self, quant_config) self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 136e8f3d9e8a..bd6661d668d9 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,13 +7,14 @@ from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.utils import supports_kw from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata - from vllm.config import QuantizationConfig from vllm.multimodal.inputs import NestedTensors # noqa: F401 from vllm.sequence import IntermediateTensors @@ -448,13 +449,30 @@ class SupportsQuant: """The interface required for all models that support quantization.""" packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} - quant_config: Optional["QuantizationConfig"] = None - - def __init__(self, quant_config: "QuantizationConfig"): - super().__init__() - self.quant_config = quant_config - self.quant_config.packed_modules_mapping.update( - self.packed_modules_mapping) + quant_config: Optional[QuantizationConfig] = None + + def __new__(cls, *args, **kwargs) -> "SupportsQuant": + instance = super().__new__(cls) + quant_config = cls._find_quant_config(*args, **kwargs) + if quant_config is not None: + instance.quant_config = quant_config + instance.quant_config.packed_modules_mapping.update( + cls.packed_modules_mapping) + return instance + + @staticmethod + def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + from vllm.config import VllmConfig # avoid circular import + + args_values = list(args) + list(kwargs.values()) + for arg in args_values: + if isinstance(arg, VllmConfig): + return arg.quant_config + + if isinstance(arg, QuantizationConfig): + return arg + + return None @runtime_checkable diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index af8bb62063be..6bbfa40beed1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -510,7 +510,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - SupportsQuant.__init__(self, vllm_config.quant_config) config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config From d7cf2c603f46b4582bc95318b20bf577f4a443e4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 14 Feb 2025 13:39:16 -0500 Subject: [PATCH 4/7] Initialize packed_modules_mapping on instance Signed-off-by: Kyle Sayers --- vllm/model_executor/layers/quantization/aqlm.py | 1 + vllm/model_executor/layers/quantization/awq.py | 1 + vllm/model_executor/layers/quantization/awq_marlin.py | 1 + vllm/model_executor/layers/quantization/base_config.py | 5 ++++- vllm/model_executor/layers/quantization/bitsandbytes.py | 2 +- .../quantization/compressed_tensors/compressed_tensors.py | 2 +- vllm/model_executor/layers/quantization/deepspeedfp.py | 1 + vllm/model_executor/layers/quantization/experts_int8.py | 2 +- vllm/model_executor/layers/quantization/fbgemm_fp8.py | 1 + vllm/model_executor/layers/quantization/fp8.py | 1 + vllm/model_executor/layers/quantization/gguf.py | 2 +- vllm/model_executor/layers/quantization/gptq.py | 1 + vllm/model_executor/layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/layers/quantization/gptq_marlin_24.py | 1 + vllm/model_executor/layers/quantization/hqq_marlin.py | 1 + vllm/model_executor/layers/quantization/ipex_quant.py | 1 + vllm/model_executor/layers/quantization/modelopt.py | 1 + vllm/model_executor/layers/quantization/moe_wna16.py | 1 + vllm/model_executor/layers/quantization/neuron_quant.py | 1 + vllm/model_executor/layers/quantization/qqq.py | 1 + vllm/model_executor/layers/quantization/quark/quark.py | 1 + vllm/model_executor/layers/quantization/tpu_int8.py | 1 + 22 files changed, 25 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6c08d016c0f7..10f5241f9a71 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -169,6 +169,7 @@ def __init__( num_codebooks: int, out_group_size: int, ) -> None: + super().__init__() self.in_group_size = in_group_size self.nbits_per_codebook = nbits_per_codebook self.num_codebooks = num_codebooks diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ff77af44d770..227be1497d0e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -26,6 +26,7 @@ def __init__( zero_point: bool, modules_to_not_convert: Optional[List[str]] = None, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index a43b2e597c1e..b921f43f9810 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -46,6 +46,7 @@ def __init__(self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, modules_to_not_convert: Optional[List[str]], full_config: Dict[str, Any]) -> None: + super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.zero_point = zero_point diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index a7917643493b..25710d6dc720 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -59,7 +59,10 @@ def method_has_implemented_embedding( class QuantizationConfig(ABC): """Base class for quantization configs.""" - packed_modules_mapping: Dict[str, List[str]] = dict() + + def __init__(self): + super().__init__() + self.packed_modules_mapping = dict() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 49d992d4cb07..33c2ca93ffa1 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -30,7 +30,7 @@ def __init__( llm_int8_skip_modules: Optional[List[str]] = None, llm_int8_threshold: float = 6.0, ) -> None: - + super().__init__() self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 6ee3e9362f8d..b6d86a81fcbc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -51,7 +51,7 @@ def __init__( kv_cache_scheme: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, ): - + super().__init__() self.ignore = ignore self.quant_format = quant_format # Map from [target -> scheme] diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index b4123650149f..67934d37284e 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -25,6 +25,7 @@ def __init__( weight_bits: int = 8, group_size: int = 512, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.valid_types = [torch.bfloat16, torch.float16] diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 87fbcf62ac1e..663fb8bf5b8e 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -17,7 +17,7 @@ class ExpertsInt8Config(QuantizationConfig): """Config class for Int8 experts quantization.""" def __init__(self) -> None: - pass + super().__init__() @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index da5ef36c5105..3bb8188f725c 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -29,6 +29,7 @@ class FBGEMMFp8Config(QuantizationConfig): """Config class for FBGEMM Fp8.""" def __init__(self, ignore_list: List[str], input_scale_ub: float): + super().__init__() self.ignore_list = ignore_list if ignore_list else [] self.input_scale_ub = input_scale_ub diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 86e025310f4e..f928ea7e23ca 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -47,6 +47,7 @@ def __init__( ignored_layers: Optional[List[str]] = None, weight_block_size: Optional[List[int]] = None, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning("Detected fp8 checkpoint. Please note that the " diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 86e6dbb5a5fb..b1fecb32f4d8 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -20,7 +20,7 @@ class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" def __init__(self, ) -> None: - pass + super().__init__() def __repr__(self) -> str: return ("GGUFConfig()") diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 6d1f0cc2eb4d..09291c2bf1f0 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -58,6 +58,7 @@ def __init__( # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers # } + super().__init__() self.dynamic = dynamic self.weight_bits = weight_bits diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0a9d86b008db..2ae30eb29a00 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -53,6 +53,7 @@ def __init__( lm_head_quantized: bool, dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: + super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index cec984483fd8..dd747e182e28 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -38,6 +38,7 @@ def __init__( weight_bits: int, group_size: int, ) -> None: + super().__init__() quant_type = { 4: scalar_types.uint4b8, 8: scalar_types.uint8b128, diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 432f43688ff5..4edc9aa848a1 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -33,6 +33,7 @@ def __init__( group_size: int, skip_modules: Optional[List[str]] = None, ) -> None: + super().__init__() assert group_size == 64, ("The only supported HQQ group size is " "currently 64.") assert weight_bits == 4, ("The only supported HQQ quantization " diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 2531170ececf..c09cc13cb276 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -35,6 +35,7 @@ def __init__( desc_act: Optional[bool] = None, lm_head_quantized: Optional[bool] = None, ) -> None: + super().__init__() self.method = method self.weight_bits = weight_bits self.group_size = group_size diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 348e9bccd9b0..050130de1c0f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -28,6 +28,7 @@ def __init__( self, is_checkpoint_fp8_serialized: bool = False, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index b9460e7d7985..77c6cab92f1c 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -29,6 +29,7 @@ def __init__(self, linear_quant_method: str, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool, modules_to_not_convert: Optional[List[str]], full_config: Dict[str, Any]) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.has_zp = has_zp diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py index a8e8be207fd1..82954612fb2a 100644 --- a/vllm/model_executor/layers/quantization/neuron_quant.py +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -20,6 +20,7 @@ def __init__( dequant_dtype: str = "f16", quantize_method: str = "vector_dynamic", ) -> None: + super().__init__() self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index 6e9d3dc6cb37..1e05917a5187 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -39,6 +39,7 @@ def __init__( group_size: int, is_sym: bool = True, ) -> None: + super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.is_sym = is_sym diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index ba123565a0ec..ca71da8b736a 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -30,6 +30,7 @@ def __init__(self, kv_cache_group: Optional[List[str]] = None, kv_cache_config: Optional[Dict[str, Any]] = None, pack_method: str = "reorder"): + super().__init__() if kv_cache_group is None: kv_cache_group = [] self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 3234fecaa3b3..14e5bcf6e5bb 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -21,6 +21,7 @@ def __init__( self, activation_scheme: str = "none", ) -> None: + super().__init__() if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( f"Unsupported activation scheme {activation_scheme}") From 697b82c76ab772dc755c63b1ed2457b08e37dee2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 14 Feb 2025 13:43:52 -0500 Subject: [PATCH 5/7] remove mixin changes Signed-off-by: Kyle Sayers --- vllm/model_executor/models/clip.py | 154 +++++++++++++++++++++-- vllm/model_executor/models/interfaces.py | 59 --------- vllm/model_executor/models/phi3v.py | 10 +- 3 files changed, 148 insertions(+), 75 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 73c109a27ac7..1e784f5b4172 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,25 +1,156 @@ # SPDX-License-Identifier: Apache-2.0 """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Iterable, List, Optional, Set, Tuple, Union +import numpy as np import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig from vllm.attention.layer import MultiHeadAttention +from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import SequenceData from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_clip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: + return get_clip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + 1 + + +def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: + return get_clip_image_feature_size(hf_config) + + +def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): + if image_feature_size_override is None: + image_feature_size = get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_prompt_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } + + +def dummy_image_for_clip( + hf_config: CLIPVisionConfig, + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_video_for_clip( + hf_config: CLIPVisionConfig, + num_frames: int, + num_videos: int = 1, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + pil_frame = dummy_image_for_clip( + hf_config, + num_images=1, + image_width_override=image_width_override, + image_height_override=image_height_override) + np_frame = np.array(pil_frame["image"]) + mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) + video_data = [mm_data_per_video] * num_videos + mm_data = {"video": video_data} + return mm_data + + +def input_processor_for_clip( + model_config: ModelConfig, + hf_config: CLIPVisionConfig, + inputs: DecoderOnlyInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[Union[int, List[int]]] = None, +): + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_clip_image_feature_size(hf_config) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( + tokenizer, + inputs.get("prompt"), + inputs["prompt_token_ids"], + placeholder_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) + + class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): def get_num_image_tokens( @@ -28,10 +159,10 @@ def get_num_image_tokens( image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + 1 + return get_clip_image_feature_size(self.vision_config) def get_max_image_tokens(self) -> int: - return self.get_patch_grid_length()**2 + 1 + return get_max_clip_image_tokens(self.vision_config) def get_image_size(self) -> int: return self.vision_config.image_size @@ -40,9 +171,10 @@ def get_patch_size(self) -> int: return self.vision_config.patch_size def get_patch_grid_length(self) -> int: - image_size, patch_size = self.get_image_size(), self.get_patch_size() - assert image_size % patch_size == 0 - return image_size // patch_size + return get_clip_patch_grid_length( + image_size=self.vision_config.image_size, + patch_size=self.vision_config.patch_size, + ) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa @@ -54,7 +186,6 @@ def __init__(self, config: CLIPVisionConfig): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size - assert self.image_size % self.patch_size == 0 self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) @@ -66,7 +197,8 @@ def __init__(self, config: CLIPVisionConfig): bias=False, ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = get_clip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) @@ -336,10 +468,10 @@ def forward( return encoder_outputs -class CLIPVisionModel(nn.Module, SupportsQuant): +class CLIPVisionModel(nn.Module): + config_class = CLIPVisionConfig main_input_name = "pixel_values" - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__( self, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index bd6661d668d9..0fc5c4db179c 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,8 +7,6 @@ from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.utils import supports_kw from .interfaces_base import is_pooling_model @@ -443,60 +441,3 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) - - -class SupportsQuant: - """The interface required for all models that support quantization.""" - - packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} - quant_config: Optional[QuantizationConfig] = None - - def __new__(cls, *args, **kwargs) -> "SupportsQuant": - instance = super().__new__(cls) - quant_config = cls._find_quant_config(*args, **kwargs) - if quant_config is not None: - instance.quant_config = quant_config - instance.quant_config.packed_modules_mapping.update( - cls.packed_modules_mapping) - return instance - - @staticmethod - def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: - from vllm.config import VllmConfig # avoid circular import - - args_values = list(args) + list(kwargs.values()) - for arg in args_values: - if isinstance(arg, VllmConfig): - return arg.quant_config - - if isinstance(arg, QuantizationConfig): - return arg - - return None - - -@runtime_checkable -class SupportsTranscription(Protocol): - """The interface required for all models that support transcription.""" - - supports_transcription: ClassVar[Literal[True]] = True - - -@overload -def supports_transcription( - model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: - ... - - -@overload -def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: - ... - - -def supports_transcription( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: - if isinstance(model, type): - return isinstance(model, SupportsTranscription) - - return isinstance(model, SupportsTranscription) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6bbfa40beed1..053390c521fc 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -50,7 +50,7 @@ from vllm.utils import is_list_of from .clip import CLIPVisionModel -from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -498,8 +498,7 @@ def _apply_prompt_replacements( @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, info=Phi3VProcessingInfo, dummy_inputs=Phi3VDummyInputsBuilder) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsQuant): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -511,6 +510,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -520,14 +520,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - quant_config=self.quant_config, + quant_config=quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( config, - self.quant_config, + quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) self.language_model = init_vllm_registered_model( From a2bccf7bf1a7e17b391ac91ed4cc82bf2cf6b762 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 14 Feb 2025 13:46:03 -0500 Subject: [PATCH 6/7] merge issues Signed-off-by: Kyle Sayers --- vllm/model_executor/models/clip.py | 149 ++--------------------- vllm/model_executor/models/interfaces.py | 27 ++++ 2 files changed, 35 insertions(+), 141 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 1e784f5b4172..547f62447816 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,156 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple, Union -import numpy as np import torch import torch.nn as nn -from PIL import Image from transformers import CLIPVisionConfig from vllm.attention.layer import MultiHeadAttention -from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) -from vllm.sequence import SequenceData from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs -def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: - assert image_size % patch_size == 0 - return image_size // patch_size - - -def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_clip_patch_grid_length(image_size=image_size, - patch_size=patch_size) - return grid_length * grid_length - - -def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: - return get_clip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) + 1 - - -def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: - return get_clip_image_feature_size(hf_config) - - -def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, - mm_key: str = "image"): - if image_feature_size_override is None: - image_feature_size = get_clip_image_feature_size(hf_config) - else: - image_feature_size = image_feature_size_override - - return SequenceData.from_prompt_token_counts( - (image_token_id, image_feature_size * num_images), - (0, seq_len - image_feature_size * num_images), - ), { - mm_key: - consecutive_placeholder_ranges(num_items=num_images, - item_size=image_feature_size) - } - - -def dummy_image_for_clip( - hf_config: CLIPVisionConfig, - num_images: int, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override - - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - - -def dummy_video_for_clip( - hf_config: CLIPVisionConfig, - num_frames: int, - num_videos: int = 1, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - pil_frame = dummy_image_for_clip( - hf_config, - num_images=1, - image_width_override=image_width_override, - image_height_override=image_height_override) - np_frame = np.array(pil_frame["image"]) - mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) - video_data = [mm_data_per_video] * num_videos - mm_data = {"video": video_data} - return mm_data - - -def input_processor_for_clip( - model_config: ModelConfig, - hf_config: CLIPVisionConfig, - inputs: DecoderOnlyInputs, - *, - image_token_id: int, - image_feature_size_override: Optional[Union[int, List[int]]] = None, -): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - if "multi_modal_placeholders" in inputs and "image" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - - if image_feature_size_override is None: - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_clip_image_feature_size(hf_config) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - else: - image_feature_size = image_feature_size_override - - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) - - # NOTE: Create a defensive copy of the original inputs - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"image": ranges}) - - class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): def get_num_image_tokens( @@ -159,10 +27,10 @@ def get_num_image_tokens( image_width: int, image_height: int, ) -> int: - return get_clip_image_feature_size(self.vision_config) + return self.get_patch_grid_length()**2 + 1 def get_max_image_tokens(self) -> int: - return get_max_clip_image_tokens(self.vision_config) + return self.get_patch_grid_length()**2 + 1 def get_image_size(self) -> int: return self.vision_config.image_size @@ -171,10 +39,9 @@ def get_patch_size(self) -> int: return self.vision_config.patch_size def get_patch_grid_length(self) -> int: - return get_clip_patch_grid_length( - image_size=self.vision_config.image_size, - patch_size=self.vision_config.patch_size, - ) + image_size, patch_size = self.get_image_size(), self.get_patch_size() + assert image_size % patch_size == 0 + return image_size // patch_size # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa @@ -186,6 +53,7 @@ def __init__(self, config: CLIPVisionConfig): self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) @@ -197,8 +65,7 @@ def __init__(self, config: CLIPVisionConfig): bias=False, ) - self.num_patches = get_clip_num_patches(image_size=self.image_size, - patch_size=self.patch_size) + self.num_patches = (self.image_size // self.patch_size)**2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0fc5c4db179c..a0a1b69ad502 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -441,3 +441,30 @@ def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: return is_pooling_model(model) and _supports_cross_encoding(model) + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + +@overload +def supports_transcription( + model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) From e8a5bd4eb721aa180c02358ccd879aeb40864be7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 14 Feb 2025 13:46:36 -0500 Subject: [PATCH 7/7] add type hint Signed-off-by: Kyle Sayers --- vllm/model_executor/layers/quantization/base_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 25710d6dc720..980be2196918 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -62,7 +62,7 @@ class QuantizationConfig(ABC): def __init__(self): super().__init__() - self.packed_modules_mapping = dict() + self.packed_modules_mapping: Dict[str, List[str]] = dict() @abstractmethod def get_name(self) -> str: