|
3 | 3 | from typing import Dict, List, Optional, Tuple, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | | -from compressed_tensors.quantization import disable_quantization |
| 6 | +from compressed_tensors.quantization import QuantizationType, disable_quantization |
7 | 7 | from compressed_tensors.utils import ( |
8 | 8 | align_modules, |
9 | 9 | get_execution_device, |
@@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin): |
126 | 126 |
|
127 | 127 | # Private vars set during validation |
128 | 128 | _num_bits: Optional[int] = PrivateAttr(default=None) |
| 129 | + _activation_bits: int = PrivateAttr(default=16) |
129 | 130 | _symmetric: Optional[bool] = PrivateAttr(default=None) |
130 | 131 | _group_size: Optional[int] = PrivateAttr(default=None) |
131 | 132 |
|
@@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier": |
189 | 190 | if act is not None |
190 | 191 | } |
191 | 192 | if not (len(num_bits_set) == 0 or num_bits_set == {16}): |
| 193 | + num_bits_type = { |
| 194 | + act.type |
| 195 | + for group in config.config_groups.values() |
| 196 | + for act in (group.input_activations, group.output_activations) |
| 197 | + if act is not None |
| 198 | + } |
| 199 | + assert ( |
| 200 | + next(iter(num_bits_type)) == QuantizationType.FLOAT |
| 201 | + ), "In AWQ, lower-precision activation quantization must be float" |
| 202 | + |
| 203 | + model._activation_bits = next(iter(num_bits_set)) |
| 204 | + |
192 | 205 | warnings.warn( |
193 | 206 | "A strategy including activation quantization was detected. " |
194 | 207 | "AWQ was originally intended for weight-only quantization. " |
@@ -612,16 +625,26 @@ def _compute_best_scale( |
612 | 625 | # Q(W * s) |
613 | 626 | for linear in linears2scale: |
614 | 627 | linear.weight.mul_(_scalesview) |
615 | | - update_offload_parameter( |
616 | | - linear, |
617 | | - "weight", |
| 628 | + scaled_weight = ( |
618 | 629 | _pseudo_quantize_tensor( |
619 | 630 | w=linear.weight.data, |
620 | 631 | symmetric=self._symmetric, |
621 | 632 | bit_width=self._num_bits, |
622 | 633 | group_size=self._group_size, |
623 | 634 | )[0] |
624 | | - / _scalesview, |
| 635 | + / _scalesview |
| 636 | + ) |
| 637 | + |
| 638 | + # fp8 activation simulation |
| 639 | + if self._activation_bits == 8: |
| 640 | + scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to( |
| 641 | + torch.float16 |
| 642 | + ) |
| 643 | + |
| 644 | + update_offload_parameter( |
| 645 | + linear, |
| 646 | + "weight", |
| 647 | + scaled_weight, |
625 | 648 | ) |
626 | 649 |
|
627 | 650 | # W * X |
|
0 commit comments