From 1c9bf4527387a465652a48df6b56c3ca7e7be4a0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 14 Oct 2025 00:21:03 -0400 Subject: [PATCH 1/7] attention quant Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/__init__.py | 18 ++ src/compressed_tensors/modeling/attention.py | 146 ++++++++++++++++ src/compressed_tensors/modeling/kvcache.py | 163 ++++++++++++++++++ .../quantization/lifecycle/apply.py | 150 +++++----------- .../quantization/lifecycle/initialize.py | 109 ++++++++---- .../quantization/quant_config.py | 92 +++++----- .../quantization/utils/helpers.py | 35 +--- src/compressed_tensors/utils/helpers.py | 64 ++++++- src/compressed_tensors/utils/match.py | 29 ++++ .../test_modeling/test_attention_and_cache.py | 105 +++++++++++ .../test_quantization/lifecycle/test_apply.py | 40 +++++ tests/test_utils/test_match.py | 81 +++++++++ 12 files changed, 821 insertions(+), 211 deletions(-) create mode 100644 src/compressed_tensors/modeling/__init__.py create mode 100644 src/compressed_tensors/modeling/attention.py create mode 100644 src/compressed_tensors/modeling/kvcache.py create mode 100644 tests/test_modeling/test_attention_and_cache.py diff --git a/src/compressed_tensors/modeling/__init__.py b/src/compressed_tensors/modeling/__init__.py new file mode 100644 index 000000000..97ee4a2ca --- /dev/null +++ b/src/compressed_tensors/modeling/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: off +from .kvcache import * +from .attention import * diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 000000000..0f41dd139 --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional +from weakref import ref + +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization.lifecycle.forward import forward_quantize +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +__all__ = [ + "QuantizedAttentionImpl", + "initialize_hooked_attention", + "register_query_hook", + "IMPL_ATTR", +] + + +IMPL_ATTR = "impl" +HOOKED_ATTENTION_NAME = "ct_hooked_attention" + + +class QuantizedAttentionImpl(InternalModule): + """ + QuantizedAttentionImpl module which wraps the functionality of the original + attention implementation. Unlike the original attention function, this + implementation is a `torch.nn.Module` which can be hooked to trigger + transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_attention`, registering a new attention implementation function + which calls this module, then setting the model attention implementation to the new + function. After triggering hooks and quantization, this module calls the original + attention implementation function. + + :param attn_module: parent attention module + """ + + _original_impl = "eager" + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular references + + def forward( + self, + module: Module, + query: Tensor, + key: Tensor, + value: Tensor, + *args, + **kwargs, + ): + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled: + query = forward_quantize(module, query, "q", quant_args) + + # original attention + return ALL_ATTENTION_FUNCTIONS[_original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + +# ----- initialize ----- # + + +def _ct_hooked_attention(module: Module, *args, **kwargs): + if hasattr(module, IMPL_ATTR): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) + + +def initialize_hooked_attention(model: PreTrainedModel, module: Module): + """ + Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances + attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module)) + if model.config._attn_implementation != HOOKED_ATTENTION_NAME: + # assumes only one model at a time + global _original_impl + _original_impl = model.config._attn_implementation + + AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) + model.config._attn_implementation = HOOKED_ATTENTION_NAME + + initialize_hooked_kv_cache(model, module) + + +# ----- hooks ----- # + + +def register_query_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope query states as an argument and + returns the modified query states or `None` + + :param module: attention module to add hook to + :param hook: query hook function + """ + impl = getattr(module, IMPL_ATTR) + + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(impl.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value + + return bound.args, bound.kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 000000000..4727c9fab --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional, Tuple +from weakref import ref + +from compressed_tensors.quantization.lifecycle.forward import forward_quantize +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import Cache, PretrainedConfig, PreTrainedModel + + +__all__ = [ + "QuantizedKVCache", + "initialize_hooked_kv_cache", + "register_key_hook", + "register_value_hook", + "KV_CACHE_ATTR", +] + + +KV_CACHE_ATTR = "kv_cache" + + +class QuantizedKVCache(InternalModule): + """ + QuantizedKVCache module which wraps the functionality of any existing kvcache args. + Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be + hooked to trigger transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values` + kwargs with this module. This module adopts the functionality of the replaced cache, + preserving caching functionality such as sliding window attention, ect. + + :param attn_module: parent attention module + """ + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular reference + self.past_key_values: Optional[Cache] = None + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + *args, + **kwargs, + ) -> Tuple[Tensor, Tensor]: + # quantization + module = self.attn_module() + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + + # original cache + if self.past_key_values is not None: + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) + else: + ret = (key_states, value_states) + + self.past_key_values = None + return ret + + +# ----- initialize ----- # + + +def _kv_cache_attention_hook(module: Module, args, kwargs): + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + _past_kv_name = ( + "past_key_values" # transformers#39956 + if "past_key_values" in inspect.signature(module.forward).parameters + else "past_key_value" + ) + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) + kwargs[_past_kv_name] = kv_cache + + return args, kwargs + + +def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module): + """ + Initialize a `QuantizedKVCache` instance attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) + module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) + + +# ----- hooks ----- # + + +def register_key_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope key states as an argument and + returns the modified key states or `None` + + :param module: attention module to add hook to + :param hook: key hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes value states as an argument and + returns the modified value states or `None` + + :param module: attention module to add hook to + :param hook: value hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 21525bd4f..1bea5a4a9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections import OrderedDict from copy import deepcopy from typing import Dict, List, Optional @@ -21,8 +20,13 @@ import torch from compressed_tensors.config import CompressionFormat +from compressed_tensors.modeling import ( + initialize_hooked_attention, + initialize_hooked_kv_cache, +) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, + is_attention_module, ) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( @@ -30,12 +34,12 @@ QuantizationStatus, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import ( - KV_CACHE_TARGETS, - is_kv_cache_quant_scheme, -) from compressed_tensors.utils.helpers import replace_module -from compressed_tensors.utils.match import match_named_modules, match_targets +from compressed_tensors.utils.match import ( + is_narrow_match, + match_named_modules, + match_targets, +) from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open @@ -53,9 +57,6 @@ ) -_LOGGER = logging.getLogger(__name__) - - def load_pretrained_quantization_parameters( model: Module, model_name_or_path: Optional[str] = None, @@ -125,8 +126,14 @@ def apply_quantization_config( if config is None: # see PR #180 return dict() - # preprocess to support kv cache scheme - config = process_quantization_config(config) + # force zero points during initialization + force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED + + # apply and initialize kv cache quantization + if config.kv_cache_scheme is not None: + _apply_kv_cache_scheme( + model, config.kv_cache_scheme, config.quantization_status, force_zero_point + ) # build mapping of targets to schemes for easier matching # use ordered dict to preserve target ordering in config @@ -162,49 +169,40 @@ def apply_quantization_config( replace_module(model, name, compressed_linear) else: + if is_attention_module(submodule) and is_narrow_match( + model, scheme.targets, name + ): + initialize_hooked_attention(model, submodule) + initialize_module_for_quantization( submodule, - force_zero_point=config.quantization_status - != QuantizationStatus.COMPRESSED, + force_zero_point=force_zero_point, ) submodule.quantization_status = config.quantization_status -def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: - """ - Preprocess the raw QuantizationConfig - - :param config: the raw QuantizationConfig - :return: the processed QuantizationConfig - """ - if config.kv_cache_scheme is not None: - config = process_kv_cache_config(config) - - return config - - -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: - """ - Reformulate the `config.kv_cache` as a `config_group` - and add it to the set of existing `config.groups` - - :param config: the QuantizationConfig - :return: the QuantizationConfig with additional "kv_cache" group - """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") - - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, +def _apply_kv_cache_scheme( + model: torch.nn.Module, + kv_cache_scheme: QuantizationArgs, + status: QuantizationStatus, + force_zero_point: bool, +): + # applies and initializes kv cache quantization + # this step cannot come after attention apply/initialize + # otherwise it will override the attention qparams + scheme = QuantizationScheme( + targets=[".*self_attn$"], input_activations=kv_cache_scheme ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) - return config + for submodule in model.modules(): + if is_attention_module(submodule): + submodule.quantization_scheme = scheme + initialize_hooked_kv_cache(model, submodule) + initialize_module_for_quantization( + submodule, + force_zero_point=force_zero_point, + ) + submodule.quantization_status = status def _load_quant_args_from_mapping( @@ -256,60 +254,6 @@ def _scheme_from_targets( targets: List[str], name: str, ) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - - -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) - - merged_scheme.update(targets=[name]) - - return QuantizationScheme(**merged_scheme) + # return the first scheme (the prioritized one, + # since the order of target_to_scheme matters) + return target_to_scheme[targets[0]] diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 50757adc3..ef6d7957f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,11 +17,16 @@ from typing import Optional, Tuple, Union import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, + QuantizedAttentionImpl, + QuantizedKVCache, +) from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, - KVCacheScaleType, QuantizationArgs, QuantizationMetadata, QuantizationScheme, @@ -31,14 +36,13 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import ( - is_fp4, - is_kv_cache_quant_scheme, - strategy_cdiv, -) +from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, + get_head_dim, + get_num_attn_heads, + get_num_kv_heads, register_offload_parameter, ) from torch.nn import Module, Parameter @@ -48,6 +52,7 @@ "initialize_module_for_quantization", "is_attention_module", "initialize_qparams", + "initialize_attn_qparams", ] @@ -81,7 +86,7 @@ def initialize_module_for_quantization( if is_attention_module(module): # quantized actions based on calltime status - _initialize_attn_scales(module) + initialize_attn_qparams(module, scheme, force_zero_point) else: if not isinstance(module, torch.nn.Linear): @@ -120,8 +125,7 @@ def initialize_module_for_quantization( force_zero_point=force_zero_point, ) - output_is_kv_cache = is_kv_cache_quant_scheme(scheme) - if scheme.output_activations is not None and not output_is_kv_cache: + if scheme.output_activations is not None: initialize_qparams( module, "output", @@ -131,14 +135,14 @@ def initialize_module_for_quantization( force_zero_point=force_zero_point, ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED - with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status wrap_module_forward_quantized(module, scheme) + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + def is_attention_module(module: Module): return "attention" in module.__class__.__name__.lower() and ( @@ -199,7 +203,7 @@ def initialize_qparams( expected_shape = (1,) elif strategy == QuantizationStrategy.TOKEN: - raise ValueError("Cannot perform static token quantization") + expected_shape = (1, 1) elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 2: @@ -276,23 +280,70 @@ def initialize_qparams( register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" +def initialize_attn_qparams( + module: Module, scheme: QuantizationScheme, force_zero_point: bool +): + """Initlaize k_scale, v_scale for self_attn""" + + impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None) + kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None) + + if impl is None and kv_cache is None: + raise ValueError("Attention module has quantization scheme but no attached") + + _validate_attention_scheme(scheme) + + # extract shapes from config + config = kv_cache.config + num_attn_heads = get_num_attn_heads(config) + num_kv_heads = get_num_kv_heads(config) + head_dim = get_head_dim(config) + + # (batch_size, num_heads, slen, head_dim) + q_observed_shape = (num_attn_heads, None, head_dim) + kv_observed_shape = (num_kv_heads, None, head_dim) + observed_dtype = next(module.parameters()).dtype + + if impl is not None: + initialize_qparams( + module, + "q", + scheme.input_activations, + observed_shape=q_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) - expected_shape = 1 # per tensor + if kv_cache is not None: + initialize_qparams( + module, + "k", + scheme.input_activations, + observed_shape=kv_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) + initialize_qparams( + module, + "v", + scheme.input_activations, + observed_shape=kv_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) +def _validate_attention_scheme(scheme: QuantizationScheme): + if scheme.weights is not None: + raise ValueError( + "Cannot apply weight quantization to attention. " + "Instead, target (q|k|v)_proj" + ) - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) + if scheme.input_activations is None: + raise ValueError( + "Cannot apply attention quantization without specifying input activations" + ) + + if scheme.output_activations is not None: + raise ValueError("Cannot apply output quantization to attention") diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 35dad9981..bb96aeb94 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from enum import Enum -from typing import Annotated, Any, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Set, Union from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs @@ -21,11 +22,7 @@ QuantizationScheme, preset_name_to_scheme, ) -from compressed_tensors.quantization.utils import ( - is_module_quantized, - module_type, - parse_out_kv_cache_args, -) +from compressed_tensors.quantization.utils import is_module_quantized, module_type from pydantic import BaseModel, ConfigDict, Field from torch.nn import Module @@ -174,42 +171,52 @@ def from_pretrained( :param model: model to calculate quantization scheme of :return: filled out QuantizationScheme for the input model """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() + from compressed_tensors.quantization.lifecycle.initialize import ( + is_attention_module, + ) + + # set of all quantization schemes + # TODO: make quant config/scheme/args frozen/hashable and use a set + quantization_schemes: List[QuantizationScheme] = list() + + # use any status from modules (in practice, use the last module) + model_status = None + + # set of all quantized types + # this is later used to create the ignore list + quantization_type_names: Set[str] = set() + + # maps types to names which are not quantized + # this is later used to create the ignore list + ignore: Dict[str, List[str]] = defaultdict(list) + + # this keeps track of any kvcache schemes + kv_cache_scheme: Optional[QuantizationArgs] = None + for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): + layer_type: str = module_type(submodule) + + if is_module_quantized(submodule): + # add to running set of schemes/layer_type_names + model_status = getattr(submodule, "quantization_status", model_status) + quantization_type_names.add(layer_type) + if submodule.quantization_scheme not in quantization_schemes: + quantization_schemes.append(submodule.quantization_scheme) + + # attention quantization implies kv cache quantization + if is_attention_module(submodule): + kv_cache_scheme = submodule.quantization_scheme.input_activations + + else: + # add non-quantized layers to the ignore list if layer_type not in ignore: ignore[layer_type] = [] ignore[layer_type].append(name) - else: - if hasattr(submodule, "quantization_status"): - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme - quantization_type_names.add(layer_type) - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) - - if len(quant_scheme_to_layers) == 0: # No quantized layers + if len(quantization_schemes) == 0: # No quantized layers return None - # kv-cache only, no weight/activation quantization - if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") - - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized + # create ignore list, only include layers whose class has ever been targeted consolidated_ignore = [] for layer_type, ignore_names in ignore.items(): if layer_type in quantization_type_names: @@ -218,20 +225,15 @@ def from_pretrained( # else we leave it off the ignore list, doesn't fall under any of the # existing quantization schemes so it won't be quantized - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args - ) - + # create config groups from all unique schemes config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): + for idx, scheme in enumerate(quantization_schemes): group_name = "group_" + str(idx) config_groups[group_name] = scheme + # infer format if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: + if model_status == QuantizationStatus.COMPRESSED: format = CompressionFormat.int_quantized.value else: format = CompressionFormat.dense.value @@ -244,7 +246,7 @@ def from_pretrained( return QuantizationConfig( config_groups=config_groups, - quantization_status=quantization_status, + quantization_status=model_status, kv_cache_scheme=kv_cache_scheme, global_compression_ratio=None, format=format, diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c0..0099b088b 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -14,7 +14,7 @@ import logging import math -from typing import Generator, List, Optional, Tuple +from typing import Generator, Optional, Tuple import torch from compressed_tensors.quantization.quant_args import ( @@ -38,7 +38,6 @@ "module_type", "get_torch_bit_depth", "can_quantize", - "parse_out_kv_cache_args", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", @@ -391,6 +390,7 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: return bit_depth > quant_args.num_bits +@deprecated() def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. @@ -411,37 +411,6 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: return False -def parse_out_kv_cache_args( - quant_scheme_to_layers: List[QuantizationScheme], -) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]: - """ - If possible, parse out the kv cache specific QuantizationArgs - from the list of the QuantizationSchemes. If no kv cache - specific QuantizationArgs available, this function acts - as an identity function - - :param quant_scheme_to_layers: list of QuantizationSchemes - :return: kv_cache_args (optional) and the (remaining or original) - list of the QuantizationSchemes - """ - kv_cache_quant_scheme_to_layers = [ - scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme) - ] - quant_scheme_to_layers = [ - scheme - for scheme in quant_scheme_to_layers - if not is_kv_cache_quant_scheme(scheme) - ] - - if kv_cache_quant_scheme_to_layers: - kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0] - kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations - else: - kv_cache_args = None - - return kv_cache_args, quant_scheme_to_layers - - def generate_gparam( updated_min_val: torch.Tensor, updated_max_val: torch.Tensor, diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index bdaa40c05..7649f0d00 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -20,7 +20,7 @@ import numpy import torch -from transformers import AutoConfig +from transformers import AutoConfig, PretrainedConfig T = TypeVar("T", bound="Callable") # used by `deprecated` @@ -45,6 +45,9 @@ "unpack_bitmasks", "patch_attr", "ParameterizedDefaultDict", + "get_num_attn_heads", + "get_num_kv_heads", + "get_head_dim", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -396,3 +399,62 @@ def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any: """ with patch_attr(self, "_factory_kwargs", factory_kwargs): return self[args] + + +def get_num_attn_heads(config: PretrainedConfig) -> int: + """ + Get the number of attention heads used by a model + + :param config: model config + :return: num_attention_heads of model + """ + if hasattr(config, "num_attention_heads"): + return config.num_attention_heads + + elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"): + return config.hidden_size // config.head_dim + + else: + raise ValueError( + "Cannot determine num_attention_heads from config. Config must define " + "either `num_attention_heads` or both `hidden_size` and `head_dim`. " + f"{config}" + ) + + +def get_num_kv_heads(config: PretrainedConfig) -> int: + """ + Get the number of key-value attention heads used by a model + + :param config: model config + :return: num_key_value_heads of model + """ + if hasattr(config, "num_key_value_heads"): + return config.num_key_value_heads + + else: + raise ValueError( + "Cannot determine num_key_value_heads from config. Config must define " + f"`num_key_value_heads`. {config}" + ) + + +def get_head_dim(config: PretrainedConfig) -> int: + """ + Get the number of dimensions used by the attention heads of a model + + :param config: model config + :return: head_dim of model + """ + if hasattr(config, "head_dim"): + return config.head_dim + + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + return config.hidden_size // config.num_attention_heads + + else: + raise ValueError( + "Cannot determine head_dim from config. Config must define " + "either `head_dim` or both `hidden_size` and `num_attention_heads`. " + f"{config}" + ) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index b96e83d04..f26400b0b 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -30,6 +30,7 @@ "match_targets", "match_modules_set", "is_match", + "is_narrow_match", ] @@ -260,6 +261,34 @@ def is_match( ) +def is_narrow_match( + model: torch.nn.Module, + targets: Union[str, Iterable[str]], + name: str, + module: Optional[torch.nn.Module] = None, +) -> bool: + """ + Checks if any of the targets narrowly match the module. A target narrowly matches + a module if the target matches the module, but does not match the module's parent + + :param model: model containing both module and its parent + :param targets: target strings, potentially containing "re:" prefixes + :param name: name of module to match + :param module: module to match. If none is provided, then get module from model + :return: True if any of the targets narrow match the module + """ + targets = [targets] if isinstance(targets, str) else targets + module = module if module is not None else model.get_submodule(name) + + parent_name = name.rsplit(".", 1)[0] + parent = model.get_submodule(parent_name) + + return any( + is_match(name, module, target) and not is_match(parent_name, parent, target) + for target in targets + ) + + def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool: """ Returns true if target string begins with "re:" and regex matches or if target diff --git a/tests/test_modeling/test_attention_and_cache.py b/tests/test_modeling/test_attention_and_cache.py new file mode 100644 index 000000000..9a7a6e1e5 --- /dev/null +++ b/tests/test_modeling/test_attention_and_cache.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, + QuantizedAttentionImpl, + QuantizedKVCache, + initialize_hooked_attention, + initialize_hooked_kv_cache, + register_key_hook, + register_query_hook, + register_value_hook, +) +from tests.testing_utils import requires_gpu +from transformers import AutoModelForCausalLM + + +@requires_gpu +def test_attention_cache(): + model = AutoModelForCausalLM.from_pretrained( + "nm-testing/llama2.c-stories15M", device_map="cuda" + ) + inputs = {key: value.to("cuda") for key, value in model.dummy_inputs.items()} + true_outputs = model(**inputs) + layers = model.model.layers + + # check if hooks work + k_called = [False for _ in range(len(layers))] + v_called = [False for _ in range(len(layers))] + + # apply kv cache quantization + _apply_kv_cache(model, layers, k_called, v_called) + + # check kv cache quantization + outputs = model(**inputs) + assert torch.equal(outputs.logits, true_outputs.logits) + assert all(k_called) and all(v_called) + + """ apply attention quantization after kv cache quantization """ + + # check if hooks work + q_called = [False for _ in range(len(layers))] + k_called = [False for _ in range(len(layers))] + v_called = [False for _ in range(len(layers))] + + _apply_attention(model, layers, q_called, k_called, v_called) + outputs = model(**inputs) + assert torch.equal(outputs.logits, true_outputs.logits) + assert all(q_called) and all(k_called) and all(v_called) + + +def _apply_kv_cache(model, layers, k_called, v_called): + for layer_index, layer in enumerate(layers): + module = layer.self_attn + initialize_hooked_kv_cache(model, module) + assert isinstance(getattr(module, KV_CACHE_ATTR), QuantizedKVCache) + + # reapply is no-op + initialize_hooked_kv_cache(model, module) + + def k_hook(_module, _states, layer_index=layer_index): # NOTE: capture by value + k_called[layer_index] = True + + def v_hook(_module, _states, layer_index=layer_index): + my_index = layer_index + v_called[my_index] = True + + register_key_hook(module, k_hook) + register_value_hook(module, v_hook) + + +def _apply_attention(model, layers, q_called, k_called, v_called): + for layer_index, layer in enumerate(layers): + module = layer.self_attn + initialize_hooked_attention(model, module) + assert isinstance(getattr(module, IMPL_ATTR), QuantizedAttentionImpl) + + # reapply is no-op + initialize_hooked_attention(model, module) + + def q_hook(_module, _states, layer_index=layer_index): + q_called[layer_index] = True + + def k_hook(_module, _states, layer_index=layer_index): + k_called[layer_index] = True + + def v_hook(_module, _states, layer_index=layer_index): + v_called[layer_index] = True + + register_query_hook(module, q_hook) + register_key_hook(module, k_hook) + register_value_hook(module, v_hook) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index ae8908202..510619313 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -366,3 +366,43 @@ def test_multi_apply_quantization_config(): weight_zero_point is not None and weight_zero_point.shape == torch.Size([1]) ) + + +@requires_accelerate() +def test_apply_kv_cache(): + from accelerate import init_empty_weights + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + args = QuantizationArgs(num_bits=8, type="float", strategy="tensor") + config = QuantizationConfig(config_groups={}, kv_cache_scheme=args) + + apply_quantization_config(model, config) + + for layer in model.model.layers: + assert getattr(layer.self_attn, "quantization_scheme").input_activations == args + assert hasattr(layer.self_attn, "k_scale") + assert hasattr(layer.self_attn, "v_scale") + + +@requires_accelerate() +def test_apply_attention(): + from accelerate import init_empty_weights + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + scheme = QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + ) + config = QuantizationConfig(config_groups={"attention": scheme}) + + apply_quantization_config(model, config) + + for layer in model.model.layers: + assert getattr(layer.self_attn, "quantization_scheme") == scheme + assert hasattr(layer.self_attn, "q_scale") + assert hasattr(layer.self_attn, "k_scale") + assert hasattr(layer.self_attn, "v_scale") diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index c85d2f434..1129120c6 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -21,6 +21,7 @@ from compressed_tensors.utils import ( InternalModule, is_match, + is_narrow_match, match_modules_set, match_named_modules, match_named_parameters, @@ -500,6 +501,86 @@ class InternalLinear(InternalModule, nn.Linear): assert len(matches) == 0 +class TestIsNarrowMatch: + def test_narrow_match_true_child_only(self): + """ + Target matches the child module name but NOT its parent name. + Should return True. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # Matches "...q_proj" but not "...self_attn" + target = r"re:.*q_proj$" + + assert is_narrow_match(model, target, name) + + def test_narrow_match_false_when_parent_also_matches(self): + """ + Target matches both the child and its parent name. + Should return False because it's not a 'narrow' match. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # Broad target that also matches the parent "transformer.layers.0.self_attn" + target = r"re:transformer\.layers\.0\..*" + + assert not is_narrow_match(model, target, name) + + def test_narrow_match_false_when_neither_matches(self): + """ + Target matches neither the child nor the parent. + Should return False. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + target = r"re:this_does_not_exist$" + + assert not is_narrow_match(model, target, name) + + def test_narrow_match_iterable_targets_any_true(self): + """ + With multiple targets: if any target narrowly matches the child, + the function should return True. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # First target is broad (matches both child & parent -> narrow False), + # second target is narrow (matches child only -> narrow True). + targets = [ + r"re:transformer\.layers\.0\..*", + r"re:.*q_proj$", + ] + + assert is_narrow_match(model, targets, name) + + def test_narrow_match_with_explicit_module_argument(self): + """ + Passing the module explicitly should behave the same as when it's + retrieved from the model by name. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + module = model.get_submodule(name) + target = r"re:.*q_proj$" + + # Both ways should be True + assert is_narrow_match(model, target, name) + assert is_narrow_match(model, target, name, module=module) + + def test_narrow_match_top_level_behavior_documented(self): + """ + (Behavior check) For a top-level module name without a dot, the current + implementation derives parent_name == name, so parent==child. + Then 'narrow' cannot be True because parent match mirrors child match. + This test documents current behavior to guard against regressions. + """ + model = DummyModel() + name = "layer1" # top-level module in DummyModel + target = r"re:^layer1$" + + assert not is_narrow_match(model, target, name) + + class TestIntegration: """Integration tests combining multiple functions""" From 35acc55f9a5c9f171cb0861738a76c892d75d5f8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 15 Oct 2025 13:14:38 -0400 Subject: [PATCH 2/7] reduce diff Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/initialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index ef6d7957f..396c2a17f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -203,7 +203,7 @@ def initialize_qparams( expected_shape = (1,) elif strategy == QuantizationStrategy.TOKEN: - expected_shape = (1, 1) + raise ValueError("Cannot perform static token quantization") elif strategy == QuantizationStrategy.CHANNEL: if len(observed_shape) < 2: From a9f6e1fefb021ffe478e72bff57d350041bec6b1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 15 Oct 2025 17:06:30 -0400 Subject: [PATCH 3/7] address nits Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/kvcache.py | 38 ++++++++++++++----- .../quantization/lifecycle/initialize.py | 8 +++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py index 4727c9fab..5d29aa645 100644 --- a/src/compressed_tensors/modeling/kvcache.py +++ b/src/compressed_tensors/modeling/kvcache.py @@ -13,8 +13,8 @@ # limitations under the License. import inspect -from typing import Callable, Optional, Tuple -from weakref import ref +from typing import Any, Callable, Dict, List, Optional, Tuple +from weakref import ReferenceType, ref from compressed_tensors.quantization.lifecycle.forward import forward_quantize from compressed_tensors.utils import getattr_chain @@ -55,7 +55,7 @@ def __init__(self, config: PretrainedConfig, attn_module: Module): super().__init__() self.config = config self.attn_module = ref(attn_module) # avoid circular reference - self.past_key_values: Optional[Cache] = None + self.past_key_values: Optional[ReferenceType[Cache]] = None def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: return self(*args, **kwargs) @@ -78,26 +78,46 @@ def forward( # original cache if self.past_key_values is not None: - ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) + ret = self.past_key_values().update( + key_states, value_states, *args, **kwargs + ) else: ret = (key_states, value_states) - self.past_key_values = None + return ret + def add_past_key_values(self, past_key_values: Optional[Cache]): + if past_key_values is not None: + self.past_key_values = ref(past_key_values) + else: + self.past_key_values = None + # ----- initialize ----- # -def _kv_cache_attention_hook(module: Module, args, kwargs): - kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) +def _kv_cache_attention_hook( + module: Module, args: List[Any], kwargs: Dict[str, Any] +) -> Tuple[List[Any], Dict[str, Any]]: + """ + Hook which should be called before each quantized attention forward pass. + This hook dynamically replaces the `past_key_values` kwarg to the attention + forward function. + + The original kvcache object is assigned to QuantizedKVCache().past_key_values + as a weakref to maintain original cache functionality and compute savings + """ _past_kv_name = ( "past_key_values" # transformers#39956 if "past_key_values" in inspect.signature(module.forward).parameters else "past_key_value" ) - kv_cache.past_key_values = kwargs.get(_past_kv_name, None) - kwargs[_past_kv_name] = kv_cache + past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None) + + cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + cache.add_past_key_values(past_key_values) + kwargs[_past_kv_name] = cache return args, kwargs diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 396c2a17f..4bd75a2b3 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -289,7 +289,11 @@ def initialize_attn_qparams( kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None) if impl is None and kv_cache is None: - raise ValueError("Attention module has quantization scheme but no attached") + raise ValueError( + f"Attention module has quantization scheme but no {IMPL_ATTR} " + f"or {KV_CACHE_ATTR} attributes. Please ensure that these " + "attributes are initialized using `apply_quantization_config`." + ) _validate_attention_scheme(scheme) @@ -337,7 +341,7 @@ def _validate_attention_scheme(scheme: QuantizationScheme): if scheme.weights is not None: raise ValueError( "Cannot apply weight quantization to attention. " - "Instead, target (q|k|v)_proj" + "Instead, target the (q|k|v)_proj submodule layers of attention" ) if scheme.input_activations is None: From 311a9ab81e088a246381f40579ac7ce297568ed6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 15 Oct 2025 23:22:06 -0400 Subject: [PATCH 4/7] fix kv cache serialization, add tests Signed-off-by: Kyle Sayers --- .../quantization/quant_config.py | 24 ++++++-- .../test_modeling/test_attention_and_cache.py | 3 + .../test_quantization/lifecycle/test_apply.py | 58 +++++++++++++++++++ tests/test_quantization/test_quant_config.py | 3 + 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index bb96aeb94..a50721caf 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -171,6 +171,7 @@ def from_pretrained( :param model: model to calculate quantization scheme of :return: filled out QuantizationScheme for the input model """ + from compressed_tensors.modeling import IMPL_ATTR, KV_CACHE_ATTR from compressed_tensors.quantization.lifecycle.initialize import ( is_attention_module, ) @@ -196,24 +197,35 @@ def from_pretrained( for name, submodule in model.named_modules(): layer_type: str = module_type(submodule) - if is_module_quantized(submodule): + # add config group if quantized non-attention or attention quant + has_config_group = is_module_quantized(submodule) and ( + not is_attention_module(submodule) or hasattr(submodule, IMPL_ATTR) + ) + # only add kvcache if quant attention (which always implies kvcache) + has_kv_cache = is_module_quantized(submodule) and is_attention_module( + submodule + ) + + if has_config_group: # add to running set of schemes/layer_type_names model_status = getattr(submodule, "quantization_status", model_status) quantization_type_names.add(layer_type) if submodule.quantization_scheme not in quantization_schemes: quantization_schemes.append(submodule.quantization_scheme) - # attention quantization implies kv cache quantization - if is_attention_module(submodule): - kv_cache_scheme = submodule.quantization_scheme.input_activations + if has_kv_cache: + model_status = getattr(submodule, "quantization_status", model_status) + kv_cache_scheme = submodule.quantization_scheme.input_activations - else: + if not has_config_group: # add non-quantized layers to the ignore list if layer_type not in ignore: ignore[layer_type] = [] ignore[layer_type].append(name) - if len(quantization_schemes) == 0: # No quantized layers + if ( + len(quantization_schemes) == 0 and kv_cache_scheme is None + ): # No quantized layers return None # create ignore list, only include layers whose class has ever been targeted diff --git a/tests/test_modeling/test_attention_and_cache.py b/tests/test_modeling/test_attention_and_cache.py index 9a7a6e1e5..230b8f3ee 100644 --- a/tests/test_modeling/test_attention_and_cache.py +++ b/tests/test_modeling/test_attention_and_cache.py @@ -56,7 +56,10 @@ def test_attention_cache(): k_called = [False for _ in range(len(layers))] v_called = [False for _ in range(len(layers))] + # apply attention quantization _apply_attention(model, layers, q_called, k_called, v_called) + + # check attention quantization outputs = model(**inputs) assert torch.equal(outputs.logits, true_outputs.logits) assert all(q_called) and all(k_called) and all(v_called) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 510619313..d56148c51 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -131,6 +131,64 @@ def test_apply_quantization_config_tinyllama(): ) +@pytest.mark.parametrize( + "config", + [ + QuantizationConfig( + config_groups={ + "linear": QuantizationScheme( + targets=["Linear"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + } + ), + QuantizationConfig( + config_groups={ + "linear": QuantizationScheme( + targets=["Linear"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + }, + ignore=[ + "model.layers.0.self_attn.q_proj", + "model.layers.1.self_attn.k_proj", + "model.layers.2.self_attn.v_proj", + ], + ), + QuantizationConfig( + config_groups={}, + kv_cache_scheme=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ), + QuantizationConfig( + config_groups={ + "attention": QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + }, + kv_cache_scheme=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ), + ], +) +def test_from_pretrained(config: QuantizationConfig): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + apply_quantization_config(model, config) + _config = QuantizationConfig.from_pretrained(model) + assert list(_config.config_groups.values()) == list(config.config_groups.values()) + assert _config.kv_cache_scheme == config.kv_cache_scheme + assert _config.ignore == config.ignore + + def test_serialize_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config() model = get_tinyllama_model() diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index c3830a02d..bdc6f0235 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -16,11 +16,14 @@ from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD, + QuantizationArgs, QuantizationConfig, QuantizationScheme, QuantizationStatus, + apply_quantization_config, ) from pydantic import ValidationError +from transformers import AutoModelForCausalLM def test_basic_config(): From 8c99f633f47073b835e39c149519f2eaee13733e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 15 Oct 2025 23:28:41 -0400 Subject: [PATCH 5/7] fix style Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/quant_config.py | 2 +- tests/test_quantization/test_quant_config.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index a50721caf..bed6078fa 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -171,7 +171,7 @@ def from_pretrained( :param model: model to calculate quantization scheme of :return: filled out QuantizationScheme for the input model """ - from compressed_tensors.modeling import IMPL_ATTR, KV_CACHE_ATTR + from compressed_tensors.modeling import IMPL_ATTR from compressed_tensors.quantization.lifecycle.initialize import ( is_attention_module, ) diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index bdc6f0235..c3830a02d 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -16,14 +16,11 @@ from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD, - QuantizationArgs, QuantizationConfig, QuantizationScheme, QuantizationStatus, - apply_quantization_config, ) from pydantic import ValidationError -from transformers import AutoModelForCausalLM def test_basic_config(): From 5225515fc458c188968cda4994a599274db6fe73 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 20 Oct 2025 15:44:25 -0400 Subject: [PATCH 6/7] do not force zp for attention Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/apply.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 1bea5a4a9..1bcbfa133 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -45,6 +45,8 @@ from safetensors import safe_open from torch.nn import Module +from loguru import logger + __all__ = [ "load_pretrained_quantization_parameters", @@ -132,7 +134,7 @@ def apply_quantization_config( # apply and initialize kv cache quantization if config.kv_cache_scheme is not None: _apply_kv_cache_scheme( - model, config.kv_cache_scheme, config.quantization_status, force_zero_point + model, config.kv_cache_scheme, config.quantization_status ) # build mapping of targets to schemes for easier matching @@ -186,22 +188,22 @@ def _apply_kv_cache_scheme( model: torch.nn.Module, kv_cache_scheme: QuantizationArgs, status: QuantizationStatus, - force_zero_point: bool, ): + if not kv_cache_scheme.symmetric: + raise logger.warning("vLLM does not support asymmetric kv cache quantization") + # applies and initializes kv cache quantization # this step cannot come after attention apply/initialize # otherwise it will override the attention qparams scheme = QuantizationScheme( - targets=[".*self_attn$"], input_activations=kv_cache_scheme + targets=[".*self_attn$"], # is never read in practice + input_activations=kv_cache_scheme ) for submodule in model.modules(): if is_attention_module(submodule): submodule.quantization_scheme = scheme initialize_hooked_kv_cache(model, submodule) - initialize_module_for_quantization( - submodule, - force_zero_point=force_zero_point, - ) + initialize_module_for_quantization(submodule, force_zero_point=False) submodule.quantization_status = status From a67737262eb595b1c44e624f6d0bea24e2685130 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 21 Oct 2025 15:49:00 -0400 Subject: [PATCH 7/7] populate ALL_MASK_ATTENTION_FUNCTIONS Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 43 ++++++++++--------- .../quantization/lifecycle/apply.py | 7 ++- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py index 0f41dd139..504d455ce 100644 --- a/src/compressed_tensors/modeling/attention.py +++ b/src/compressed_tensors/modeling/attention.py @@ -14,7 +14,6 @@ import inspect from typing import Callable, Optional -from weakref import ref from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache from compressed_tensors.quantization.lifecycle.forward import forward_quantize @@ -23,7 +22,8 @@ from torch import Tensor from torch.nn import Module from torch.utils.hooks import RemovableHandle -from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig, PreTrainedModel +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -51,16 +51,13 @@ class QuantizedAttentionImpl(InternalModule): which calls this module, then setting the model attention implementation to the new function. After triggering hooks and quantization, this module calls the original attention implementation function. - - :param attn_module: parent attention module """ _original_impl = "eager" - def __init__(self, config: PretrainedConfig, attn_module: Module): + def __init__(self, config: PretrainedConfig): super().__init__() self.config = config - self.attn_module = ref(attn_module) # avoid circular references def forward( self, @@ -79,7 +76,7 @@ def forward( query = forward_quantize(module, query, "q", quant_args) # original attention - return ALL_ATTENTION_FUNCTIONS[_original_impl]( + return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl]( module, query, key, @@ -92,30 +89,34 @@ def forward( # ----- initialize ----- # -def _ct_hooked_attention(module: Module, *args, **kwargs): - if hasattr(module, IMPL_ATTR): - return module.impl(module, *args, **kwargs) - else: - return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) +def _hooked_attention(module: Module, *args, **kwargs): + assert hasattr(module, IMPL_ATTR), ( + f"Using {HOOKED_ATTENTION_NAME} attention implementation, " + f"but attention module does not have {IMPL_ATTR} submodule." + ) + + return getattr(module, IMPL_ATTR)(module, *args, **kwargs) def initialize_hooked_attention(model: PreTrainedModel, module: Module): """ Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances - attached to attention + attached to attention. Assumes that only one model is hooked at a time. :param model: parent model of attention module :param module: attention module to initialize with """ if not hasattr(module, IMPL_ATTR): - module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module)) - if model.config._attn_implementation != HOOKED_ATTENTION_NAME: - # assumes only one model at a time - global _original_impl - _original_impl = model.config._attn_implementation + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config)) + + if model.config._attn_implementation != HOOKED_ATTENTION_NAME: + QuantizedAttentionImpl._original_impl = model.config._attn_implementation + original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation] - AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) - model.config._attn_implementation = HOOKED_ATTENTION_NAME + ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention) + ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask) + model.set_attn_implementation(HOOKED_ATTENTION_NAME) + assert model.config._attn_implementation == HOOKED_ATTENTION_NAME initialize_hooked_kv_cache(model, module) @@ -133,7 +134,7 @@ def register_query_hook( :param module: attention module to add hook to :param hook: query hook function """ - impl = getattr(module, IMPL_ATTR) + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) def _hook(impl: QuantizedAttentionImpl, args, kwargs): bound = inspect.signature(impl.forward).bind(*args, **kwargs) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 1bcbfa133..28c8a7b97 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -42,11 +42,10 @@ ) from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from loguru import logger from safetensors import safe_open from torch.nn import Module -from loguru import logger - __all__ = [ "load_pretrained_quantization_parameters", @@ -191,13 +190,13 @@ def _apply_kv_cache_scheme( ): if not kv_cache_scheme.symmetric: raise logger.warning("vLLM does not support asymmetric kv cache quantization") - + # applies and initializes kv cache quantization # this step cannot come after attention apply/initialize # otherwise it will override the attention qparams scheme = QuantizationScheme( targets=[".*self_attn$"], # is never read in practice - input_activations=kv_cache_scheme + input_activations=kv_cache_scheme, ) for submodule in model.modules(): if is_attention_module(submodule):