Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/awq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ In order to target weight and activation scaling locations within the model, the
]
```

To support other model families, you can add supply your own mappings via the `mappings` argument with instantiating the `AWQModifier`, or you can add them to the registry [here](/src/llmcompressor/modifiers/awq/mappings.py) (contributions are welcome!)
Note: the mappings define which layers get smoothed whereas targets and ignore define which layers get quantized. So if you include a layer in the ignore list that is going to get matched due to the included mappings, it will get smoothed but not quantized.

To support other model families, you can supply your own mappings via the `mappings` argument with instantiating the `AWQModifier`, or you can add them to the registry [here](/src/llmcompressor/modifiers/awq/mappings.py) (contributions are welcome!)
43 changes: 33 additions & 10 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class AWQModifier(Modifier, QuantizationMixin):
to smoothed) and the second entry is the layer whose output is scaled to
achieve the smoothing.
If regex is used, it matches layers with the largest overlap in module name.
:param ignore: list of layers to ignore, even if they match a regex in mappings.
:param ignore: list of layers to ignore during quantization (not smoothed).
It should match the name of layers whose outputs are scaled to achieve
smoothing (the second entry of the mappings list).
:param offload_device: offload cached args to this device, which reduces memory
Expand Down Expand Up @@ -280,9 +280,19 @@ def _set_resolved_mappings(self, model: Module) -> None:
"""
resolved_mappings: list[ResolvedMapping] = []
module_to_name = get_module_to_name_dict(model)
# Get names of modules targeted for quantization (excludes ignored)
targeted_names = set(
name
for name, _ in match_named_modules(
model, self.resolved_targets, self.ignore
)
)
for mapping in self.mappings:
# we deliberately don't use the ignore list when matching mappings,
# so that we can handle layers that need smoothing but not quantization
# we only skip if no layers in mapping are targeted for quantization.
for smooth_layers, *nested_balance_layers in match_modules_set(
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
model, (mapping.smooth_layer, *mapping.balance_layers)
):
if len(smooth_layers) > 1:
raise ValueError(
Expand All @@ -301,19 +311,27 @@ def _set_resolved_mappings(self, model: Module) -> None:
for balance_layer in balance_layers
]

# Check if at least one layer is targeted for quantization
any_targeted = smooth_name in targeted_names or any(
bn in targeted_names for bn in balance_names
)

all_compatible = _check_layers_are_compatible(
smooth_layer, smooth_name, balance_layers, balance_names
)

# skip mapping if any of the balance layers are incompatible
if not all_compatible or len(balance_layers) == 0:
skip_message: str | None = None
if not all_compatible:
skip_message = " because found incompatible balance layers"
elif not any_targeted:
skip_message = " because no layers are targeted for quantization"
elif len(balance_layers) == 0:
skip_message = " because no balance layers were found"

if skip_message:
logger.warning(
f"skipping AWQ for {smooth_name} for mapping {mapping}"
+ (
" because found incompatible balance layers"
if not all_compatible
else " because no balance layers were found"
)
+ skip_message
)

continue
Expand Down Expand Up @@ -568,7 +586,12 @@ def _compute_best_scale(
for balance_layer in balance_layers_to_patch
],
):
for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
total_iterations = n_grid * len(duo_scalings)
for grid_idx, use_duo_scaling in tqdm(
product(range(n_grid), duo_scalings),
total=total_iterations,
desc="Grid search",
):
# create new scales
ratio = grid_idx / n_grid

Expand Down
49 changes: 44 additions & 5 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from typing import Callable

import torch
from compressed_tensors.utils import align_module_device, match_modules_set
from compressed_tensors.utils import (
align_module_device,
match_modules_set,
match_named_modules,
)
from loguru import logger
from pydantic import ConfigDict, Field
from torch.nn import Module
Expand Down Expand Up @@ -86,9 +90,10 @@ class SmoothQuantModifier(Modifier):
achieve the smoothing. If regex is used, it matches layers with the largest
overlap in module name. If not supplied the argument will be inferred from the
model architecture.
:param ignore: list of layers to ignore, even if they match a regex in mappings.
It should match the name of layers whose outputs are scaled to achieve
smoothing (the second entry of the mappings list).
:param ignore: list of layers to ignore during smoothing.
Mappings are only skipped if all layers in the mapping are ignored.
It should match the name of layers whose outputs are scaled to
achieve smoothing (the second entry of the mappings list).
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
:param calibration_function: optional function to use for the forward pass, or None
Expand Down Expand Up @@ -200,9 +205,19 @@ def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]:
"""
resolved_mappings = []
module_to_name = get_module_to_name_dict(model)
# Get names of modules that are not ignored
ignored_names = set()
if self.ignore:
ignored_names = set(
name for name, _ in match_named_modules(model, self.ignore)
)

for mapping in self.mappings:
# we deliberately don't use the ignore list when matching mappings
# so that we can handle layers that need smoothing but not all operations
# we only skip if no layers in mapping would be smoothed.
for *nested_balance_layers, smooth_layers in match_modules_set(
model, tree_leaves(mapping), self.ignore
model, tree_leaves(mapping)
):
if len(smooth_layers) > 1:
raise ValueError(
Expand All @@ -213,6 +228,30 @@ def _resolve_mappings(self, model: Module) -> list[SmoothQuantMapping]:
smooth_layer = smooth_layers[0]
smooth_name = module_to_name.get(smooth_layers[0])
balance_layers = tree_leaves(nested_balance_layers)
balance_names = [
module_to_name.get(balance_layer)
for balance_layer in balance_layers
]

# Check if at least one layer would be smoothed (not ignored)
any_not_ignored = smooth_name not in ignored_names or any(
bn not in ignored_names for bn in balance_names
)

if not any_not_ignored:
logger.warning(
f"Skipping SmoothQuant for {smooth_name} because all layers "
"in the mapping are in the ignore list"
)
continue

if len(balance_layers) == 0:
logger.warning(
f"Skipping SmoothQuant for {smooth_name} because no balance "
"layers were found"
)
continue

resolved_mappings.append(
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
)
Expand Down
59 changes: 59 additions & 0 deletions tests/llmcompressor/modifiers/awq/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,65 @@ def test_validate():
AWQModifier(scheme="W4A16", duo_scaling="x")


@pytest.mark.unit
def test_ignore_behavior():
"""Test that mapping is skipped when NO layers are targeted for quantization"""
# Test case 1: Some balance layers ignored but at least one is targeted
# Mapping should proceed
awq = AWQModifier(
mappings=[
AWQMapping(
"re:.*input_layernorm",
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
),
],
ignore=["re:.*q_proj", "re:.*k_proj"], # Only 2 of 3 balance layers ignored
scheme="W4A16_ASYM",
)

self_attn = torch.nn.ModuleDict(
{
"q_proj": Linear(4, 4),
"k_proj": Linear(4, 4),
"v_proj": Linear(4, 4),
}
)
model = torch.nn.ModuleDict(
{
"decoder": torch.nn.ModuleDict(
{
"self_attn": self_attn,
"input_layernorm": torch.nn.LayerNorm(4),
}
)
}
)

awq._set_resolved_mappings(model)

# Mapping should exist because v_proj is targeted for quantization
assert len(awq._resolved_mappings) == 1

# Test case 2: All Linear layers ignored - mapping should be skipped
# because no layers are targeted for quantization
awq2 = AWQModifier(
mappings=[
AWQMapping(
"re:.*input_layernorm",
["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
),
],
ignore=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
scheme="W4A16_ASYM",
)

awq2._set_resolved_mappings(model)

# Mapping should be skipped because no layers are targeted for quantization
# (input_layernorm is LayerNorm, not Linear, so not targeted anyway)
assert len(awq2._resolved_mappings) == 0


@pytest.mark.unit
def test_moe_multiple_balance_layers():
"""Test AWQ mapping with multiple balance layers in MoE architecture"""
Expand Down
54 changes: 54 additions & 0 deletions tests/llmcompressor/modifiers/smoothquant/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,57 @@ def create_moe_layer():
# Verify all balance layers are unique
balance_layer_ids = [id(layer) for layer in mapping.balance_layers]
assert len(balance_layer_ids) == len(set(balance_layer_ids))


@pytest.mark.unit
def test_ignore_behavior():
"""Test that mapping is skipped when ALL layers are in ignore list"""
hidden_size = 64

model = torch.nn.ModuleDict(
{
"decoder": torch.nn.ModuleDict(
{
"input_layernorm": torch.nn.LayerNorm(hidden_size),
"self_attn": torch.nn.ModuleDict(
{
"q_proj": torch.nn.Linear(hidden_size, hidden_size),
"k_proj": torch.nn.Linear(hidden_size, hidden_size),
"v_proj": torch.nn.Linear(hidden_size, hidden_size),
}
),
}
)
}
)

# Test case 1: Some balance layers ignored - mapping should proceed
sq = SmoothQuantModifier(
smoothing_strength=0.5,
mappings=[
(["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm")
],
ignore=["re:.*q_proj", "re:.*k_proj"], # Only 2 of 3 balance layers ignored
)

resolved_mappings = sq._resolve_mappings(model)
# Mapping should exist because v_proj is not ignored
assert len(resolved_mappings) == 1

# Test case 2: All layers ignored - mapping should be skipped
sq2 = SmoothQuantModifier(
smoothing_strength=0.5,
mappings=[
(["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm")
],
ignore=[
"re:.*input_layernorm",
"re:.*q_proj",
"re:.*k_proj",
"re:.*v_proj",
],
)

resolved_mappings2 = sq2._resolve_mappings(model)
# Mapping should be skipped because all layers are ignored
assert len(resolved_mappings2) == 0