diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index cb7111c106..a17c3131a2 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -39,8 +39,10 @@ from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.modifiers.utils.pytorch_helpers import is_moe_model from llmcompressor.observers.base import Observer from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.sentinel import Sentinel from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( get_module_to_name_dict, @@ -101,8 +103,8 @@ class AWQModifier(Modifier, QuantizationMixin): - apply smoothing to each smoothing layer - consume cached activations across all batches - clear cached activations as they are used - - find best smoothing scale for each smoothing layer - - apply to model weights + - find best smoothing scale for each smoothing layer via grid search + - apply best scales to model weights - raise error if any unused activations remain - on_end - re-run logic of sequential epoch end (in case of basic pipeline) @@ -145,7 +147,7 @@ class AWQModifier(Modifier, QuantizationMixin): # User-provided vars (in addition to QuantizationMixin args) sequential_targets: str | list[str] | None = None mappings: list[AWQMapping] | None = None - offload_device: torch.device | None = None + offload_device: torch.device | None | Sentinel = Sentinel("not_provided") duo_scaling: bool | Literal["both"] = True n_grid: int = 20 @@ -161,6 +163,10 @@ class AWQModifier(Modifier, QuantizationMixin): ) # List to store error metrics for each layer _error_metrics: list[dict] = PrivateAttr(default_factory=list) + # Cache FP16 baseline outputs for each parent module, one list of tensors per batch + _fp16_baseline_cache: dict[Module, IntermediatesCache] = PrivateAttr( + default_factory=dict + ) def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -199,6 +205,21 @@ def on_initialize(self, state: State, **kwargs) -> bool: architecture=state.model.__class__.__name__ ) + # Set default offload_device + if self.offload_device == Sentinel("not_provided"): + # Check if we have a MoE model + if is_moe_model(state.model): + self.offload_device = torch.device("cpu") + logger.info( + "MoE model detected: setting offload_device to 'cpu' by default " + "to reduce memory usage. You can override this by explicitly " + "setting offload_device in your recipe." + ) + else: + # For non-MoE models, convert sentinel to None + # (no offloading by default) + self.offload_device = None + self._set_resolved_mappings(state.model) return True @@ -473,16 +494,28 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - best_scales = self._compute_best_scale(mapping, fp16_outputs) + orig_layer_weights = { + balance_layer: balance_layer.weight.clone() + for balance_layer in mapping.balance_layers + if hasattr(balance_layer, "quantization_scheme") + and hasattr(balance_layer.quantization_scheme, "weights") + } + + best_scales = self._compute_best_scale( + mapping, fp16_outputs, orig_layer_weights + ) @torch.no_grad() - def _smooth(module: Module): + def _smooth( + module: Module, orig_layer_weights: dict[Module, torch.Tensor] + ): scales = best_scales.to(module.weight.device) if module in balance_layers: update_offload_parameter( module, "weight", - module.weight.mul_(scales.view(1, -1)), + orig_layer_weights[module].to(module.weight.device) + * scales.view(1, -1), ) elif module == smooth_layer: if module.weight.ndim == 1: @@ -509,16 +542,18 @@ def _smooth(module: Module): ) for layer in balance_layers: - _smooth(layer) - _smooth(smooth_layer) + _smooth(layer, orig_layer_weights) + _smooth(smooth_layer, orig_layer_weights) # remove caches needed to smooth this mapping del self._smooth_activation_means[mapping.smooth_name] + del orig_layer_weights for v in self._parent_args_cache.values(): v.batch_intermediates.clear() self._assert_all_activations_consumed() + @torch.no_grad() def _run_samples(self, module: Module) -> list[torch.Tensor]: outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] @@ -533,6 +568,7 @@ def _compute_best_scale( self, mapping: ResolvedMapping, fp16_outputs: list[torch.Tensor], + orig_layer_weights: dict[torch.nn.Module, torch.Tensor], ) -> torch.Tensor: """ Select best scales for a given mapping in a grid search @@ -556,12 +592,6 @@ def _compute_best_scale( best_error = float("inf") initial_error = None - org_sd = { - k: v.cpu() - for k, v in mapping.parent.state_dict().items() - if v.device != torch.device("meta") - } - device = get_execution_device(mapping.parent) x_mean = self._smooth_activation_means[mapping.smooth_name][0].to(device) @@ -632,8 +662,11 @@ def _compute_best_scale( continue w_qscheme = balance_layer.quantization_scheme.weights - balance_layer.weight.mul_(_scalesview) - # For TENSOR_GROUP (nvfp4), need to calculate global scale + balance_layer.weight.data.copy_( + orig_layer_weights[balance_layer].to(_scalesview.device) + * _scalesview + ) + should_calculate_gparam = ( w_qscheme.strategy == QuantizationStrategy.TENSOR_GROUP ) @@ -643,17 +676,15 @@ def _compute_best_scale( balance_layer.weight, should_calculate_gparam=should_calculate_gparam, ) - update_offload_parameter( - balance_layer, - "weight", + balance_layer.weight.data = ( forward_quantize( balance_layer, - balance_layer.weight.data, + balance_layer.weight, "weight", w_qscheme, ) - / _scalesview, - ) + / _scalesview + ).to(balance_layer.weight.dtype) # Apply fused global scales for TENSOR_GROUP during grid search # to match inference behavior @@ -669,6 +700,7 @@ def _compute_best_scale( # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_outputs, int_w_outputs) + del int_w_outputs if initial_error is None: initial_error = loss @@ -682,8 +714,6 @@ def _compute_best_scale( best_scales = scales.clone() pbar.set_postfix({"best_error": f"{best_error:.3e}"}) - mapping.parent.load_state_dict(org_sd, strict=False) - if best_ratio == -1: logger.debug(history) raise Exception( @@ -731,13 +761,11 @@ def _compute_loss( for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): loss += torch.nn.functional.mse_loss( fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum" - ).item() + ) num_elements += fp16_batch.numel() # Normalize the loss by the total number of elements - loss /= num_elements - - return loss + return (loss / num_elements).item() def _log_error_metrics(self): """