diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 4c9644998f..b85789b419 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -53,6 +53,7 @@ def tokenize(sample): recipe = [ AWQModifier( ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], + force_balance=["re:.*mlp.gate$"], scheme="W4A16", targets=["Linear"], ), @@ -67,6 +68,13 @@ def tokenize(sample): num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) + + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-sym-new" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + # Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") @@ -76,9 +84,4 @@ def tokenize(sample): ) output = model.generate(input_ids, max_new_tokens=100) print(tokenizer.decode(output[0])) -print("==========================================\n\n") - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-sym" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) +print("==========================================\n\n") \ No newline at end of file diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0e5ca71b3b..bc6c35d9b5 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -120,9 +120,15 @@ class AWQModifier(Modifier, QuantizationMixin): to smoothed) and the second entry is the layer whose output is scaled to achieve the smoothing. If regex is used, it matches layers with the largest overlap in module name. - :param ignore: list of layers to ignore, even if they match a regex in mappings. - It should match the name of layers whose outputs are scaled to achieve - smoothing (the second entry of the mappings list). + :param ignore: list of layers to exclude from quantization. By default, layers in + `ignore` are also excluded from AWQ smoothing. Use `force_balance` to override + this and smooth certain layers despite being in `ignore` (e.g., MoE gate layers). + :param force_balance: list of layers to include in AWQ smoothing even if they are in + `ignore`. This allows you to smooth but not quantize specific layers. For example, + if `ignore=["lm_head", "re:.*mlp.gate"]` and `force_balance=["re:.*mlp.gate"]`, + then mlp.gate will be smoothed but not quantized, while lm_head will be neither + smoothed nor quantized. If None, all layers in `ignore` are excluded from smoothing + (default backward-compatible behavior). :param offload_device: offload cached args to this device, which reduces memory requirements but requires more time to move data between cpu and execution device. Defaults to None, so cached args are not offloaded. Consider setting @@ -145,6 +151,7 @@ class AWQModifier(Modifier, QuantizationMixin): # User-provided vars (in addition to QuantizationMixin args) sequential_targets: str | list[str] | None = None mappings: list[AWQMapping] | None = None + force_balance: list[str] | None = None offload_device: torch.device | None = None duo_scaling: bool | Literal["both"] = True n_grid: int = 20 @@ -281,9 +288,34 @@ def _set_resolved_mappings(self, model: Module) -> None: """ resolved_mappings: list[ResolvedMapping] = [] module_to_name = get_module_to_name_dict(model) + + # Compute the effective ignore list for AWQ smoothing: + # Start with self.ignore, then remove any layers in force_balance + # If force_balance is None, use self.ignore as-is (backward compatible) + if self.force_balance is not None: + # Validate that all force_balance layers are in ignore + ignore_set = set(self.ignore or []) + force_balance_set = set(self.force_balance) + invalid_force_balance = force_balance_set - ignore_set + if invalid_force_balance: + raise ValueError( + f"force_balance contains layers that are not in ignore: " + f"{invalid_force_balance}. force_balance should only contain " + f"layers that are in ignore but you want to smooth anyway." + ) + + # Remove force_balance layers from ignore list for AWQ matching + awq_ignore = [ + ign for ign in (self.ignore or []) + if ign not in self.force_balance + ] + else: + # Default: exclude everything in ignore from smoothing + awq_ignore = self.ignore + for mapping in self.mappings: for smooth_layers, *nested_balance_layers in match_modules_set( - model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore + model, (mapping.smooth_layer, *mapping.balance_layers), awq_ignore ): if len(smooth_layers) > 1: raise ValueError( @@ -356,11 +388,12 @@ def cache_smooth_activations_hook( args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): - self._smooth_activation_means[smooth_name] = _accumulate_mean( - # Assume that first argument is the input - args[0].cpu().abs().detach().flatten(0, -2), + + act_mean, count = _accumulate_mean( + args[0].abs().detach().flatten(0, -2), self._smooth_activation_means.get(smooth_name, None), ) + self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count) return cache_smooth_activations_hook @@ -525,6 +558,7 @@ def _compute_best_scale( best_ratio = -1 best_scales = None best_error = float("inf") + initial_error = None org_sd = { k: v.cpu() @@ -569,7 +603,14 @@ def _compute_best_scale( for balance_layer in balance_layers_to_patch ], ): - for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): + total_iterations = n_grid * len(duo_scalings) + pbar = tqdm( + product(range(n_grid), duo_scalings), + total=total_iterations, + desc=f"Grid search for {mapping.smooth_name}", + leave=False, + ) + for grid_idx, use_duo_scaling in pbar: # create new scales ratio = grid_idx / n_grid @@ -588,14 +629,15 @@ def _compute_best_scale( scales[torch.isnan(scales)] = 1 # Q(W * s) - for balance_layer in balance_layers_to_patch: + for balance_layer in mapping.balance_layers: + balance_layer.weight.mul_(_scalesview) + if not hasattr(balance_layer, "quantization_scheme") or not hasattr( balance_layer.quantization_scheme, "weights" ): continue w_qscheme = balance_layer.quantization_scheme.weights - balance_layer.weight.mul_(_scalesview) call_observer( balance_layer, "weight", @@ -620,6 +662,11 @@ def _compute_best_scale( # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_outputs, int_w_outputs) + # Track initial error for logging + if initial_error is None: + initial_error = loss + + history.append( {"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss} ) @@ -627,6 +674,9 @@ def _compute_best_scale( best_error = loss best_ratio = ratio best_scales = scales.clone() + + # Update progress bar with best error + pbar.set_postfix({"best_error": f"{best_error:.3e}"}) mapping.parent.load_state_dict(org_sd, strict=False) @@ -639,6 +689,16 @@ def _compute_best_scale( "https://github.com/vllm-project/llm-compressor/issues" ) + # Log final results + err_reduction = best_error / initial_error + logger.info( + f"AWQ grid search for {mapping.smooth_name}: " + f"initial error = {initial_error:.3e}, " + f"best error = {best_error:.3e}, " + f"error reduction rate (best/initial) = {err_reduction * 100:.3f}%" + ) + + assert ( torch.isnan(best_scales).sum() == 0 ), f"Nan found in scales: {best_scales}" @@ -657,7 +717,7 @@ def _compute_loss( # Compute the MSE loss for each batch for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): loss += torch.nn.functional.mse_loss( - fp16_batch, int_w_batch.to(fp16_batch.device) + fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum" ).item() num_elements += fp16_batch.numel() @@ -863,9 +923,10 @@ def _accumulate_mean( sum_added = inp.sum(dim=0) num_added = inp.size(0) if prev_mean_and_count is None: - return sum_added, num_added + return sum_added / num_added, num_added prev_mean, prev_count = prev_mean_and_count + prev_mean = prev_mean.to(inp.device) prev_sum = prev_mean * prev_count new_count = prev_count + num_added diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 9e50979c9d..8d6288f0f2 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -46,7 +46,7 @@ class AWQMapping: AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]), AWQMapping( "re:.*post_attention_layernorm$", - ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"], + ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$", "re:.*mlp.gate$"], ), AWQMapping( "re:.*up_proj$",