diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6bc97b446..bc86ba25f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -182,6 +182,8 @@ def validate_awq_after(model: "AWQModifier") -> "AWQModifier": ), "In AWQ, all config groups must use the same configuration for group_size" model._group_size = next(iter(group_size_set)) + if model._group_size is None: + model._group_size = -1 in_num_bits_set = set( group.input_activations.num_bits @@ -460,14 +462,17 @@ def _apply_smoothing(self, model: Module) -> None: weight = torch.cat([bl.weight for bl in balance_layers], dim=0) org_shape = weight.shape # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self._group_size) + if self._group_size > 0: + weight = weight.view(-1, self._group_size) # Calculates the relative magnitude of the weights within # each of the quantization groups, and rescales each group # individually so that each group has weights on a 0-1 scale. weight.abs_() weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions - weight = weight.view(org_shape) + if self._group_size > 0: + # Resizes the rescaled weight matrix back up to + # its original dimensions + weight = weight.view(org_shape) # Gets the average rescaled magnitude for each output channel w_mean = weight.mean(0) del weight