-
Notifications
You must be signed in to change notification settings - Fork 464
[AWQ] speed improvements #2188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AWQ] speed improvements #2188
Changes from all commits
a52b62c
6e7a2ff
3cca978
a5799d7
dc57f97
49c86af
e67d58c
b1ac2e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,23 +92,23 @@ class AWQModifier(Modifier, QuantizationMixin): | |
|
|
||
| - on_initialize | ||
| - resolve mappings | ||
| - capture kwargs needed for forward passes into modules | ||
| - set up hooks to capture forward pass kwargs and FP16 baseline outputs | ||
| - on_start | ||
| - set up activation cache hooks to capture input activations | ||
| to balance layers | ||
| - on sequential epoch end | ||
| - apply smoothing to each smoothing layer | ||
| - consume cached activations across all batches | ||
| - clear cached activations as they are used | ||
| - find best smoothing scale for each smoothing layer | ||
| - apply to model weights | ||
| - find best smoothing scale for each smoothing layer via grid search | ||
| - apply best scales to model weights | ||
| - raise error if any unused activations remain | ||
| - on_end | ||
| - re-run logic of sequential epoch end (in case of basic pipeline) | ||
| - set scales and zero points | ||
| - remove activation hooks | ||
| - on_finalize | ||
| - clear resolved mappings and captured activations | ||
| - clear resolved mappings, cached activations, and FP16 baseline cache | ||
|
|
||
| :param sequential_targets: list of module names to compress in | ||
| the same calibration pass | ||
|
|
@@ -158,6 +158,12 @@ class AWQModifier(Modifier, QuantizationMixin): | |
| _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( | ||
| default_factory=dict | ||
| ) | ||
| # Cache FP16 baseline outputs for each parent module | ||
| # During calibration: list of flattened tensors, one per batch | ||
| # After calibration: single concatenated tensor | ||
| _fp16_baseline_cache: dict[Module, list[torch.Tensor] | torch.Tensor] = PrivateAttr( | ||
| default_factory=dict | ||
| ) | ||
|
|
||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||
| """ | ||
|
|
@@ -262,6 +268,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: | |
|
|
||
| self._parent_args_cache.clear() | ||
| self._smooth_activation_means.clear() | ||
| self._fp16_baseline_cache.clear() | ||
| self._resolved_mappings.clear() | ||
|
|
||
| return True | ||
|
|
@@ -337,8 +344,10 @@ def _set_resolved_mappings(self, model: Module) -> None: | |
|
|
||
| def _setup_activation_cache_hooks(self) -> None: | ||
| """ | ||
| Attach a forward hook to each activation we want to smooth. This allows us to | ||
| calculate the dynamic range during calibration | ||
| Attach forward hooks to cache data needed for AWQ smoothing: | ||
| 1. Parent module kwargs (for replaying forward passes during grid search) | ||
| 2. FP16 baseline outputs (flattened, to avoid recomputing during grid search) | ||
| 3. Smooth layer input activations (to calculate activation means) | ||
| """ | ||
|
|
||
| def cache_parent_kwargs_hook( | ||
|
|
@@ -349,6 +358,17 @@ def cache_parent_kwargs_hook( | |
| values = inspect.signature(module.forward).bind(*args, **kwargs) | ||
| self._parent_args_cache[module].append(values.arguments) | ||
|
|
||
| def cache_fp16_baseline_hook( | ||
| module: Module, | ||
| args: tuple[torch.Tensor, ...], | ||
| output: torch.Tensor, | ||
| ): | ||
| # Extract first element if tuple (same logic as _run_samples) | ||
| result = output[0] if isinstance(output, tuple) else output | ||
| if module not in self._fp16_baseline_cache: | ||
| self._fp16_baseline_cache[module] = [] | ||
| self._fp16_baseline_cache[module].append(result.detach().flatten()) | ||
|
|
||
| def create_cache_smooth_activations_hook_fn(smooth_name): | ||
| def cache_smooth_activations_hook( | ||
| _module: Module, | ||
|
|
@@ -378,6 +398,12 @@ def cache_smooth_activations_hook( | |
| with_kwargs=True, | ||
| ) | ||
|
|
||
| self.register_hook( | ||
| mapping.parent, | ||
| cache_fp16_baseline_hook, | ||
| "forward", | ||
| ) | ||
|
|
||
| # input activations to balance layers needed for loss function | ||
| # storing inputs to first balance layer is sufficient | ||
| # other balance layers get the same input | ||
|
|
@@ -392,7 +418,7 @@ def _apply_smoothing(self, model: Module) -> None: | |
| """ | ||
| Calculate the best scaling factors for each layer to smooth activations and | ||
| apply the scaling factors to the weights of the next layer to offset the | ||
| smoothing | ||
| smoothing. | ||
|
|
||
| :param model: model to apply smoothing to | ||
| """ | ||
|
|
@@ -413,18 +439,32 @@ def _apply_smoothing(self, model: Module) -> None: | |
| calibration_forward_context(model), | ||
| HooksMixin.disable_hooks(), | ||
| ): | ||
| # Compute output of unquantized module | ||
| fp16_outputs = self._run_samples(parent_module) | ||
| if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): | ||
| # Retrieve cached FP16 baseline outputs (collected during calibration) | ||
| if ( | ||
| parent_module not in self._fp16_baseline_cache | ||
| or len(self._fp16_baseline_cache[parent_module]) == 0 | ||
| ): | ||
| logger.info( | ||
| f"Skipping smooth_layer {mapping.smooth_name}, no activations " | ||
| "found to scale. This can occasionally occur in MoE models " | ||
| "when certain experts are not activated by calibration samples." | ||
| ) | ||
| del self._smooth_activation_means[mapping.smooth_name] | ||
| continue | ||
|
|
||
| # Concatenate FP16 baseline outputs (if still a list from calibration) | ||
| # This happens once per parent module; subsequent mappings reuse it | ||
| # Store as single-element list to maintain consistent access pattern | ||
| if len(self._fp16_baseline_cache[parent_module]) > 1: | ||
| self._fp16_baseline_cache[parent_module] = [ | ||
| torch.cat(self._fp16_baseline_cache[parent_module]) | ||
| ] | ||
|
|
||
| if not all( | ||
| [fp16_output.isfinite().all() for fp16_output in fp16_outputs] | ||
| [ | ||
| ref.isfinite().all() | ||
| for ref in self._fp16_baseline_cache[parent_module] | ||
| ] | ||
| ): | ||
| logger.warning( | ||
| f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf " | ||
|
|
@@ -438,7 +478,14 @@ def _apply_smoothing(self, model: Module) -> None: | |
| del self._smooth_activation_means[mapping.smooth_name] | ||
| continue | ||
|
|
||
| best_scales = self._compute_best_scale(mapping, fp16_outputs) | ||
| orig_layer_weights = { | ||
| balance_layer: balance_layer.weight.clone() | ||
| for balance_layer in mapping.balance_layers | ||
| if hasattr(balance_layer, "quantization_scheme") | ||
| and hasattr(balance_layer.quantization_scheme, "weights") | ||
| } | ||
|
|
||
| best_scales = self._compute_best_scale(mapping, orig_layer_weights) | ||
|
|
||
| @torch.no_grad() | ||
| def _smooth(module: Module): | ||
|
|
@@ -447,7 +494,7 @@ def _smooth(module: Module): | |
| update_offload_parameter( | ||
| module, | ||
| "weight", | ||
| module.weight.mul_(scales.view(1, -1)), | ||
| orig_layer_weights[module] * scales.view(1, -1), | ||
| ) | ||
| elif module == smooth_layer: | ||
| if module.weight.ndim == 1: | ||
|
|
@@ -487,6 +534,7 @@ def _smooth(module: Module): | |
|
|
||
| for v in self._parent_args_cache.values(): | ||
| v.batch_intermediates.clear() | ||
| self._fp16_baseline_cache.clear() | ||
| self._assert_all_activations_consumed() | ||
|
|
||
| def _run_samples(self, module: Module) -> list[torch.Tensor]: | ||
|
|
@@ -499,15 +547,40 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]: | |
| for output in outputs | ||
| ] | ||
|
|
||
| def _run_samples_preallocated( | ||
| self, module: Module, fp16_out_concat: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Run samples through module and accumulate outputs in preallocated tensor. | ||
|
|
||
| :param module: Module to run samples through | ||
| :param fp16_out_concat: Reference tensor to determine size and device | ||
| :return: Preallocated tensor filled with flattened outputs | ||
| """ | ||
| # Preallocate tensor with same size and device as fp16_out_concat | ||
| int_w_concat = torch.empty_like(fp16_out_concat) | ||
| offset = 0 | ||
|
|
||
| for batch_kwargs in self._parent_args_cache[module]: | ||
| output = module(**batch_kwargs) | ||
| # If tuple, assume that first argument is the output | ||
| result = output[0] if isinstance(output, tuple) else output | ||
| # Flatten and write to preallocated tensor | ||
| flattened = result.flatten() | ||
| size = flattened.numel() | ||
| int_w_concat[offset : offset + size] = flattened | ||
| offset += size | ||
|
|
||
| return int_w_concat | ||
|
|
||
| def _compute_best_scale( | ||
| self, | ||
| mapping: ResolvedMapping, | ||
| fp16_outputs: list[torch.Tensor], | ||
| orig_layer_weights: dict[torch.nn.Module, torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Select best scales for a given mapping in a grid search | ||
| Best scales are those that minimize MSE loss of quantized weight | ||
| outputs compared to fp16_outputs | ||
| Select best scales for a given mapping via grid search. | ||
| Best scales minimize MSE loss between FP16 baseline and quantized outputs. | ||
|
|
||
| L(s) = || Q(W * s) (s^-1 * X) - W * X || | ||
| Q: weight quantization function | _pseudo_quantize_tensor(W * s) | ||
|
|
@@ -516,20 +589,21 @@ def _compute_best_scale( | |
| s: per channel scaling factor | s^-1 * X | ||
|
|
||
| :param mapping: best scales will be found for the ResolvedMapping. | ||
| :param fp16_outputs: output of mapping.parent in unquantized case, | ||
| one tensor for each batch. | ||
| :return: tensor of best scales, one for each channel | ||
| :param orig_layer_weights: Original weights for balance layers | ||
| :return: Tensor of best scales, one per channel (on CPU) | ||
| """ | ||
| history = [] | ||
| best_ratio = -1 | ||
| best_scales = None | ||
| best_error = float("inf") | ||
| fp16_out_concat = self._fp16_baseline_cache[mapping.parent][0] | ||
|
|
||
| org_sd = { | ||
| k: v.cpu() | ||
| for k, v in mapping.parent.state_dict().items() | ||
| if v.device != torch.device("meta") | ||
| } | ||
| balance_layers_to_patch = [ | ||
| balance_layer | ||
| for balance_layer in mapping.balance_layers | ||
| if hasattr(balance_layer, "quantization_scheme") | ||
| and hasattr(balance_layer.quantization_scheme, "weights") | ||
| ] | ||
|
|
||
| device = get_execution_device(mapping.parent) | ||
|
|
||
|
|
@@ -549,12 +623,6 @@ def _compute_best_scale( | |
|
|
||
| # Where appropriate, replace observers with memoryless_minmax | ||
| # for duration of grid search | ||
| balance_layers_to_patch = [ | ||
| balance_layer | ||
| for balance_layer in mapping.balance_layers | ||
| if hasattr(balance_layer, "quantization_scheme") | ||
| and hasattr(balance_layer.quantization_scheme, "weights") | ||
| ] | ||
| with patch_attrs( | ||
| balance_layers_to_patch, | ||
| "weight_observer", | ||
|
|
@@ -586,38 +654,44 @@ def _compute_best_scale( | |
| scales[torch.isinf(scales)] = 1 | ||
| scales[torch.isnan(scales)] = 1 | ||
|
|
||
| # Q(W * s) | ||
| # Q(W * s) - Apply current scale ratio to original weights and quantize | ||
| for balance_layer in balance_layers_to_patch: | ||
| if not hasattr(balance_layer, "quantization_scheme") or not hasattr( | ||
| balance_layer.quantization_scheme, "weights" | ||
| ): | ||
| continue | ||
|
|
||
| w_qscheme = balance_layer.quantization_scheme.weights | ||
| balance_layer.weight.mul_(_scalesview) | ||
|
|
||
| balance_layer.weight.data.copy_( | ||
| orig_layer_weights[balance_layer] * _scalesview | ||
| ) | ||
|
|
||
| call_observer( | ||
| balance_layer, | ||
| "weight", | ||
| balance_layer.weight, | ||
| # TODO test should_calculate_gparam for nvfp4 support | ||
| ) | ||
| update_offload_parameter( | ||
| balance_layer, | ||
| "weight", | ||
| balance_layer.weight.data = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice job avoiding writing to the offload
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we don't need
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we can just keep in memory and not mess with that. |
||
| forward_quantize( | ||
| balance_layer, | ||
| balance_layer.weight.data, | ||
| balance_layer.weight, | ||
| "weight", | ||
| w_qscheme, | ||
| ) | ||
| / _scalesview, | ||
| ) | ||
| / _scalesview | ||
| ).to(balance_layer.weight.dtype) | ||
HDCharles marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # W * X | ||
| int_w_outputs = self._run_samples(mapping.parent) | ||
| # W * X - Run forward passes and accumulate outputs | ||
| int_w_concat = self._run_samples_preallocated( | ||
| mapping.parent, fp16_out_concat | ||
| ) | ||
|
|
||
| # compute mean squared error (L2 norm) | ||
| loss = self._compute_loss(fp16_outputs, int_w_outputs) | ||
| # Compute mean squared error (L2 norm) | ||
| loss = torch.nn.functional.mse_loss( | ||
| fp16_out_concat, int_w_concat | ||
| ).item() | ||
|
|
||
| history.append( | ||
| {"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss} | ||
|
|
@@ -627,8 +701,6 @@ def _compute_best_scale( | |
| best_ratio = ratio | ||
| best_scales = scales.clone() | ||
|
|
||
HDCharles marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mapping.parent.load_state_dict(org_sd, strict=False) | ||
|
|
||
| if best_ratio == -1: | ||
| logger.debug(history) | ||
| raise Exception( | ||
|
|
@@ -644,27 +716,6 @@ def _compute_best_scale( | |
|
|
||
| return best_scales.detach().cpu() | ||
|
|
||
| @torch.no_grad() | ||
| def _compute_loss( | ||
| self, | ||
| fp16_outputs: list[torch.Tensor], | ||
| int_w_outputs: list[torch.Tensor], | ||
| ) -> float: | ||
| loss = 0.0 | ||
| num_elements = 0 | ||
|
|
||
| # Compute the MSE loss for each batch | ||
| for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): | ||
| loss += torch.nn.functional.mse_loss( | ||
| fp16_batch, int_w_batch.to(fp16_batch.device) | ||
| ).item() | ||
| num_elements += fp16_batch.numel() | ||
|
|
||
| # Normalize the loss by the total number of elements | ||
| loss /= num_elements | ||
|
|
||
| return loss | ||
|
|
||
| def _assert_all_activations_consumed(self): | ||
| """ | ||
| Confirm all activations have been consumed | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are now caching output activations for every mapping in a given subgraph, wouldn't this increase memory requirements quite a bit, especially for MoE models? For which model are you seeing the 30% memory increase that you mention in the summary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hold on i'm rewriting this, i didn't realize by default we don't enable offloading so all my measurements were off. we do need to cache this but not on gpu and we can offload it