|
21 | 21 |
|
22 | 22 | import torch |
23 | 23 | from compressed_tensors.config import CompressionFormat |
| 24 | +from compressed_tensors.modeling import ( |
| 25 | + initialize_hooked_attention, |
| 26 | + initialize_hooked_kv_cache, |
| 27 | +) |
24 | 28 | from compressed_tensors.quantization.lifecycle.initialize import ( |
25 | 29 | initialize_module_for_quantization, |
| 30 | + is_attention_module, |
26 | 31 | ) |
27 | 32 | from compressed_tensors.quantization.quant_args import QuantizationArgs |
28 | 33 | from compressed_tensors.quantization.quant_config import ( |
29 | 34 | QuantizationConfig, |
30 | 35 | QuantizationStatus, |
31 | 36 | ) |
32 | 37 | from compressed_tensors.quantization.quant_scheme import QuantizationScheme |
33 | | -from compressed_tensors.quantization.utils import ( |
34 | | - KV_CACHE_TARGETS, |
35 | | - is_kv_cache_quant_scheme, |
36 | | -) |
37 | 38 | from compressed_tensors.utils.helpers import deprecated, replace_module |
38 | | -from compressed_tensors.utils.match import match_named_modules, match_targets |
| 39 | +from compressed_tensors.utils.match import ( |
| 40 | + is_narrow_match, |
| 41 | + match_named_modules, |
| 42 | + match_targets, |
| 43 | +) |
39 | 44 | from compressed_tensors.utils.offload import update_parameter_data |
40 | 45 | from compressed_tensors.utils.safetensors_load import get_safetensors_folder |
41 | 46 | from safetensors import safe_open |
@@ -126,8 +131,24 @@ def apply_quantization_config( |
126 | 131 | if config is None: # see PR #180 |
127 | 132 | return dict() |
128 | 133 |
|
129 | | - # preprocess to support kv cache scheme |
130 | | - config = process_quantization_config(config) |
| 134 | + # force zero points during initialization |
| 135 | + force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED |
| 136 | + |
| 137 | + # apply kv cache quantization before any attention quantization |
| 138 | + # because attention quantization is a superset of kv cache quantization |
| 139 | + if config.kv_cache_scheme is not None: |
| 140 | + scheme = QuantizationScheme( |
| 141 | + targets=".*self_attn$", input_activations=config.kv_cache_scheme |
| 142 | + ) |
| 143 | + for submodule in model.modules(): |
| 144 | + if is_attention_module(submodule): |
| 145 | + submodule.quantization_scheme = scheme |
| 146 | + initialize_hooked_kv_cache(model, submodule) |
| 147 | + initialize_module_for_quantization( |
| 148 | + submodule, |
| 149 | + force_zero_point=force_zero_point, |
| 150 | + ) |
| 151 | + submodule.quantization_status = config.quantization_status |
131 | 152 |
|
132 | 153 | # build mapping of targets to schemes for easier matching |
133 | 154 | # use ordered dict to preserve target ordering in config |
@@ -163,51 +184,19 @@ def apply_quantization_config( |
163 | 184 | replace_module(model, name, compressed_linear) |
164 | 185 |
|
165 | 186 | else: |
| 187 | + if is_attention_module(submodule) and is_narrow_match( |
| 188 | + model, scheme.targets, name |
| 189 | + ): |
| 190 | + initialize_hooked_attention(model, submodule) |
| 191 | + |
166 | 192 | initialize_module_for_quantization( |
167 | 193 | submodule, |
168 | | - force_zero_point=config.quantization_status |
169 | | - != QuantizationStatus.COMPRESSED, |
| 194 | + force_zero_point=force_zero_point, |
170 | 195 | ) |
171 | 196 |
|
172 | 197 | submodule.quantization_status = config.quantization_status |
173 | 198 |
|
174 | 199 |
|
175 | | -def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: |
176 | | - """ |
177 | | - Preprocess the raw QuantizationConfig |
178 | | -
|
179 | | - :param config: the raw QuantizationConfig |
180 | | - :return: the processed QuantizationConfig |
181 | | - """ |
182 | | - if config.kv_cache_scheme is not None: |
183 | | - config = process_kv_cache_config(config) |
184 | | - |
185 | | - return config |
186 | | - |
187 | | - |
188 | | -def process_kv_cache_config( |
189 | | - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS |
190 | | -) -> QuantizationConfig: |
191 | | - """ |
192 | | - Reformulate the `config.kv_cache` as a `config_group` |
193 | | - and add it to the set of existing `config.groups` |
194 | | -
|
195 | | - :param config: the QuantizationConfig |
196 | | - :return: the QuantizationConfig with additional "kv_cache" group |
197 | | - """ |
198 | | - if targets == KV_CACHE_TARGETS: |
199 | | - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") |
200 | | - |
201 | | - kv_cache_dict = config.kv_cache_scheme.model_dump() |
202 | | - kv_cache_scheme = QuantizationScheme( |
203 | | - output_activations=QuantizationArgs(**kv_cache_dict), |
204 | | - targets=targets, |
205 | | - ) |
206 | | - kv_cache_group = dict(kv_cache=kv_cache_scheme) |
207 | | - config.config_groups.update(kv_cache_group) |
208 | | - return config |
209 | | - |
210 | | - |
211 | 200 | @deprecated( |
212 | 201 | message="This function is deprecated and will be removed in a future release." |
213 | 202 | "Please use `match_targets` from `compressed_tensors.utils.match` instead." |
@@ -282,60 +271,6 @@ def _scheme_from_targets( |
282 | 271 | targets: List[str], |
283 | 272 | name: str, |
284 | 273 | ) -> QuantizationScheme: |
285 | | - if len(targets) == 1: |
286 | | - # if `targets` iterable contains a single element |
287 | | - # use it as the key |
288 | | - return target_to_scheme[targets[0]] |
289 | | - |
290 | | - # otherwise, we need to merge QuantizationSchemes corresponding |
291 | | - # to multiple targets. This is most likely because `name` module |
292 | | - # is being target both as an ordinary quantization target, as well |
293 | | - # as kv cache quantization target |
294 | | - schemes_to_merge = [target_to_scheme[target] for target in targets] |
295 | | - return _merge_schemes(schemes_to_merge, name) |
296 | | - |
297 | | - |
298 | | -def _merge_schemes( |
299 | | - schemes_to_merge: List[QuantizationScheme], name: str |
300 | | -) -> QuantizationScheme: |
301 | | - kv_cache_quantization_scheme = [ |
302 | | - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) |
303 | | - ] |
304 | | - if not kv_cache_quantization_scheme: |
305 | | - # if the schemes_to_merge do not contain any |
306 | | - # kv cache QuantizationScheme |
307 | | - # return the first scheme (the prioritized one, |
308 | | - # since the order of schemes_to_merge matters) |
309 | | - return schemes_to_merge[0] |
310 | | - else: |
311 | | - # fetch the kv cache QuantizationScheme and the highest |
312 | | - # priority non-kv cache QuantizationScheme and merge them |
313 | | - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] |
314 | | - quantization_scheme = [ |
315 | | - scheme |
316 | | - for scheme in schemes_to_merge |
317 | | - if not is_kv_cache_quant_scheme(scheme) |
318 | | - ][0] |
319 | | - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] |
320 | | - merged_scheme = {} |
321 | | - for scheme in schemes_to_merge: |
322 | | - scheme_dict = { |
323 | | - k: v for k, v in scheme.model_dump().items() if v is not None |
324 | | - } |
325 | | - # when merging multiple schemes, the final target will be |
326 | | - # the `name` argument - hence erase the original targets |
327 | | - del scheme_dict["targets"] |
328 | | - # make sure that schemes do not "clash" with each other |
329 | | - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) |
330 | | - if overlapping_keys: |
331 | | - raise ValueError( |
332 | | - f"The module: {name} is being modified by two clashing " |
333 | | - f"quantization schemes, that jointly try to override " |
334 | | - f"properties: {overlapping_keys}. Fix the quantization config " |
335 | | - "so that it is not ambiguous." |
336 | | - ) |
337 | | - merged_scheme.update(scheme_dict) |
338 | | - |
339 | | - merged_scheme.update(targets=[name]) |
340 | | - |
341 | | - return QuantizationScheme(**merged_scheme) |
| 274 | + # return the first scheme (the prioritized one, |
| 275 | + # since the order of target_to_scheme matters) |
| 276 | + return target_to_scheme[targets[0]] |
0 commit comments