| 
3 | 3 | from typing import Dict, List, Optional, Tuple, Union  | 
4 | 4 | 
 
  | 
5 | 5 | import torch  | 
6 |  | -from compressed_tensors.quantization import disable_quantization  | 
 | 6 | +from compressed_tensors.quantization import QuantizationType, disable_quantization  | 
7 | 7 | from compressed_tensors.utils import (  | 
8 | 8 |     align_modules,  | 
9 | 9 |     get_execution_device,  | 
@@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin):  | 
126 | 126 | 
 
  | 
127 | 127 |     # Private vars set during validation  | 
128 | 128 |     _num_bits: Optional[int] = PrivateAttr(default=None)  | 
 | 129 | +    _activation_bits: int = PrivateAttr(default=16)  | 
129 | 130 |     _symmetric: Optional[bool] = PrivateAttr(default=None)  | 
130 | 131 |     _group_size: Optional[int] = PrivateAttr(default=None)  | 
131 | 132 | 
 
  | 
@@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier":  | 
189 | 190 |             if act is not None  | 
190 | 191 |         }  | 
191 | 192 |         if not (len(num_bits_set) == 0 or num_bits_set == {16}):  | 
 | 193 | +            num_bits_type = {  | 
 | 194 | +                act.type  | 
 | 195 | +                for group in config.config_groups.values()  | 
 | 196 | +                for act in (group.input_activations, group.output_activations)  | 
 | 197 | +                if act is not None  | 
 | 198 | +            }  | 
 | 199 | +            assert (  | 
 | 200 | +                next(iter(num_bits_type)) == QuantizationType.FLOAT  | 
 | 201 | +            ), "In AWQ, lower-precision activation quantization must be float"  | 
 | 202 | + | 
 | 203 | +            model._activation_bits = next(iter(num_bits_set))  | 
 | 204 | + | 
192 | 205 |             warnings.warn(  | 
193 | 206 |                 "A strategy including activation quantization was detected. "  | 
194 | 207 |                 "AWQ was originally intended for weight-only quantization. "  | 
@@ -612,16 +625,26 @@ def _compute_best_scale(  | 
612 | 625 |             # Q(W * s)  | 
613 | 626 |             for linear in linears2scale:  | 
614 | 627 |                 linear.weight.mul_(_scalesview)  | 
615 |  | -                update_offload_parameter(  | 
616 |  | -                    linear,  | 
617 |  | -                    "weight",  | 
 | 628 | +                scaled_weight = (  | 
618 | 629 |                     _pseudo_quantize_tensor(  | 
619 | 630 |                         w=linear.weight.data,  | 
620 | 631 |                         symmetric=self._symmetric,  | 
621 | 632 |                         bit_width=self._num_bits,  | 
622 | 633 |                         group_size=self._group_size,  | 
623 | 634 |                     )[0]  | 
624 |  | -                    / _scalesview,  | 
 | 635 | +                    / _scalesview  | 
 | 636 | +                )  | 
 | 637 | + | 
 | 638 | +                # fp8 activation simulation  | 
 | 639 | +                if self._activation_bits == 8:  | 
 | 640 | +                    scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to(  | 
 | 641 | +                        torch.float16  | 
 | 642 | +                    )  | 
 | 643 | + | 
 | 644 | +                update_offload_parameter(  | 
 | 645 | +                    linear,  | 
 | 646 | +                    "weight",  | 
 | 647 | +                    scaled_weight,  | 
625 | 648 |                 )  | 
626 | 649 | 
 
  | 
627 | 650 |             # W * X  | 
 | 
0 commit comments