Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def tokenize(sample):
recipe = [
AWQModifier(
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
force_balance=["re:.*mlp.gate$"],
scheme="W4A16",
targets=["Linear"],
),
Expand All @@ -67,6 +68,13 @@ def tokenize(sample):
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)



# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-sym-new"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
Expand All @@ -76,9 +84,4 @@ def tokenize(sample):
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert these changes later


# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-sym"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
print("==========================================\n\n")
85 changes: 73 additions & 12 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,15 @@ class AWQModifier(Modifier, QuantizationMixin):
to smoothed) and the second entry is the layer whose output is scaled to
achieve the smoothing.
If regex is used, it matches layers with the largest overlap in module name.
:param ignore: list of layers to ignore, even if they match a regex in mappings.
It should match the name of layers whose outputs are scaled to achieve
smoothing (the second entry of the mappings list).
:param ignore: list of layers to exclude from quantization. By default, layers in
`ignore` are also excluded from AWQ smoothing. Use `force_balance` to override
this and smooth certain layers despite being in `ignore` (e.g., MoE gate layers).
:param force_balance: list of layers to include in AWQ smoothing even if they are in
`ignore`. This allows you to smooth but not quantize specific layers. For example,
if `ignore=["lm_head", "re:.*mlp.gate"]` and `force_balance=["re:.*mlp.gate"]`,
then mlp.gate will be smoothed but not quantized, while lm_head will be neither
smoothed nor quantized. If None, all layers in `ignore` are excluded from smoothing
(default backward-compatible behavior).
:param offload_device: offload cached args to this device, which reduces memory
requirements but requires more time to move data between cpu and execution
device. Defaults to None, so cached args are not offloaded. Consider setting
Expand All @@ -145,6 +151,7 @@ class AWQModifier(Modifier, QuantizationMixin):
# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: str | list[str] | None = None
mappings: list[AWQMapping] | None = None
force_balance: list[str] | None = None
offload_device: torch.device | None = None
duo_scaling: bool | Literal["both"] = True
n_grid: int = 20
Expand Down Expand Up @@ -281,9 +288,34 @@ def _set_resolved_mappings(self, model: Module) -> None:
"""
resolved_mappings: list[ResolvedMapping] = []
module_to_name = get_module_to_name_dict(model)

# Compute the effective ignore list for AWQ smoothing:
# Start with self.ignore, then remove any layers in force_balance
# If force_balance is None, use self.ignore as-is (backward compatible)
if self.force_balance is not None:
# Validate that all force_balance layers are in ignore
ignore_set = set(self.ignore or [])
force_balance_set = set(self.force_balance)
invalid_force_balance = force_balance_set - ignore_set
if invalid_force_balance:
raise ValueError(
f"force_balance contains layers that are not in ignore: "
f"{invalid_force_balance}. force_balance should only contain "
f"layers that are in ignore but you want to smooth anyway."
)

# Remove force_balance layers from ignore list for AWQ matching
awq_ignore = [
ign for ign in (self.ignore or [])
if ign not in self.force_balance
]
Comment on lines +308 to +311
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list comprehension for creating awq_ignore involves checking for membership in self.force_balance, which is a list. This results in a time complexity of O(N*M), where N is the length of self.ignore and M is the length of self.force_balance. You have already computed force_balance_set on line 298, which allows for O(1) average time complexity for membership checking. Using this set would make the operation more efficient, with a total complexity of O(N+M).

Suggested change
awq_ignore = [
ign for ign in (self.ignore or [])
if ign not in self.force_balance
]
awq_ignore = [
ign for ign in (self.ignore or [])
if ign not in force_balance_set
]

else:
# Default: exclude everything in ignore from smoothing
awq_ignore = self.ignore

for mapping in self.mappings:
for smooth_layers, *nested_balance_layers in match_modules_set(
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
model, (mapping.smooth_layer, *mapping.balance_layers), awq_ignore
):
if len(smooth_layers) > 1:
raise ValueError(
Expand Down Expand Up @@ -356,11 +388,12 @@ def cache_smooth_activations_hook(
args: tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
self._smooth_activation_means[smooth_name] = _accumulate_mean(
# Assume that first argument is the input
args[0].cpu().abs().detach().flatten(0, -2),

act_mean, count = _accumulate_mean(
args[0].abs().detach().flatten(0, -2),
self._smooth_activation_means.get(smooth_name, None),
)
self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count)

return cache_smooth_activations_hook

Expand Down Expand Up @@ -525,6 +558,7 @@ def _compute_best_scale(
best_ratio = -1
best_scales = None
best_error = float("inf")
initial_error = None

org_sd = {
k: v.cpu()
Expand Down Expand Up @@ -569,7 +603,14 @@ def _compute_best_scale(
for balance_layer in balance_layers_to_patch
],
):
for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
total_iterations = n_grid * len(duo_scalings)
pbar = tqdm(
product(range(n_grid), duo_scalings),
total=total_iterations,
desc=f"Grid search for {mapping.smooth_name}",
leave=False,
)
for grid_idx, use_duo_scaling in pbar:
# create new scales
ratio = grid_idx / n_grid

Expand All @@ -588,14 +629,15 @@ def _compute_best_scale(
scales[torch.isnan(scales)] = 1

# Q(W * s)
for balance_layer in balance_layers_to_patch:
for balance_layer in mapping.balance_layers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Iterating over mapping.balance_layers here appears to be incorrect. This loop is part of a grid search to find the best scaling factor by minimizing quantization error. By including non-quantized layers (which can be in balance_layers due to force_balance), their weights are modified, and the resulting output distortion is included in the loss calculation. This loss should ideally only reflect quantization error from the quantized layers. Using balance_layers_to_patch, which is defined before this block and contains only the layers to be quantized, would be the correct approach. The influence of force_balance layers is already correctly handled in the computation of w_mean, which contributes to the scales.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini’s answer seems incorrect. We also need to account for nn.Modules that should not be quantized, so that the model produced by AWQ remains functionally equivalent to the original network.

balance_layer.weight.mul_(_scalesview)

if not hasattr(balance_layer, "quantization_scheme") or not hasattr(
balance_layer.quantization_scheme, "weights"
):
continue

w_qscheme = balance_layer.quantization_scheme.weights
balance_layer.weight.mul_(_scalesview)
call_observer(
balance_layer,
"weight",
Expand All @@ -620,13 +662,21 @@ def _compute_best_scale(
# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_outputs, int_w_outputs)

# Track initial error for logging
if initial_error is None:
initial_error = loss


history.append(
{"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss}
)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()

# Update progress bar with best error
pbar.set_postfix({"best_error": f"{best_error:.3e}"})

mapping.parent.load_state_dict(org_sd, strict=False)

Expand All @@ -639,6 +689,16 @@ def _compute_best_scale(
"https://github.com/vllm-project/llm-compressor/issues"
)

# Log final results
err_reduction = best_error / initial_error
logger.info(
f"AWQ grid search for {mapping.smooth_name}: "
f"initial error = {initial_error:.3e}, "
f"best error = {best_error:.3e}, "
f"error reduction rate (best/initial) = {err_reduction * 100:.3f}%"
)


assert (
torch.isnan(best_scales).sum() == 0
), f"Nan found in scales: {best_scales}"
Expand All @@ -657,7 +717,7 @@ def _compute_loss(
# Compute the MSE loss for each batch
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
loss += torch.nn.functional.mse_loss(
fp16_batch, int_w_batch.to(fp16_batch.device)
fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum"
).item()
num_elements += fp16_batch.numel()

Expand Down Expand Up @@ -863,9 +923,10 @@ def _accumulate_mean(
sum_added = inp.sum(dim=0)
num_added = inp.size(0)
if prev_mean_and_count is None:
return sum_added, num_added
return sum_added / num_added, num_added

prev_mean, prev_count = prev_mean_and_count
prev_mean = prev_mean.to(inp.device)

prev_sum = prev_mean * prev_count
new_count = prev_count + num_added
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AWQMapping:
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
AWQMapping(
"re:.*post_attention_layernorm$",
["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"],
["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$", "re:.*mlp.gate$"],
),
AWQMapping(
"re:.*up_proj$",
Expand Down