Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 117 additions & 66 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,23 @@ class AWQModifier(Modifier, QuantizationMixin):

- on_initialize
- resolve mappings
- capture kwargs needed for forward passes into modules
- set up hooks to capture forward pass kwargs and FP16 baseline outputs
- on_start
- set up activation cache hooks to capture input activations
to balance layers
- on sequential epoch end
- apply smoothing to each smoothing layer
- consume cached activations across all batches
- clear cached activations as they are used
- find best smoothing scale for each smoothing layer
- apply to model weights
- find best smoothing scale for each smoothing layer via grid search
- apply best scales to model weights
- raise error if any unused activations remain
- on_end
- re-run logic of sequential epoch end (in case of basic pipeline)
- set scales and zero points
- remove activation hooks
- on_finalize
- clear resolved mappings and captured activations
- clear resolved mappings, cached activations, and FP16 baseline cache

:param sequential_targets: list of module names to compress in
the same calibration pass
Expand Down Expand Up @@ -158,6 +158,12 @@ class AWQModifier(Modifier, QuantizationMixin):
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
default_factory=dict
)
# Cache FP16 baseline outputs for each parent module
# During calibration: list of flattened tensors, one per batch
# After calibration: single concatenated tensor
_fp16_baseline_cache: dict[Module, list[torch.Tensor] | torch.Tensor] = PrivateAttr(
default_factory=dict
)

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Expand Down Expand Up @@ -262,6 +268,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:

self._parent_args_cache.clear()
self._smooth_activation_means.clear()
self._fp16_baseline_cache.clear()
self._resolved_mappings.clear()

return True
Expand Down Expand Up @@ -337,8 +344,10 @@ def _set_resolved_mappings(self, model: Module) -> None:

def _setup_activation_cache_hooks(self) -> None:
"""
Attach a forward hook to each activation we want to smooth. This allows us to
calculate the dynamic range during calibration
Attach forward hooks to cache data needed for AWQ smoothing:
1. Parent module kwargs (for replaying forward passes during grid search)
2. FP16 baseline outputs (flattened, to avoid recomputing during grid search)
3. Smooth layer input activations (to calculate activation means)
"""

