Skip to content
Merged
82 changes: 82 additions & 0 deletions examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier

# Select model and load it.
MODEL_ID = "Qwen/Qwen3-30B-A3B"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Select calibration dataset.
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
[{"role": "user", "content": example["text"]}],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


# Configure the quantization algorithm to run.
# NOTE: vllm currently does not support asym MoE, using symmetric here
recipe = [
AWQModifier(
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
scheme="W4A16",
targets=["Linear"],
),
]

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
174 changes: 106 additions & 68 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
get_parent_by_name,
)
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers

__all__ = ["AWQModifier"]

Expand Down Expand Up @@ -307,22 +303,37 @@ def _set_resolved_mappings(self, model: Module) -> None:
repeat for model.layer.1 and so on
"""
resolved_mappings: list[ResolvedMapping] = []
num_skipped_oproj_mappings = 0
for mapping in self.mappings:
to_smooth_layers = get_layers(mapping.smooth_layer, model)
for layer_name, smooth_layer in to_smooth_layers.items():
# always exclude `.weight_observer`, only want `.weight`
if layer_name not in self.ignore and not layer_name.endswith(
"_observer"
):
balance_layers, balance_names = [], []
for balance_suffix in mapping.balance_layers:
# find the submodule that matches the activation layer
balance_name, balance_layer = get_matching_layer(
balance_suffix, layer_name, model
)
if not balance_layer:
continue
for mapping_idx, mapping in enumerate(self.mappings):
smooth_layers = get_layers(mapping.smooth_layer, model)
smooth_names = [
smooth_name
for smooth_name in smooth_layers
if (
smooth_name not in self.ignore
and not smooth_name.endswith("_observer")
)
]

num_skipped_mappings = 0
pbar = tqdm(smooth_names)
for smooth_name in pbar:
pbar.set_description(
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
f" ({num_skipped_mappings} skipped)"
)
smooth_layer = smooth_layers[smooth_name]

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in get_layers(
balance_regex,
smooth_parent,
).items():
balance_name = f"{smooth_parent_name}.{balance_suffix}"

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
Expand All @@ -332,52 +343,43 @@ def _set_resolved_mappings(self, model: Module) -> None:
and ".o_proj" in balance_name
and (
(
".v_proj" in layer_name
".v_proj" in smooth_name
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
".qkv_proj" in layer_name
".qkv_proj" in smooth_name
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
)
):
num_skipped_oproj_mappings += 1
num_skipped_mappings += 1
continue

balance_layers.append(balance_layer)
balance_names.append(balance_name)

if len(balance_layers) == 0:
continue

# each mapping can contain multiple layers to balance, but only
# one layer to smooth
if len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers,
# parent of any balance layer is the parent
parent_name, parent = get_parent_by_name(
layer_name=balance_name, model=model
)
resolved_mappings.append(
ResolvedMapping(
layer_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
)
if len(balance_layers) == 0:
continue

elif len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers, find lowest common parent
parent_name, parent = get_lowest_common_parent(balance_names, model)

resolved_mappings.append(
ResolvedMapping(
smooth_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
)
if num_skipped_oproj_mappings > 0:
logger.info(
f"Excluded {num_skipped_oproj_mappings} from resolved "
"mappings due to shape mismatch"
)
)
self._resolved_mappings = resolved_mappings
return

Expand All @@ -401,11 +403,9 @@ def cache_smooth_activations_hook(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
# Assume that first argument is the input
inp = args[0].cpu().detach().squeeze()

self._smooth_activation_means[smooth_name] = _accumulate_mean(
inp,
# Assume that first argument is the input
args[0].cpu().detach().squeeze(),
self._smooth_activation_means.get(smooth_name, None),
)

Expand Down Expand Up @@ -444,12 +444,14 @@ def _apply_smoothing(self, model: Module) -> None:

:param model: model to apply smoothing to
"""
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
# NOTE: When using SequentialPipeline, not all the mappings
# will have cached activations in the segment being udpated
if mapping.smooth_name not in self._smooth_activation_means:
continue

# NOTE: When using SequentialPipeline, not all the mappings
# will have cached activations in the segment being udpated
mappings_to_smooth = [
mapping
for mapping in self._resolved_mappings
if mapping.smooth_name in self._smooth_activation_means
]
for mapping in tqdm(mappings_to_smooth, desc="Smoothing"):
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
parent_module = mapping.parent
Expand All @@ -473,10 +475,15 @@ def _apply_smoothing(self, model: Module) -> None:
# [STEP 3]: Compute output of module
# could cache from hook, rather than recomputing here
fp16_output = self._run_samples(parent_module)
fp16_output = fp16_output.clip(
torch.finfo(fp16_output.dtype).min,
torch.finfo(fp16_output.dtype).max,
)
if fp16_output.shape[0] == 0:
logger.info(
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
"found to scale. This can occasionally occur in MoE models "
"when certain experts are not activated by calibration samples."
)
del self._smooth_activation_means[mapping.smooth_name]
continue

x_mean = self._smooth_activation_means[mapping.smooth_name][0]

# [STEP 4]: Compute loss
Expand Down Expand Up @@ -536,10 +543,14 @@ def smooth(module):

def _run_samples(self, module: Module) -> torch.Tensor:
with align_module_device(module):
outputs = [
module(**batch_kwargs)
for batch_kwargs in self._parent_args_cache[module]
]
return torch.cat(
[
module(**batch_kwargs)[0]
for batch_kwargs in self._parent_args_cache[module]
output[0] if isinstance(output, Tuple) else output
for output in outputs
],
dim=0,
)
Expand Down Expand Up @@ -736,3 +747,30 @@ def _accumulate_mean(
new_count = prev_count + num_added

return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent,
excluding parents of type ModuleList, which don't seem to play
nicely with hooks.
Returns name of parent and pointer to parent module

Implementation is a small alteration of os.path.commonprefix
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
"""
s1 = min(names)
s2 = max(names)
parent_name = ""
for i, c in enumerate(s1):
if c != s2[i]:
parent_name = s1[:i].rstrip(".")
break

while True:
if parent_name == "":
return "", module
parent = get_layer_by_name(parent_name, module)
if not isinstance(parent, torch.nn.ModuleList):
return parent_name, parent
parent_name = ".".join(parent_name.split(".")[:-1])
Loading