Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _get_quantization_config(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
quant_config.maybe_update_config(model_config.model)
return quant_config
return None

Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,9 @@ def apply_vllm_mapper( # noqa: B027
"""
# TODO (@kylesayrs): add implementations for all subclasses
pass

def maybe_update_config(self, model_name: str): # noqa: B027
"""
Interface to update values after config initialization.
"""
pass
52 changes: 46 additions & 6 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Optional, Union

import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter

from vllm import _custom_ops as ops
Expand All @@ -22,6 +23,8 @@
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of


class GPTQConfig(QuantizationConfig):
Expand All @@ -38,6 +41,7 @@ def __init__(
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
autoround_version: str = "",
modules_in_block_to_quantize: Optional[list[str]] = None,
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
Expand Down Expand Up @@ -75,15 +79,20 @@ def __init__(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")

self.modules_in_block_to_quantize = modules_in_block_to_quantize or []

# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version

def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)

@classmethod
def get_name(cls) -> QuantizationMethods:
Expand Down Expand Up @@ -114,8 +123,10 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic, autoround_version)
dynamic, autoround_version, modules_in_block_to_quantize)

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
Expand All @@ -136,6 +147,35 @@ def get_quant_method(

return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)

def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)

def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return

unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)


class ExllamaState(Enum):

Expand Down
65 changes: 55 additions & 10 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Optional, Union

import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE

import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
Expand Down Expand Up @@ -35,6 +36,8 @@
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of

logger = init_logger(__name__)

Expand Down Expand Up @@ -71,10 +74,16 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}

def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any]) -> None:
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
full_config: dict[str, Any],
modules_in_block_to_quantize: Optional[list[str]] = None) -> None:
super().__init__()
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
Expand Down Expand Up @@ -121,15 +130,19 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,

self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]

self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")

def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
)

@classmethod
def get_name(cls) -> QuantizationMethods:
Expand Down Expand Up @@ -158,8 +171,11 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic, config)
lm_head_quantized, dynamic, config,
modules_in_block_to_quantize)

@classmethod
def override_quantization_method(
Expand Down Expand Up @@ -223,6 +239,35 @@ def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size)

def apply_vllm_mapper(self, hf_to_vllm_mapper):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize)

def maybe_update_config(self,
model_name: str,
revision: Optional[str] = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return

unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name,
revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get('dtype', None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
Comment on lines +242 to +269
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The methods apply_vllm_mapper and maybe_update_config are identical to the ones in vllm/model_executor/layers/quantization/gptq.py. This code duplication can lead to maintenance issues, where a change in one file might not be propagated to the other, potentially causing bugs.

To improve maintainability, I recommend refactoring this shared logic into a common base class or a mixin. For example, you could create a GPTQConfigMixin class to house these methods, and then have both GPTQConfig and GPTQMarlinConfig inherit from it. This would ensure that the logic for handling modules_in_block_to_quantize is defined in a single place.



class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Expand Down
52 changes: 51 additions & 1 deletion vllm/model_executor/layers/quantization/utils/gptq_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from copy import deepcopy
from fractions import Fraction
from types import MappingProxyType
from typing import Optional, Union

import regex as re
Expand Down Expand Up @@ -70,6 +72,49 @@ def get_dynamic_override(
return default_value


def is_layer_gptq_quantized(
prefix: str,
quantized_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj

# GPTQ's `modules_in_block_to_quantize`:
# Substr: ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"]
# Full prefix ["model.layers.0.self_attn.q_proj"]

proj_name = prefix.split(".")[-1]

# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]

is_quantized = None
for shard_prefix in shard_prefixes:
is_shard_quantized = any(layer in shard_prefix
for layer in quantized_layers)

if is_quantized is None:
is_quantized = is_shard_quantized
elif is_shard_quantized != is_quantized:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_quantized = any(layer in prefix for layer in quantized_layers)

assert is_quantized is not None
return is_quantized


def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
Expand All @@ -80,10 +125,15 @@ def get_linear_quant_method(
parallel_lm_head_quantized = isinstance(
layer, ParallelLMHead) and cloned_config.lm_head_quantized
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
is_layer_quantized = is_layer_gptq_quantized(
prefix=prefix,
quantized_layers=cloned_config.modules_in_block_to_quantize,
fused_mapping=cloned_config.packed_modules_mapping)
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
layer_name=prefix) == False or (
not is_layer_quantized): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
Expand Down
12 changes: 2 additions & 10 deletions vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand Down Expand Up @@ -1281,11 +1278,6 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

raise ValueError("Only image or video modality is supported")

def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
Expand All @@ -1297,14 +1289,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)

self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "mlp_AR"),
)

Expand Down
Loading