def cache_parent_kwargs_hook(
Expand All @@ -349,6 +358,17 @@ def cache_parent_kwargs_hook(
values = inspect.signature(module.forward).bind(*args, **kwargs)
self._parent_args_cache[module].append(values.arguments)

def cache_fp16_baseline_hook(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are now caching output activations for every mapping in a given subgraph, wouldn't this increase memory requirements quite a bit, especially for MoE models? For which model are you seeing the 30% memory increase that you mention in the summary?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hold on i'm rewriting this, i didn't realize by default we don't enable offloading so all my measurements were off. we do need to cache this but not on gpu and we can offload it

module: Module,
args: tuple[torch.Tensor, ...],
output: torch.Tensor,
):
# Extract first element if tuple (same logic as _run_samples)
result = output[0] if isinstance(output, tuple) else output
if module not in self._fp16_baseline_cache:
self._fp16_baseline_cache[module] = []
self._fp16_baseline_cache[module].append(result.detach().flatten())

def create_cache_smooth_activations_hook_fn(smooth_name):
def cache_smooth_activations_hook(
_module: Module,
Expand Down Expand Up @@ -378,6 +398,12 @@ def cache_smooth_activations_hook(
with_kwargs=True,
)

self.register_hook(
mapping.parent,
cache_fp16_baseline_hook,
"forward",
)

# input activations to balance layers needed for loss function
# storing inputs to first balance layer is sufficient
# other balance layers get the same input
Expand All @@ -392,7 +418,7 @@ def _apply_smoothing(self, model: Module) -> None:
"""
Calculate the best scaling factors for each layer to smooth activations and
apply the scaling factors to the weights of the next layer to offset the
smoothing
smoothing.

:param model: model to apply smoothing to
"""
Expand All @@ -413,18 +439,32 @@ def _apply_smoothing(self, model: Module) -> None:
calibration_forward_context(model),
HooksMixin.disable_hooks(),
):
# Compute output of unquantized module
fp16_outputs = self._run_samples(parent_module)
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
# Retrieve cached FP16 baseline outputs (collected during calibration)
if (
parent_module not in self._fp16_baseline_cache
or len(self._fp16_baseline_cache[parent_module]) == 0
):
logger.info(
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
"found to scale. This can occasionally occur in MoE models "
"when certain experts are not activated by calibration samples."
)
del self._smooth_activation_means[mapping.smooth_name]
continue

# Concatenate FP16 baseline outputs (if still a list from calibration)
# This happens once per parent module; subsequent mappings reuse it
# Store as single-element list to maintain consistent access pattern
if len(self._fp16_baseline_cache[parent_module]) > 1:
self._fp16_baseline_cache[parent_module] = [
torch.cat(self._fp16_baseline_cache[parent_module])
]

if not all(
[fp16_output.isfinite().all() for fp16_output in fp16_outputs]
[
ref.isfinite().all()
for ref in self._fp16_baseline_cache[parent_module]
]
):
logger.warning(
f"Skipping smooth_layer {mapping.smooth_name}, NaN or inf "
Expand All @@ -438,7 +478,14 @@ def _apply_smoothing(self, model: Module) -> None:
del self._smooth_activation_means[mapping.smooth_name]
continue

best_scales = self._compute_best_scale(mapping, fp16_outputs)
orig_layer_weights = {
balance_layer: balance_layer.weight.clone()
for balance_layer in mapping.balance_layers
if hasattr(balance_layer, "quantization_scheme")
and hasattr(balance_layer.quantization_scheme, "weights")
}

best_scales = self._compute_best_scale(mapping, orig_layer_weights)

@torch.no_grad()
def _smooth(module: Module):
Expand All @@ -447,7 +494,7 @@ def _smooth(module: Module):
update_offload_parameter(
module,
"weight",
module.weight.mul_(scales.view(1, -1)),
orig_layer_weights[module] * scales.view(1, -1),
)
elif module == smooth_layer:
if module.weight.ndim == 1:
Expand Down Expand Up @@ -487,6 +534,7 @@ def _smooth(module: Module):

for v in self._parent_args_cache.values():
v.batch_intermediates.clear()
self._fp16_baseline_cache.clear()
self._assert_all_activations_consumed()

def _run_samples(self, module: Module) -> list[torch.Tensor]:
Expand All @@ -499,15 +547,40 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]:
for output in outputs
]

def _run_samples_preallocated(
self, module: Module, fp16_out_concat: torch.Tensor
) -> torch.Tensor:
"""
Run samples through module and accumulate outputs in preallocated tensor.

:param module: Module to run samples through
:param fp16_out_concat: Reference tensor to determine size and device
:return: Preallocated tensor filled with flattened outputs
"""
# Preallocate tensor with same size and device as fp16_out_concat
int_w_concat = torch.empty_like(fp16_out_concat)
offset = 0

for batch_kwargs in self._parent_args_cache[module]:
output = module(**batch_kwargs)
# If tuple, assume that first argument is the output
result = output[0] if isinstance(output, tuple) else output
# Flatten and write to preallocated tensor
flattened = result.flatten()
size = flattened.numel()
int_w_concat[offset : offset + size] = flattened
offset += size

return int_w_concat

def _compute_best_scale(
self,
mapping: ResolvedMapping,
fp16_outputs: list[torch.Tensor],
orig_layer_weights: dict[torch.nn.Module, torch.Tensor],
) -> torch.Tensor:
"""
Select best scales for a given mapping in a grid search
Best scales are those that minimize MSE loss of quantized weight
outputs compared to fp16_outputs
Select best scales for a given mapping via grid search.
Best scales minimize MSE loss between FP16 baseline and quantized outputs.

L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | _pseudo_quantize_tensor(W * s)
Expand All @@ -516,20 +589,21 @@ def _compute_best_scale(
s: per channel scaling factor | s^-1 * X

:param mapping: best scales will be found for the ResolvedMapping.
:param fp16_outputs: output of mapping.parent in unquantized case,
one tensor for each batch.
:return: tensor of best scales, one for each channel
:param orig_layer_weights: Original weights for balance layers
:return: Tensor of best scales, one per channel (on CPU)
"""
history = []
best_ratio = -1
best_scales = None
best_error = float("inf")
fp16_out_concat = self._fp16_baseline_cache[mapping.parent][0]

org_sd = {
k: v.cpu()
for k, v in mapping.parent.state_dict().items()
if v.device != torch.device("meta")
}
balance_layers_to_patch = [
balance_layer
for balance_layer in mapping.balance_layers
if hasattr(balance_layer, "quantization_scheme")
and hasattr(balance_layer.quantization_scheme, "weights")
]

device = get_execution_device(mapping.parent)

Expand All @@ -549,12 +623,6 @@ def _compute_best_scale(

# Where appropriate, replace observers with memoryless_minmax
# for duration of grid search
balance_layers_to_patch = [
balance_layer
for balance_layer in mapping.balance_layers
if hasattr(balance_layer, "quantization_scheme")
and hasattr(balance_layer.quantization_scheme, "weights")
]
with patch_attrs(
balance_layers_to_patch,
"weight_observer",
Expand Down Expand Up @@ -586,38 +654,44 @@ def _compute_best_scale(
scales[torch.isinf(scales)] = 1
scales[torch.isnan(scales)] = 1

# Q(W * s)
# Q(W * s) - Apply current scale ratio to original weights and quantize
for balance_layer in balance_layers_to_patch:
if not hasattr(balance_layer, "quantization_scheme") or not hasattr(
balance_layer.quantization_scheme, "weights"
):
continue

w_qscheme = balance_layer.quantization_scheme.weights
balance_layer.weight.mul_(_scalesview)

balance_layer.weight.data.copy_(
orig_layer_weights[balance_layer] * _scalesview
)

call_observer(
balance_layer,
"weight",
balance_layer.weight,
# TODO test should_calculate_gparam for nvfp4 support
)
update_offload_parameter(
balance_layer,
"weight",
balance_layer.weight.data = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job avoiding writing to the offload

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need update_offload_parameter here because it all happens on the exec device, and the smooth function is done elsewhere after best_scales are calculated?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can just keep in memory and not mess with that.

forward_quantize(
balance_layer,
balance_layer.weight.data,
balance_layer.weight,
"weight",
w_qscheme,
)
/ _scalesview,
)
/ _scalesview
).to(balance_layer.weight.dtype)

# W * X
int_w_outputs = self._run_samples(mapping.parent)
# W * X - Run forward passes and accumulate outputs
int_w_concat = self._run_samples_preallocated(
mapping.parent, fp16_out_concat
)

# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_outputs, int_w_outputs)
# Compute mean squared error (L2 norm)
loss = torch.nn.functional.mse_loss(
fp16_out_concat, int_w_concat
).item()

history.append(
{"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss}
Expand All @@ -627,8 +701,6 @@ def _compute_best_scale(
best_ratio = ratio
best_scales = scales.clone()

mapping.parent.load_state_dict(org_sd, strict=False)

if best_ratio == -1:
logger.debug(history)
raise Exception(
Expand All @@ -644,27 +716,6 @@ def _compute_best_scale(

return best_scales.detach().cpu()

@torch.no_grad()
def _compute_loss(
self,
fp16_outputs: list[torch.Tensor],
int_w_outputs: list[torch.Tensor],
) -> float:
loss = 0.0
num_elements = 0

# Compute the MSE loss for each batch
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
loss += torch.nn.functional.mse_loss(
fp16_batch, int_w_batch.to(fp16_batch.device)
).item()
num_elements += fp16_batch.numel()

# Normalize the loss by the total number of elements
loss /= num_elements

return loss

def _assert_all_activations_consumed(self):
"""
Confirm all activations have been consumed
Expand Down