-
Notifications
You must be signed in to change notification settings - Fork 400
[AWQ] Support for a module used in an AWQ mapping to be unquantized && other bug fixes #2158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ea79421
09d0ecf
df6178c
9b867dd
8978e3b
f58d013
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -120,9 +120,15 @@ class AWQModifier(Modifier, QuantizationMixin): | |||||||||||||||||
| to smoothed) and the second entry is the layer whose output is scaled to | ||||||||||||||||||
| achieve the smoothing. | ||||||||||||||||||
| If regex is used, it matches layers with the largest overlap in module name. | ||||||||||||||||||
| :param ignore: list of layers to ignore, even if they match a regex in mappings. | ||||||||||||||||||
| It should match the name of layers whose outputs are scaled to achieve | ||||||||||||||||||
| smoothing (the second entry of the mappings list). | ||||||||||||||||||
| :param ignore: list of layers to exclude from quantization. By default, layers in | ||||||||||||||||||
| `ignore` are also excluded from AWQ smoothing. Use `force_balance` to override | ||||||||||||||||||
| this and smooth certain layers despite being in `ignore` (e.g., MoE gate layers). | ||||||||||||||||||
| :param force_balance: list of layers to include in AWQ smoothing even if they are in | ||||||||||||||||||
| `ignore`. This allows you to smooth but not quantize specific layers. For example, | ||||||||||||||||||
| if `ignore=["lm_head", "re:.*mlp.gate"]` and `force_balance=["re:.*mlp.gate"]`, | ||||||||||||||||||
| then mlp.gate will be smoothed but not quantized, while lm_head will be neither | ||||||||||||||||||
| smoothed nor quantized. If None, all layers in `ignore` are excluded from smoothing | ||||||||||||||||||
| (default backward-compatible behavior). | ||||||||||||||||||
| :param offload_device: offload cached args to this device, which reduces memory | ||||||||||||||||||
| requirements but requires more time to move data between cpu and execution | ||||||||||||||||||
| device. Defaults to None, so cached args are not offloaded. Consider setting | ||||||||||||||||||
|
|
@@ -145,6 +151,7 @@ class AWQModifier(Modifier, QuantizationMixin): | |||||||||||||||||
| # User-provided vars (in addition to QuantizationMixin args) | ||||||||||||||||||
| sequential_targets: str | list[str] | None = None | ||||||||||||||||||
| mappings: list[AWQMapping] | None = None | ||||||||||||||||||
| force_balance: list[str] | None = None | ||||||||||||||||||
| offload_device: torch.device | None = None | ||||||||||||||||||
| duo_scaling: bool | Literal["both"] = True | ||||||||||||||||||
| n_grid: int = 20 | ||||||||||||||||||
|
|
@@ -281,9 +288,34 @@ def _set_resolved_mappings(self, model: Module) -> None: | |||||||||||||||||
| """ | ||||||||||||||||||
| resolved_mappings: list[ResolvedMapping] = [] | ||||||||||||||||||
| module_to_name = get_module_to_name_dict(model) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Compute the effective ignore list for AWQ smoothing: | ||||||||||||||||||
| # Start with self.ignore, then remove any layers in force_balance | ||||||||||||||||||
| # If force_balance is None, use self.ignore as-is (backward compatible) | ||||||||||||||||||
| if self.force_balance is not None: | ||||||||||||||||||
| # Validate that all force_balance layers are in ignore | ||||||||||||||||||
| ignore_set = set(self.ignore or []) | ||||||||||||||||||
| force_balance_set = set(self.force_balance) | ||||||||||||||||||
| invalid_force_balance = force_balance_set - ignore_set | ||||||||||||||||||
| if invalid_force_balance: | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
| f"force_balance contains layers that are not in ignore: " | ||||||||||||||||||
| f"{invalid_force_balance}. force_balance should only contain " | ||||||||||||||||||
| f"layers that are in ignore but you want to smooth anyway." | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Remove force_balance layers from ignore list for AWQ matching | ||||||||||||||||||
| awq_ignore = [ | ||||||||||||||||||
| ign for ign in (self.ignore or []) | ||||||||||||||||||
| if ign not in self.force_balance | ||||||||||||||||||
| ] | ||||||||||||||||||
|
Comment on lines
+308
to
+311
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list comprehension for creating
Suggested change
|
||||||||||||||||||
| else: | ||||||||||||||||||
| # Default: exclude everything in ignore from smoothing | ||||||||||||||||||
| awq_ignore = self.ignore | ||||||||||||||||||
|
|
||||||||||||||||||
| for mapping in self.mappings: | ||||||||||||||||||
| for smooth_layers, *nested_balance_layers in match_modules_set( | ||||||||||||||||||
| model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore | ||||||||||||||||||
| model, (mapping.smooth_layer, *mapping.balance_layers), awq_ignore | ||||||||||||||||||
| ): | ||||||||||||||||||
| if len(smooth_layers) > 1: | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
|
|
@@ -356,11 +388,12 @@ def cache_smooth_activations_hook( | |||||||||||||||||
| args: tuple[torch.Tensor, ...], | ||||||||||||||||||
| _output: torch.Tensor, | ||||||||||||||||||
| ): | ||||||||||||||||||
| self._smooth_activation_means[smooth_name] = _accumulate_mean( | ||||||||||||||||||
| # Assume that first argument is the input | ||||||||||||||||||
| args[0].cpu().abs().detach().flatten(0, -2), | ||||||||||||||||||
|
|
||||||||||||||||||
| act_mean, count = _accumulate_mean( | ||||||||||||||||||
| args[0].abs().detach().flatten(0, -2), | ||||||||||||||||||
| self._smooth_activation_means.get(smooth_name, None), | ||||||||||||||||||
| ) | ||||||||||||||||||
| self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count) | ||||||||||||||||||
|
|
||||||||||||||||||
| return cache_smooth_activations_hook | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -525,6 +558,7 @@ def _compute_best_scale( | |||||||||||||||||
| best_ratio = -1 | ||||||||||||||||||
| best_scales = None | ||||||||||||||||||
| best_error = float("inf") | ||||||||||||||||||
| initial_error = None | ||||||||||||||||||
|
|
||||||||||||||||||
| org_sd = { | ||||||||||||||||||
| k: v.cpu() | ||||||||||||||||||
|
|
@@ -569,7 +603,14 @@ def _compute_best_scale( | |||||||||||||||||
| for balance_layer in balance_layers_to_patch | ||||||||||||||||||
| ], | ||||||||||||||||||
| ): | ||||||||||||||||||
| for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): | ||||||||||||||||||
| total_iterations = n_grid * len(duo_scalings) | ||||||||||||||||||
| pbar = tqdm( | ||||||||||||||||||
| product(range(n_grid), duo_scalings), | ||||||||||||||||||
| total=total_iterations, | ||||||||||||||||||
| desc=f"Grid search for {mapping.smooth_name}", | ||||||||||||||||||
| leave=False, | ||||||||||||||||||
| ) | ||||||||||||||||||
| for grid_idx, use_duo_scaling in pbar: | ||||||||||||||||||
| # create new scales | ||||||||||||||||||
| ratio = grid_idx / n_grid | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -588,14 +629,15 @@ def _compute_best_scale( | |||||||||||||||||
| scales[torch.isnan(scales)] = 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| # Q(W * s) | ||||||||||||||||||
| for balance_layer in balance_layers_to_patch: | ||||||||||||||||||
| for balance_layer in mapping.balance_layers: | ||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Iterating over
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gemini’s answer seems incorrect. We also need to account for nn.Modules that should not be quantized, so that the model produced by AWQ remains functionally equivalent to the original network. |
||||||||||||||||||
| balance_layer.weight.mul_(_scalesview) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not hasattr(balance_layer, "quantization_scheme") or not hasattr( | ||||||||||||||||||
| balance_layer.quantization_scheme, "weights" | ||||||||||||||||||
| ): | ||||||||||||||||||
| continue | ||||||||||||||||||
|
|
||||||||||||||||||
| w_qscheme = balance_layer.quantization_scheme.weights | ||||||||||||||||||
| balance_layer.weight.mul_(_scalesview) | ||||||||||||||||||
| call_observer( | ||||||||||||||||||
| balance_layer, | ||||||||||||||||||
| "weight", | ||||||||||||||||||
|
|
@@ -620,13 +662,21 @@ def _compute_best_scale( | |||||||||||||||||
| # compute mean squared error (L2 norm) | ||||||||||||||||||
| loss = self._compute_loss(fp16_outputs, int_w_outputs) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Track initial error for logging | ||||||||||||||||||
| if initial_error is None: | ||||||||||||||||||
| initial_error = loss | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| history.append( | ||||||||||||||||||
| {"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss} | ||||||||||||||||||
| ) | ||||||||||||||||||
| if loss < best_error: | ||||||||||||||||||
| best_error = loss | ||||||||||||||||||
| best_ratio = ratio | ||||||||||||||||||
| best_scales = scales.clone() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Update progress bar with best error | ||||||||||||||||||
| pbar.set_postfix({"best_error": f"{best_error:.3e}"}) | ||||||||||||||||||
|
|
||||||||||||||||||
| mapping.parent.load_state_dict(org_sd, strict=False) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -639,6 +689,16 @@ def _compute_best_scale( | |||||||||||||||||
| "https://github.com/vllm-project/llm-compressor/issues" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Log final results | ||||||||||||||||||
| err_reduction = best_error / initial_error | ||||||||||||||||||
| logger.info( | ||||||||||||||||||
| f"AWQ grid search for {mapping.smooth_name}: " | ||||||||||||||||||
| f"initial error = {initial_error:.3e}, " | ||||||||||||||||||
| f"best error = {best_error:.3e}, " | ||||||||||||||||||
| f"error reduction rate (best/initial) = {err_reduction * 100:.3f}%" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| assert ( | ||||||||||||||||||
| torch.isnan(best_scales).sum() == 0 | ||||||||||||||||||
| ), f"Nan found in scales: {best_scales}" | ||||||||||||||||||
|
|
@@ -657,7 +717,7 @@ def _compute_loss( | |||||||||||||||||
| # Compute the MSE loss for each batch | ||||||||||||||||||
| for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): | ||||||||||||||||||
| loss += torch.nn.functional.mse_loss( | ||||||||||||||||||
| fp16_batch, int_w_batch.to(fp16_batch.device) | ||||||||||||||||||
| fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum" | ||||||||||||||||||
| ).item() | ||||||||||||||||||
| num_elements += fp16_batch.numel() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -863,9 +923,10 @@ def _accumulate_mean( | |||||||||||||||||
| sum_added = inp.sum(dim=0) | ||||||||||||||||||
| num_added = inp.size(0) | ||||||||||||||||||
| if prev_mean_and_count is None: | ||||||||||||||||||
| return sum_added, num_added | ||||||||||||||||||
| return sum_added / num_added, num_added | ||||||||||||||||||
|
|
||||||||||||||||||
| prev_mean, prev_count = prev_mean_and_count | ||||||||||||||||||
| prev_mean = prev_mean.to(inp.device) | ||||||||||||||||||
|
|
||||||||||||||||||
| prev_sum = prev_mean * prev_count | ||||||||||||||||||
| new_count = prev_count + num_added | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll revert these changes later