diff --git a/src/compressed_tensors/modeling/__init__.py b/src/compressed_tensors/modeling/__init__.py new file mode 100644 index 00000000..97ee4a2c --- /dev/null +++ b/src/compressed_tensors/modeling/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +# isort: off +from .kvcache import * +from .attention import * diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 00000000..504d455c --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional + +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization.lifecycle.forward import forward_quantize +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import PretrainedConfig, PreTrainedModel +from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +__all__ = [ + "QuantizedAttentionImpl", + "initialize_hooked_attention", + "register_query_hook", + "IMPL_ATTR", +] + + +IMPL_ATTR = "impl" +HOOKED_ATTENTION_NAME = "ct_hooked_attention" + + +class QuantizedAttentionImpl(InternalModule): + """ + QuantizedAttentionImpl module which wraps the functionality of the original + attention implementation. Unlike the original attention function, this + implementation is a `torch.nn.Module` which can be hooked to trigger + transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_attention`, registering a new attention implementation function + which calls this module, then setting the model attention implementation to the new + function. After triggering hooks and quantization, this module calls the original + attention implementation function. + """ + + _original_impl = "eager" + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + + def forward( + self, + module: Module, + query: Tensor, + key: Tensor, + value: Tensor, + *args, + **kwargs, + ): + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled: + query = forward_quantize(module, query, "q", quant_args) + + # original attention + return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + +# ----- initialize ----- # + + +def _hooked_attention(module: Module, *args, **kwargs): + assert hasattr(module, IMPL_ATTR), ( + f"Using {HOOKED_ATTENTION_NAME} attention implementation, " + f"but attention module does not have {IMPL_ATTR} submodule." + ) + + return getattr(module, IMPL_ATTR)(module, *args, **kwargs) + + +def initialize_hooked_attention(model: PreTrainedModel, module: Module): + """ + Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances + attached to attention. Assumes that only one model is hooked at a time. + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config)) + + if model.config._attn_implementation != HOOKED_ATTENTION_NAME: + QuantizedAttentionImpl._original_impl = model.config._attn_implementation + original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation] + + ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention) + ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask) + model.set_attn_implementation(HOOKED_ATTENTION_NAME) + assert model.config._attn_implementation == HOOKED_ATTENTION_NAME + + initialize_hooked_kv_cache(model, module) + + +# ----- hooks ----- # + + +def register_query_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope query states as an argument and + returns the modified query states or `None` + + :param module: attention module to add hook to + :param hook: query hook function + """ + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) + + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(impl.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value + + return bound.args, bound.kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 00000000..5d29aa64 --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple +from weakref import ReferenceType, ref + +from compressed_tensors.quantization.lifecycle.forward import forward_quantize +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import Cache, PretrainedConfig, PreTrainedModel + + +__all__ = [ + "QuantizedKVCache", + "initialize_hooked_kv_cache", + "register_key_hook", + "register_value_hook", + "KV_CACHE_ATTR", +] + + +KV_CACHE_ATTR = "kv_cache" + + +class QuantizedKVCache(InternalModule): + """ + QuantizedKVCache module which wraps the functionality of any existing kvcache args. + Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be + hooked to trigger transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values` + kwargs with this module. This module adopts the functionality of the replaced cache, + preserving caching functionality such as sliding window attention, ect. + + :param attn_module: parent attention module + """ + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular reference + self.past_key_values: Optional[ReferenceType[Cache]] = None + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + *args, + **kwargs, + ) -> Tuple[Tensor, Tensor]: + # quantization + module = self.attn_module() + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + + # original cache + if self.past_key_values is not None: + ret = self.past_key_values().update( + key_states, value_states, *args, **kwargs + ) + else: + ret = (key_states, value_states) + self.past_key_values = None + + return ret + + def add_past_key_values(self, past_key_values: Optional[Cache]): + if past_key_values is not None: + self.past_key_values = ref(past_key_values) + else: + self.past_key_values = None + + +# ----- initialize ----- # + + +def _kv_cache_attention_hook( + module: Module, args: List[Any], kwargs: Dict[str, Any] +) -> Tuple[List[Any], Dict[str, Any]]: + """ + Hook which should be called before each quantized attention forward pass. + This hook dynamically replaces the `past_key_values` kwarg to the attention + forward function. + + The original kvcache object is assigned to QuantizedKVCache().past_key_values + as a weakref to maintain original cache functionality and compute savings + """ + _past_kv_name = ( + "past_key_values" # transformers#39956 + if "past_key_values" in inspect.signature(module.forward).parameters + else "past_key_value" + ) + past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None) + + cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + cache.add_past_key_values(past_key_values) + kwargs[_past_kv_name] = cache + + return args, kwargs + + +def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module): + """ + Initialize a `QuantizedKVCache` instance attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) + module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) + + +# ----- hooks ----- # + + +def register_key_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope key states as an argument and + returns the modified key states or `None` + + :param module: attention module to add hook to + :param hook: key hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes value states as an argument and + returns the modified value states or `None` + + :param module: attention module to add hook to + :param hook: value hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 21525bd4..28c8a7b9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections import OrderedDict from copy import deepcopy from typing import Dict, List, Optional @@ -21,8 +20,13 @@ import torch from compressed_tensors.config import CompressionFormat +from compressed_tensors.modeling import ( + initialize_hooked_attention, + initialize_hooked_kv_cache, +) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, + is_attention_module, ) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( @@ -30,14 +34,15 @@ QuantizationStatus, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import ( - KV_CACHE_TARGETS, - is_kv_cache_quant_scheme, -) from compressed_tensors.utils.helpers import replace_module -from compressed_tensors.utils.match import match_named_modules, match_targets +from compressed_tensors.utils.match import ( + is_narrow_match, + match_named_modules, + match_targets, +) from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from loguru import logger from safetensors import safe_open from torch.nn import Module @@ -53,9 +58,6 @@ ) -_LOGGER = logging.getLogger(__name__) - - def load_pretrained_quantization_parameters( model: Module, model_name_or_path: Optional[str] = None, @@ -125,8 +127,14 @@ def apply_quantization_config( if config is None: # see PR #180 return dict() - # preprocess to support kv cache scheme - config = process_quantization_config(config) + # force zero points during initialization + force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED + + # apply and initialize kv cache quantization + if config.kv_cache_scheme is not None: + _apply_kv_cache_scheme( + model, config.kv_cache_scheme, config.quantization_status + ) # build mapping of targets to schemes for easier matching # use ordered dict to preserve target ordering in config @@ -162,49 +170,40 @@ def apply_quantization_config( replace_module(model, name, compressed_linear) else: + if is_attention_module(submodule) and is_narrow_match( + model, scheme.targets, name + ): + initialize_hooked_attention(model, submodule) + initialize_module_for_quantization( submodule, - force_zero_point=config.quantization_status - != QuantizationStatus.COMPRESSED, + force_zero_point=force_zero_point, ) submodule.quantization_status = config.quantization_status -def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: - """ - Preprocess the raw QuantizationConfig - - :param config: the raw QuantizationConfig - :return: the processed QuantizationConfig - """ - if config.kv_cache_scheme is not None: - config = process_kv_cache_config(config) - - return config - - -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: - """ - Reformulate the `config.kv_cache` as a `config_group` - and add it to the set of existing `config.groups` - - :param config: the QuantizationConfig - :return: the QuantizationConfig with additional "kv_cache" group - """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") - - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, +def _apply_kv_cache_scheme( + model: torch.nn.Module, + kv_cache_scheme: QuantizationArgs, + status: QuantizationStatus, +): + if not kv_cache_scheme.symmetric: + raise logger.warning("vLLM does not support asymmetric kv cache quantization") + + # applies and initializes kv cache quantization + # this step cannot come after attention apply/initialize + # otherwise it will override the attention qparams + scheme = QuantizationScheme( + targets=[".*self_attn$"], # is never read in practice + input_activations=kv_cache_scheme, ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) - return config + for submodule in model.modules(): + if is_attention_module(submodule): + submodule.quantization_scheme = scheme + initialize_hooked_kv_cache(model, submodule) + initialize_module_for_quantization(submodule, force_zero_point=False) + submodule.quantization_status = status def _load_quant_args_from_mapping( @@ -256,60 +255,6 @@ def _scheme_from_targets( targets: List[str], name: str, ) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - - -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) - - merged_scheme.update(targets=[name]) - - return QuantizationScheme(**merged_scheme) + # return the first scheme (the prioritized one, + # since the order of target_to_scheme matters) + return target_to_scheme[targets[0]] diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 50757adc..4bd75a2b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,11 +17,16 @@ from typing import Optional, Tuple, Union import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, + QuantizedAttentionImpl, + QuantizedKVCache, +) from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, DynamicType, - KVCacheScaleType, QuantizationArgs, QuantizationMetadata, QuantizationScheme, @@ -31,14 +36,13 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import ( - is_fp4, - is_kv_cache_quant_scheme, - strategy_cdiv, -) +from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, + get_head_dim, + get_num_attn_heads, + get_num_kv_heads, register_offload_parameter, ) from torch.nn import Module, Parameter @@ -48,6 +52,7 @@ "initialize_module_for_quantization", "is_attention_module", "initialize_qparams", + "initialize_attn_qparams", ] @@ -81,7 +86,7 @@ def initialize_module_for_quantization( if is_attention_module(module): # quantized actions based on calltime status - _initialize_attn_scales(module) + initialize_attn_qparams(module, scheme, force_zero_point) else: if not isinstance(module, torch.nn.Linear): @@ -120,8 +125,7 @@ def initialize_module_for_quantization( force_zero_point=force_zero_point, ) - output_is_kv_cache = is_kv_cache_quant_scheme(scheme) - if scheme.output_activations is not None and not output_is_kv_cache: + if scheme.output_activations is not None: initialize_qparams( module, "output", @@ -131,14 +135,14 @@ def initialize_module_for_quantization( force_zero_point=force_zero_point, ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED - with disable_hf_hook(module): # wrap forward call of module to perform # quantized actions based on calltime status wrap_module_forward_quantized(module, scheme) + module.quantization_scheme = scheme + module.quantization_status = QuantizationStatus.INITIALIZED + def is_attention_module(module: Module): return "attention" in module.__class__.__name__.lower() and ( @@ -276,23 +280,74 @@ def initialize_qparams( register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" +def initialize_attn_qparams( + module: Module, scheme: QuantizationScheme, force_zero_point: bool +): + """Initlaize k_scale, v_scale for self_attn""" - expected_shape = 1 # per tensor + impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None) + kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None) - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device + if impl is None and kv_cache is None: + raise ValueError( + f"Attention module has quantization scheme but no {IMPL_ATTR} " + f"or {KV_CACHE_ATTR} attributes. Please ensure that these " + "attributes are initialized using `apply_quantization_config`." + ) - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) + _validate_attention_scheme(scheme) + + # extract shapes from config + config = kv_cache.config + num_attn_heads = get_num_attn_heads(config) + num_kv_heads = get_num_kv_heads(config) + head_dim = get_head_dim(config) + + # (batch_size, num_heads, slen, head_dim) + q_observed_shape = (num_attn_heads, None, head_dim) + kv_observed_shape = (num_kv_heads, None, head_dim) + observed_dtype = next(module.parameters()).dtype + + if impl is not None: + initialize_qparams( + module, + "q", + scheme.input_activations, + observed_shape=q_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) + if kv_cache is not None: + initialize_qparams( + module, + "k", + scheme.input_activations, + observed_shape=kv_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) + initialize_qparams( + module, + "v", + scheme.input_activations, + observed_shape=kv_observed_shape, + observed_dtype=observed_dtype, + force_zero_point=force_zero_point, + ) + + +def _validate_attention_scheme(scheme: QuantizationScheme): + if scheme.weights is not None: + raise ValueError( + "Cannot apply weight quantization to attention. " + "Instead, target the (q|k|v)_proj submodule layers of attention" + ) + + if scheme.input_activations is None: + raise ValueError( + "Cannot apply attention quantization without specifying input activations" + ) + + if scheme.output_activations is not None: + raise ValueError("Cannot apply output quantization to attention") diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 35dad998..bed6078f 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from enum import Enum -from typing import Annotated, Any, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Set, Union from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs @@ -21,11 +22,7 @@ QuantizationScheme, preset_name_to_scheme, ) -from compressed_tensors.quantization.utils import ( - is_module_quantized, - module_type, - parse_out_kv_cache_args, -) +from compressed_tensors.quantization.utils import is_module_quantized, module_type from pydantic import BaseModel, ConfigDict, Field from torch.nn import Module @@ -174,42 +171,64 @@ def from_pretrained( :param model: model to calculate quantization scheme of :return: filled out QuantizationScheme for the input model """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() + from compressed_tensors.modeling import IMPL_ATTR + from compressed_tensors.quantization.lifecycle.initialize import ( + is_attention_module, + ) + + # set of all quantization schemes + # TODO: make quant config/scheme/args frozen/hashable and use a set + quantization_schemes: List[QuantizationScheme] = list() + + # use any status from modules (in practice, use the last module) + model_status = None + + # set of all quantized types + # this is later used to create the ignore list + quantization_type_names: Set[str] = set() + + # maps types to names which are not quantized + # this is later used to create the ignore list + ignore: Dict[str, List[str]] = defaultdict(list) + + # this keeps track of any kvcache schemes + kv_cache_scheme: Optional[QuantizationArgs] = None + for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): - if layer_type not in ignore: - ignore[layer_type] = [] - ignore[layer_type].append(name) - else: - if hasattr(submodule, "quantization_status"): - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme + layer_type: str = module_type(submodule) + + # add config group if quantized non-attention or attention quant + has_config_group = is_module_quantized(submodule) and ( + not is_attention_module(submodule) or hasattr(submodule, IMPL_ATTR) + ) + # only add kvcache if quant attention (which always implies kvcache) + has_kv_cache = is_module_quantized(submodule) and is_attention_module( + submodule + ) + + if has_config_group: + # add to running set of schemes/layer_type_names + model_status = getattr(submodule, "quantization_status", model_status) quantization_type_names.add(layer_type) + if submodule.quantization_scheme not in quantization_schemes: + quantization_schemes.append(submodule.quantization_scheme) - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) + if has_kv_cache: + model_status = getattr(submodule, "quantization_status", model_status) + kv_cache_scheme = submodule.quantization_scheme.input_activations - if len(quant_scheme_to_layers) == 0: # No quantized layers - return None + if not has_config_group: + # add non-quantized layers to the ignore list + if layer_type not in ignore: + ignore[layer_type] = [] + ignore[layer_type].append(name) - # kv-cache only, no weight/activation quantization if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") + len(quantization_schemes) == 0 and kv_cache_scheme is None + ): # No quantized layers + return None - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized + # create ignore list, only include layers whose class has ever been targeted consolidated_ignore = [] for layer_type, ignore_names in ignore.items(): if layer_type in quantization_type_names: @@ -218,20 +237,15 @@ def from_pretrained( # else we leave it off the ignore list, doesn't fall under any of the # existing quantization schemes so it won't be quantized - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args - ) - + # create config groups from all unique schemes config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): + for idx, scheme in enumerate(quantization_schemes): group_name = "group_" + str(idx) config_groups[group_name] = scheme + # infer format if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: + if model_status == QuantizationStatus.COMPRESSED: format = CompressionFormat.int_quantized.value else: format = CompressionFormat.dense.value @@ -244,7 +258,7 @@ def from_pretrained( return QuantizationConfig( config_groups=config_groups, - quantization_status=quantization_status, + quantization_status=model_status, kv_cache_scheme=kv_cache_scheme, global_compression_ratio=None, format=format, diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fccd677c..0099b088 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -14,7 +14,7 @@ import logging import math -from typing import Generator, List, Optional, Tuple +from typing import Generator, Optional, Tuple import torch from compressed_tensors.quantization.quant_args import ( @@ -38,7 +38,6 @@ "module_type", "get_torch_bit_depth", "can_quantize", - "parse_out_kv_cache_args", "KV_CACHE_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", @@ -391,6 +390,7 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: return bit_depth > quant_args.num_bits +@deprecated() def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. @@ -411,37 +411,6 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: return False -def parse_out_kv_cache_args( - quant_scheme_to_layers: List[QuantizationScheme], -) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]: - """ - If possible, parse out the kv cache specific QuantizationArgs - from the list of the QuantizationSchemes. If no kv cache - specific QuantizationArgs available, this function acts - as an identity function - - :param quant_scheme_to_layers: list of QuantizationSchemes - :return: kv_cache_args (optional) and the (remaining or original) - list of the QuantizationSchemes - """ - kv_cache_quant_scheme_to_layers = [ - scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme) - ] - quant_scheme_to_layers = [ - scheme - for scheme in quant_scheme_to_layers - if not is_kv_cache_quant_scheme(scheme) - ] - - if kv_cache_quant_scheme_to_layers: - kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0] - kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations - else: - kv_cache_args = None - - return kv_cache_args, quant_scheme_to_layers - - def generate_gparam( updated_min_val: torch.Tensor, updated_max_val: torch.Tensor, diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index bdaa40c0..7649f0d0 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -20,7 +20,7 @@ import numpy import torch -from transformers import AutoConfig +from transformers import AutoConfig, PretrainedConfig T = TypeVar("T", bound="Callable") # used by `deprecated` @@ -45,6 +45,9 @@ "unpack_bitmasks", "patch_attr", "ParameterizedDefaultDict", + "get_num_attn_heads", + "get_num_kv_heads", + "get_head_dim", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -396,3 +399,62 @@ def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any: """ with patch_attr(self, "_factory_kwargs", factory_kwargs): return self[args] + + +def get_num_attn_heads(config: PretrainedConfig) -> int: + """ + Get the number of attention heads used by a model + + :param config: model config + :return: num_attention_heads of model + """ + if hasattr(config, "num_attention_heads"): + return config.num_attention_heads + + elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"): + return config.hidden_size // config.head_dim + + else: + raise ValueError( + "Cannot determine num_attention_heads from config. Config must define " + "either `num_attention_heads` or both `hidden_size` and `head_dim`. " + f"{config}" + ) + + +def get_num_kv_heads(config: PretrainedConfig) -> int: + """ + Get the number of key-value attention heads used by a model + + :param config: model config + :return: num_key_value_heads of model + """ + if hasattr(config, "num_key_value_heads"): + return config.num_key_value_heads + + else: + raise ValueError( + "Cannot determine num_key_value_heads from config. Config must define " + f"`num_key_value_heads`. {config}" + ) + + +def get_head_dim(config: PretrainedConfig) -> int: + """ + Get the number of dimensions used by the attention heads of a model + + :param config: model config + :return: head_dim of model + """ + if hasattr(config, "head_dim"): + return config.head_dim + + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + return config.hidden_size // config.num_attention_heads + + else: + raise ValueError( + "Cannot determine head_dim from config. Config must define " + "either `head_dim` or both `hidden_size` and `num_attention_heads`. " + f"{config}" + ) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index b96e83d0..f26400b0 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -30,6 +30,7 @@ "match_targets", "match_modules_set", "is_match", + "is_narrow_match", ] @@ -260,6 +261,34 @@ def is_match( ) +def is_narrow_match( + model: torch.nn.Module, + targets: Union[str, Iterable[str]], + name: str, + module: Optional[torch.nn.Module] = None, +) -> bool: + """ + Checks if any of the targets narrowly match the module. A target narrowly matches + a module if the target matches the module, but does not match the module's parent + + :param model: model containing both module and its parent + :param targets: target strings, potentially containing "re:" prefixes + :param name: name of module to match + :param module: module to match. If none is provided, then get module from model + :return: True if any of the targets narrow match the module + """ + targets = [targets] if isinstance(targets, str) else targets + module = module if module is not None else model.get_submodule(name) + + parent_name = name.rsplit(".", 1)[0] + parent = model.get_submodule(parent_name) + + return any( + is_match(name, module, target) and not is_match(parent_name, parent, target) + for target in targets + ) + + def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool: """ Returns true if target string begins with "re:" and regex matches or if target diff --git a/tests/test_modeling/test_attention_and_cache.py b/tests/test_modeling/test_attention_and_cache.py new file mode 100644 index 00000000..230b8f3e --- /dev/null +++ b/tests/test_modeling/test_attention_and_cache.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.modeling import ( + IMPL_ATTR, + KV_CACHE_ATTR, + QuantizedAttentionImpl, + QuantizedKVCache, + initialize_hooked_attention, + initialize_hooked_kv_cache, + register_key_hook, + register_query_hook, + register_value_hook, +) +from tests.testing_utils import requires_gpu +from transformers import AutoModelForCausalLM + + +@requires_gpu +def test_attention_cache(): + model = AutoModelForCausalLM.from_pretrained( + "nm-testing/llama2.c-stories15M", device_map="cuda" + ) + inputs = {key: value.to("cuda") for key, value in model.dummy_inputs.items()} + true_outputs = model(**inputs) + layers = model.model.layers + + # check if hooks work + k_called = [False for _ in range(len(layers))] + v_called = [False for _ in range(len(layers))] + + # apply kv cache quantization + _apply_kv_cache(model, layers, k_called, v_called) + + # check kv cache quantization + outputs = model(**inputs) + assert torch.equal(outputs.logits, true_outputs.logits) + assert all(k_called) and all(v_called) + + """ apply attention quantization after kv cache quantization """ + + # check if hooks work + q_called = [False for _ in range(len(layers))] + k_called = [False for _ in range(len(layers))] + v_called = [False for _ in range(len(layers))] + + # apply attention quantization + _apply_attention(model, layers, q_called, k_called, v_called) + + # check attention quantization + outputs = model(**inputs) + assert torch.equal(outputs.logits, true_outputs.logits) + assert all(q_called) and all(k_called) and all(v_called) + + +def _apply_kv_cache(model, layers, k_called, v_called): + for layer_index, layer in enumerate(layers): + module = layer.self_attn + initialize_hooked_kv_cache(model, module) + assert isinstance(getattr(module, KV_CACHE_ATTR), QuantizedKVCache) + + # reapply is no-op + initialize_hooked_kv_cache(model, module) + + def k_hook(_module, _states, layer_index=layer_index): # NOTE: capture by value + k_called[layer_index] = True + + def v_hook(_module, _states, layer_index=layer_index): + my_index = layer_index + v_called[my_index] = True + + register_key_hook(module, k_hook) + register_value_hook(module, v_hook) + + +def _apply_attention(model, layers, q_called, k_called, v_called): + for layer_index, layer in enumerate(layers): + module = layer.self_attn + initialize_hooked_attention(model, module) + assert isinstance(getattr(module, IMPL_ATTR), QuantizedAttentionImpl) + + # reapply is no-op + initialize_hooked_attention(model, module) + + def q_hook(_module, _states, layer_index=layer_index): + q_called[layer_index] = True + + def k_hook(_module, _states, layer_index=layer_index): + k_called[layer_index] = True + + def v_hook(_module, _states, layer_index=layer_index): + v_called[layer_index] = True + + register_query_hook(module, q_hook) + register_key_hook(module, k_hook) + register_value_hook(module, v_hook) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index ae890820..d56148c5 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -131,6 +131,64 @@ def test_apply_quantization_config_tinyllama(): ) +@pytest.mark.parametrize( + "config", + [ + QuantizationConfig( + config_groups={ + "linear": QuantizationScheme( + targets=["Linear"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + } + ), + QuantizationConfig( + config_groups={ + "linear": QuantizationScheme( + targets=["Linear"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + }, + ignore=[ + "model.layers.0.self_attn.q_proj", + "model.layers.1.self_attn.k_proj", + "model.layers.2.self_attn.v_proj", + ], + ), + QuantizationConfig( + config_groups={}, + kv_cache_scheme=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ), + QuantizationConfig( + config_groups={ + "attention": QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ) + }, + kv_cache_scheme=QuantizationArgs( + num_bits=8, type="float", strategy="tensor" + ), + ), + ], +) +def test_from_pretrained(config: QuantizationConfig): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + apply_quantization_config(model, config) + _config = QuantizationConfig.from_pretrained(model) + assert list(_config.config_groups.values()) == list(config.config_groups.values()) + assert _config.kv_cache_scheme == config.kv_cache_scheme + assert _config.ignore == config.ignore + + def test_serialize_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config() model = get_tinyllama_model() @@ -366,3 +424,43 @@ def test_multi_apply_quantization_config(): weight_zero_point is not None and weight_zero_point.shape == torch.Size([1]) ) + + +@requires_accelerate() +def test_apply_kv_cache(): + from accelerate import init_empty_weights + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + args = QuantizationArgs(num_bits=8, type="float", strategy="tensor") + config = QuantizationConfig(config_groups={}, kv_cache_scheme=args) + + apply_quantization_config(model, config) + + for layer in model.model.layers: + assert getattr(layer.self_attn, "quantization_scheme").input_activations == args + assert hasattr(layer.self_attn, "k_scale") + assert hasattr(layer.self_attn, "v_scale") + + +@requires_accelerate() +def test_apply_attention(): + from accelerate import init_empty_weights + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + scheme = QuantizationScheme( + targets=["LlamaAttention"], + input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + ) + config = QuantizationConfig(config_groups={"attention": scheme}) + + apply_quantization_config(model, config) + + for layer in model.model.layers: + assert getattr(layer.self_attn, "quantization_scheme") == scheme + assert hasattr(layer.self_attn, "q_scale") + assert hasattr(layer.self_attn, "k_scale") + assert hasattr(layer.self_attn, "v_scale") diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index c85d2f43..1129120c 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -21,6 +21,7 @@ from compressed_tensors.utils import ( InternalModule, is_match, + is_narrow_match, match_modules_set, match_named_modules, match_named_parameters, @@ -500,6 +501,86 @@ class InternalLinear(InternalModule, nn.Linear): assert len(matches) == 0 +class TestIsNarrowMatch: + def test_narrow_match_true_child_only(self): + """ + Target matches the child module name but NOT its parent name. + Should return True. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # Matches "...q_proj" but not "...self_attn" + target = r"re:.*q_proj$" + + assert is_narrow_match(model, target, name) + + def test_narrow_match_false_when_parent_also_matches(self): + """ + Target matches both the child and its parent name. + Should return False because it's not a 'narrow' match. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # Broad target that also matches the parent "transformer.layers.0.self_attn" + target = r"re:transformer\.layers\.0\..*" + + assert not is_narrow_match(model, target, name) + + def test_narrow_match_false_when_neither_matches(self): + """ + Target matches neither the child nor the parent. + Should return False. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + target = r"re:this_does_not_exist$" + + assert not is_narrow_match(model, target, name) + + def test_narrow_match_iterable_targets_any_true(self): + """ + With multiple targets: if any target narrowly matches the child, + the function should return True. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + # First target is broad (matches both child & parent -> narrow False), + # second target is narrow (matches child only -> narrow True). + targets = [ + r"re:transformer\.layers\.0\..*", + r"re:.*q_proj$", + ] + + assert is_narrow_match(model, targets, name) + + def test_narrow_match_with_explicit_module_argument(self): + """ + Passing the module explicitly should behave the same as when it's + retrieved from the model by name. + """ + model = DummyModel() + name = "transformer.layers.0.self_attn.q_proj" + module = model.get_submodule(name) + target = r"re:.*q_proj$" + + # Both ways should be True + assert is_narrow_match(model, target, name) + assert is_narrow_match(model, target, name, module=module) + + def test_narrow_match_top_level_behavior_documented(self): + """ + (Behavior check) For a top-level module name without a dot, the current + implementation derives parent_name == name, so parent==child. + Then 'narrow' cannot be True because parent match mirrors child match. + This test documents current behavior to guard against regressions. + """ + model = DummyModel() + name = "layer1" # top-level module in DummyModel + target = r"re:^layer1$" + + assert not is_narrow_match(model, target, name) + + class TestIntegration: """Integration tests combining multiple functions"""