Skip to content
Merged
84 changes: 56 additions & 28 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.modifiers.utils.pytorch_helpers import is_moe_model
from llmcompressor.observers.base import Observer
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.sentinel import Sentinel
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
get_module_to_name_dict,
Expand Down Expand Up @@ -101,8 +103,8 @@ class AWQModifier(Modifier, QuantizationMixin):
- apply smoothing to each smoothing layer
- consume cached activations across all batches
- clear cached activations as they are used
- find best smoothing scale for each smoothing layer
- apply to model weights
- find best smoothing scale for each smoothing layer via grid search
- apply best scales to model weights
- raise error if any unused activations remain
- on_end
- re-run logic of sequential epoch end (in case of basic pipeline)
Expand Down Expand Up @@ -145,7 +147,7 @@ class AWQModifier(Modifier, QuantizationMixin):
# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: str | list[str] | None = None
mappings: list[AWQMapping] | None = None
offload_device: torch.device | None = None
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
duo_scaling: bool | Literal["both"] = True
n_grid: int = 20

Expand All @@ -161,6 +163,10 @@ class AWQModifier(Modifier, QuantizationMixin):
)
# List to store error metrics for each layer
_error_metrics: list[dict] = PrivateAttr(default_factory=list)
# Cache FP16 baseline outputs for each parent module, one list of tensors per batch
_fp16_baseline_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down Expand Up @@ -199,6 +205,21 @@ def on_initialize(self, state: State, **kwargs) -> bool:
architecture=state.model.__class__.__name__
)

# Set default offload_device
if self.offload_device == Sentinel("not_provided"):
# Check if we have a MoE model
if is_moe_model(state.model):
self.offload_device = torch.device("cpu")
logger.info(
"MoE model detected: setting offload_device to 'cpu' by default "
"to reduce memory usage. You can override this by explicitly "
"setting offload_device in your recipe."
)
else:
# For non-MoE models, convert sentinel to None
# (no offloading by default)
self.offload_device = None

self._set_resolved_mappings(state.model)

return True
Expand Down Expand Up @@ -473,16 +494,28 @@ def _apply_smoothing(self, model: Module) -> None:
del self._smooth_activation_means[mapping.smooth_name]
continue

best_scales = self._compute_best_scale(mapping, fp16_outputs)
orig_layer_weights = {
balance_layer: balance_layer.weight.clone()
for balance_layer in mapping.balance_layers
if hasattr(balance_layer, "quantization_scheme")
and hasattr(balance_layer.quantization_scheme, "weights")
}

best_scales = self._compute_best_scale(
mapping, fp16_outputs, orig_layer_weights
)

@torch.no_grad()
def _smooth(module: Module):
def _smooth(
module: Module, orig_layer_weights: dict[Module, torch.Tensor]
):
scales = best_scales.to(module.weight.device)
if module in balance_layers:
update_offload_parameter(
module,
"weight",
module.weight.mul_(scales.view(1, -1)),
orig_layer_weights[module].to(module.weight.device)
* scales.view(1, -1),
)
elif module == smooth_layer:
if module.weight.ndim == 1:
Expand All @@ -509,16 +542,18 @@ def _smooth(module: Module):
)

for layer in balance_layers:
_smooth(layer)
_smooth(smooth_layer)
_smooth(layer, orig_layer_weights)
_smooth(smooth_layer, orig_layer_weights)

# remove caches needed to smooth this mapping
del self._smooth_activation_means[mapping.smooth_name]
del orig_layer_weights

for v in self._parent_args_cache.values():
v.batch_intermediates.clear()
self._assert_all_activations_consumed()

@torch.no_grad()
def _run_samples(self, module: Module) -> list[torch.Tensor]:
outputs = [
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
Expand All @@ -533,6 +568,7 @@ def _compute_best_scale(
self,
mapping: ResolvedMapping,
fp16_outputs: list[torch.Tensor],
orig_layer_weights: dict[torch.nn.Module, torch.Tensor],
) -> torch.Tensor:
"""
Select best scales for a given mapping in a grid search
Expand All @@ -556,12 +592,6 @@ def _compute_best_scale(
best_error = float("inf")
initial_error = None

org_sd = {
k: v.cpu()
for k, v in mapping.parent.state_dict().items()
if v.device != torch.device("meta")
}

device = get_execution_device(mapping.parent)

x_mean = self._smooth_activation_means[mapping.smooth_name][0].to(device)
Expand Down Expand Up @@ -632,8 +662,11 @@ def _compute_best_scale(
continue

w_qscheme = balance_layer.quantization_scheme.weights
balance_layer.weight.mul_(_scalesview)
# For TENSOR_GROUP (nvfp4), need to calculate global scale
balance_layer.weight.data.copy_(
orig_layer_weights[balance_layer].to(_scalesview.device)
* _scalesview
)

should_calculate_gparam = (
w_qscheme.strategy == QuantizationStrategy.TENSOR_GROUP
)
Expand All @@ -643,17 +676,15 @@ def _compute_best_scale(
balance_layer.weight,
should_calculate_gparam=should_calculate_gparam,
)
update_offload_parameter(
balance_layer,
"weight",
balance_layer.weight.data = (
forward_quantize(
balance_layer,
balance_layer.weight.data,
balance_layer.weight,
"weight",
w_qscheme,
)
/ _scalesview,
)
/ _scalesview
).to(balance_layer.weight.dtype)

# Apply fused global scales for TENSOR_GROUP during grid search
# to match inference behavior
Expand All @@ -669,6 +700,7 @@ def _compute_best_scale(

# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_outputs, int_w_outputs)
del int_w_outputs

if initial_error is None:
initial_error = loss
Expand All @@ -682,8 +714,6 @@ def _compute_best_scale(
best_scales = scales.clone()
pbar.set_postfix({"best_error": f"{best_error:.3e}"})

mapping.parent.load_state_dict(org_sd, strict=False)

if best_ratio == -1:
logger.debug(history)
raise Exception(
Expand Down Expand Up @@ -731,13 +761,11 @@ def _compute_loss(
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
loss += torch.nn.functional.mse_loss(
fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum"
).item()
)
num_elements += fp16_batch.numel()

# Normalize the loss by the total number of elements
loss /= num_elements

return loss
return (loss / num_elements).item()

def _log_error_metrics(self):
"""
Expand Down