Skip to content

Commit fbeaeba

Browse files
committed
apply
Signed-off-by: Kyle Sayers <[email protected]>
1 parent caa1ecf commit fbeaeba

File tree

7 files changed

+102
-200
lines changed

7 files changed

+102
-200
lines changed

src/compressed_tensors/modeling/attention.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@
1717
from weakref import ref
1818

1919
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
20-
from compressed_tensors.quantization import (
21-
QuantizationArgs,
22-
QuantizationScheme,
23-
QuantizationStrategy,
24-
forward_quantize,
25-
)
20+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
2621
from compressed_tensors.utils import getattr_chain
2722
from compressed_tensors.utils.internal import InternalModule
2823
from torch import Tensor
@@ -60,11 +55,12 @@ class QuantizedAttentionImpl(InternalModule):
6055
:param attn_module: parent attention module
6156
"""
6257

58+
_original_impl = "eager"
59+
6360
def __init__(self, config: PretrainedConfig, attn_module: Module):
6461
super().__init__()
6562
self.config = config
6663
self.attn_module = ref(attn_module) # avoid circular references
67-
self._qparams_initialized = False
6864

6965
def forward(
7066
self,
@@ -79,7 +75,7 @@ def forward(
7975
quant_args_attr = "quantization_scheme.input_activations"
8076
quant_args = getattr_chain(module, quant_args_attr, None)
8177
quant_enabled = getattr(module, "quantization_enabled", True)
82-
if quant_args is not None and quant_enabled and self._qparams_initialized:
78+
if quant_args is not None and quant_enabled:
8379
query = forward_quantize(module, query, "q", quant_args)
8480

8581
# original attention

src/compressed_tensors/modeling/kvcache.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
from typing import Callable, Optional, Tuple
1717
from weakref import ref
1818

19-
# from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
20-
# from compressed_tensors.quantization.lifecycle.initialize import (
21-
# _initialize_scale_zero_point,
22-
# )
19+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
2320
from compressed_tensors.utils import getattr_chain
2421
from compressed_tensors.utils.internal import InternalModule
2522
from torch import Tensor
@@ -59,7 +56,6 @@ def __init__(self, config: PretrainedConfig, attn_module: Module):
5956
self.config = config
6057
self.attn_module = ref(attn_module) # avoid circular reference
6158
self.past_key_values: Optional[Cache] = None
62-
self._qparams_initialized = False
6359

6460
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
6561
return self(*args, **kwargs)
@@ -76,7 +72,7 @@ def forward(
7672
quant_args_attr = "quantization_scheme.input_activations"
7773
quant_args = getattr_chain(module, quant_args_attr, None)
7874
quant_enabled = getattr(module, "quantization_enabled", True)
79-
if quant_args is not None and quant_enabled and self._qparams_initialized:
75+
if quant_args is not None and quant_enabled:
8076
key_states = forward_quantize(module, key_states, "k", quant_args)
8177
value_states = forward_quantize(module, value_states, "v", quant_args)
8278

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 37 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,26 @@
2121

2222
import torch
2323
from compressed_tensors.config import CompressionFormat
24+
from compressed_tensors.modeling import (
25+
initialize_hooked_attention,
26+
initialize_hooked_kv_cache,
27+
)
2428
from compressed_tensors.quantization.lifecycle.initialize import (
2529
initialize_module_for_quantization,
30+
is_attention_module,
2631
)
2732
from compressed_tensors.quantization.quant_args import QuantizationArgs
2833
from compressed_tensors.quantization.quant_config import (
2934
QuantizationConfig,
3035
QuantizationStatus,
3136
)
3237
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
33-
from compressed_tensors.quantization.utils import (
34-
KV_CACHE_TARGETS,
35-
is_kv_cache_quant_scheme,
36-
)
3738
from compressed_tensors.utils.helpers import deprecated, replace_module
38-
from compressed_tensors.utils.match import match_named_modules, match_targets
39+
from compressed_tensors.utils.match import (
40+
is_narrow_match,
41+
match_named_modules,
42+
match_targets,
43+
)
3944
from compressed_tensors.utils.offload import update_parameter_data
4045
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4146
from safetensors import safe_open
@@ -126,8 +131,24 @@ def apply_quantization_config(
126131
if config is None: # see PR #180
127132
return dict()
128133

129-
# preprocess to support kv cache scheme
130-
config = process_quantization_config(config)
134+
# force zero points during initialization
135+
force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED
136+
137+
# apply kv cache quantization before any attention quantization
138+
# because attention quantization is a superset of kv cache quantization
139+
if config.kv_cache_scheme is not None:
140+
scheme = QuantizationScheme(
141+
targets=".*self_attn$", input_activations=config.kv_cache_scheme
142+
)
143+
for submodule in model.modules():
144+
if is_attention_module(submodule):
145+
submodule.quantization_scheme = scheme
146+
initialize_hooked_kv_cache(model, submodule)
147+
initialize_module_for_quantization(
148+
submodule,
149+
force_zero_point=force_zero_point,
150+
)
151+
submodule.quantization_status = config.quantization_status
131152

132153
# build mapping of targets to schemes for easier matching
133154
# use ordered dict to preserve target ordering in config
@@ -163,51 +184,19 @@ def apply_quantization_config(
163184
replace_module(model, name, compressed_linear)
164185

165186
else:
187+
if is_attention_module(submodule) and is_narrow_match(
188+
model, scheme.targets, name
189+
):
190+
initialize_hooked_attention(model, submodule)
191+
166192
initialize_module_for_quantization(
167193
submodule,
168-
force_zero_point=config.quantization_status
169-
!= QuantizationStatus.COMPRESSED,
194+
force_zero_point=force_zero_point,
170195
)
171196

172197
submodule.quantization_status = config.quantization_status
173198

174199

175-
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
176-
"""
177-
Preprocess the raw QuantizationConfig
178-
179-
:param config: the raw QuantizationConfig
180-
:return: the processed QuantizationConfig
181-
"""
182-
if config.kv_cache_scheme is not None:
183-
config = process_kv_cache_config(config)
184-
185-
return config
186-
187-
188-
def process_kv_cache_config(
189-
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
190-
) -> QuantizationConfig:
191-
"""
192-
Reformulate the `config.kv_cache` as a `config_group`
193-
and add it to the set of existing `config.groups`
194-
195-
:param config: the QuantizationConfig
196-
:return: the QuantizationConfig with additional "kv_cache" group
197-
"""
198-
if targets == KV_CACHE_TARGETS:
199-
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
200-
201-
kv_cache_dict = config.kv_cache_scheme.model_dump()
202-
kv_cache_scheme = QuantizationScheme(
203-
output_activations=QuantizationArgs(**kv_cache_dict),
204-
targets=targets,
205-
)
206-
kv_cache_group = dict(kv_cache=kv_cache_scheme)
207-
config.config_groups.update(kv_cache_group)
208-
return config
209-
210-
211200
@deprecated(
212201
message="This function is deprecated and will be removed in a future release."
213202
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
@@ -282,60 +271,6 @@ def _scheme_from_targets(
282271
targets: List[str],
283272
name: str,
284273
) -> QuantizationScheme:
285-
if len(targets) == 1:
286-
# if `targets` iterable contains a single element
287-
# use it as the key
288-
return target_to_scheme[targets[0]]
289-
290-
# otherwise, we need to merge QuantizationSchemes corresponding
291-
# to multiple targets. This is most likely because `name` module
292-
# is being target both as an ordinary quantization target, as well
293-
# as kv cache quantization target
294-
schemes_to_merge = [target_to_scheme[target] for target in targets]
295-
return _merge_schemes(schemes_to_merge, name)
296-
297-
298-
def _merge_schemes(
299-
schemes_to_merge: List[QuantizationScheme], name: str
300-
) -> QuantizationScheme:
301-
kv_cache_quantization_scheme = [
302-
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
303-
]
304-
if not kv_cache_quantization_scheme:
305-
# if the schemes_to_merge do not contain any
306-
# kv cache QuantizationScheme
307-
# return the first scheme (the prioritized one,
308-
# since the order of schemes_to_merge matters)
309-
return schemes_to_merge[0]
310-
else:
311-
# fetch the kv cache QuantizationScheme and the highest
312-
# priority non-kv cache QuantizationScheme and merge them
313-
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
314-
quantization_scheme = [
315-
scheme
316-
for scheme in schemes_to_merge
317-
if not is_kv_cache_quant_scheme(scheme)
318-
][0]
319-
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
320-
merged_scheme = {}
321-
for scheme in schemes_to_merge:
322-
scheme_dict = {
323-
k: v for k, v in scheme.model_dump().items() if v is not None
324-
}
325-
# when merging multiple schemes, the final target will be
326-
# the `name` argument - hence erase the original targets
327-
del scheme_dict["targets"]
328-
# make sure that schemes do not "clash" with each other
329-
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
330-
if overlapping_keys:
331-
raise ValueError(
332-
f"The module: {name} is being modified by two clashing "
333-
f"quantization schemes, that jointly try to override "
334-
f"properties: {overlapping_keys}. Fix the quantization config "
335-
"so that it is not ambiguous."
336-
)
337-
merged_scheme.update(scheme_dict)
338-
339-
merged_scheme.update(targets=[name])
340-
341-
return QuantizationScheme(**merged_scheme)
274+
# return the first scheme (the prioritized one,
275+
# since the order of target_to_scheme matters)
276+
return target_to_scheme[targets[0]]

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@
3737
from compressed_tensors.quantization.lifecycle.forward import (
3838
wrap_module_forward_quantized,
3939
)
40-
from compressed_tensors.quantization.utils import (
41-
is_fp4,
42-
is_kv_cache_quant_scheme,
43-
strategy_cdiv,
44-
)
40+
from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
4541
from compressed_tensors.utils import (
4642
disable_hf_hook,
4743
get_execution_device,
@@ -129,8 +125,7 @@ def initialize_module_for_quantization(
129125
force_zero_point=force_zero_point,
130126
)
131127

132-
output_is_kv_cache = is_kv_cache_quant_scheme(scheme)
133-
if scheme.output_activations is not None and not output_is_kv_cache:
128+
if scheme.output_activations is not None:
134129
initialize_qparams(
135130
module,
136131
"output",

0 commit comments

Comments
 (0)