Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bc97d48
update autoround version
yiliu30 Nov 21, 2025
19ab4f2
Merge branch 'main' into autoround-version
yiliu30 Nov 21, 2025
9ba113c
expose bs
yiliu30 Nov 24, 2025
646982a
Merge branch 'autoround-version' of https://github.com/yiliu30/llm-co…
yiliu30 Nov 24, 2025
1050335
use 0.9.1
yiliu30 Nov 26, 2025
50e6682
fix
yiliu30 Nov 27, 2025
d139071
update
yiliu30 Nov 27, 2025
a0affbd
enable auto-dispatch
yiliu30 Dec 2, 2025
17ba9f5
add ds example
yiliu30 Dec 2, 2025
cd943cd
merge main
yiliu30 Dec 15, 2025
8338ed5
pass ignore to ar
yiliu30 Dec 15, 2025
56515af
add qwen example
yiliu30 Dec 15, 2025
ad6c1c0
update example
yiliu30 Dec 15, 2025
09a72c0
format
yiliu30 Dec 15, 2025
af112bd
update
yiliu30 Dec 15, 2025
ec98118
refine suspend hook
yiliu30 Dec 17, 2025
c5eae60
update
yiliu30 Dec 18, 2025
2d482fc
clean code
yiliu30 Dec 18, 2025
17b7e45
add ut
yiliu30 Dec 18, 2025
7a9b3cd
fix
yiliu30 Dec 18, 2025
4f45b17
fix hint
yiliu30 Dec 18, 2025
0fac601
refine
yiliu30 Dec 18, 2025
0f7a990
speedup ut
yiliu30 Dec 18, 2025
58ef017
clean
yiliu30 Dec 19, 2025
c9ea99c
add docstring
yiliu30 Dec 19, 2025
d2a7c92
format
yiliu30 Dec 19, 2025
d48c3d6
Merge branch 'main' into auto-device
yiliu30 Dec 19, 2025
993a68e
Merge branch 'main' into auto-device
yiliu30 Dec 20, 2025
fa8cdcc
Merge branch 'main' into auto-device
yiliu30 Dec 23, 2025
c17e923
rename device_map to device_ids
yiliu30 Dec 24, 2025
1092cde
fix typo
yiliu30 Dec 24, 2025
0734dd5
add docstring
yiliu30 Jan 5, 2026
d3e6da6
Merge branch 'main' into auto-device
yiliu30 Jan 6, 2026
6d4934a
fix format issue
yiliu30 Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/autoround/qwen3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.autoround import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "Qwen/Qwen3-235B-A22B"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
ITERS = 200
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with AutoRound with a group size 128
# * For `Qwen/Qwen3-235B-A22B`, it requires about 300 GB memory
# to run tuning with default settings.
recipe = AutoRoundModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
"re:.*mlp.gate$",
],
iters=ITERS,
enable_torch_compile=False,
device_ids="0,1,2,3", # Use 4 A100 GPUs
)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
shuffle_calibration_samples=False,
)


# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
print(f"save to {SAVE_DIR}")
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)


# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
79 changes: 73 additions & 6 deletions src/llmcompressor/modifiers/autoround/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from accelerate.hooks import add_hook_to_module, remove_hook_from_submodules
from auto_round import AutoRound
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -54,6 +57,34 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper:
return wrapped_model


@contextmanager
def suspend_accelerate_hooks(model: nn.Module):
"""
Temporarily suspend Accelerate hooks from a model.

This context manager detaches all Accelerate hooks (used for device offloading,
dtype casting, etc.) from the model, allowing Autoround to operate without
interference. On exit, the model is restored to its original device
and all hooks are re-attached.
"""
saved_hooks = {}
original_device = next(model.parameters()).device
for name, module in model.named_modules():
if hasattr(module, "_hf_hook"):
saved_hooks[name] = module._hf_hook

remove_hook_from_submodules(model)
try:
yield
finally:
remove_hook_from_submodules(model)
model.to(original_device)
for name, module in model.named_modules():
if name in saved_hooks:
logger.info("Restoring Accelerate hook for module: {}", name)
add_hook_to_module(module, saved_hooks[name], append=True)


class AutoRoundModifier(Modifier, QuantizationMixin):
"""
Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf.
Expand Down Expand Up @@ -103,13 +134,29 @@ class AutoRoundModifier(Modifier, QuantizationMixin):
:param scheme: a single quantization scheme to apply to the model. This is a
dictionary that supports all keys from QuantizationScheme except targets, which
will be set to the targets parameter set at the modifier level.
:param sequential_targets: class names of decoding layers to tune sequentially. If
None, targets are inferred via `get_no_split_params()` to respect no-split
constraints for large models. Defaults to None.
:param iters: number of tuning iterations per block (decoding layer). Higher values
typically improve accuracy at the cost of longer tuning time. Defaults to 200.
:param enable_torch_compile: whether to enable `torch.compile` to accelerate the
tuning loop. Disable if your environment or model encounters compilation issues.
Defaults to True.
:param batch_size: calibration/tuning batch size used by AutoRound when optimizing
rounding/clipping parameters. Larger values can improve stability but require
more memory. Defaults to 8.
:param device_ids: optional device map string for layer dispatch during tuning.
Examples: "0,1" for cuda:0 and cuda:1, or "auto" to use all available GPUs.
When None, no dispatching occurs and the model remains on its current device.
Defaults to None.
"""

sequential_targets: Union[str, List[str], None] = None
# AutoRound modifier arguments
iters: int = 200
enable_torch_compile: bool = True
batch_size: int = 8
device_ids: Optional[str] = None

# private variables
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)
Expand Down Expand Up @@ -215,15 +262,20 @@ def apply_autoround(self, state, subgraph):
wrapped_model = _wrap_decoding_layer(decoding_layer)
wrapped_model.name_or_path = state.model.name_or_path

with torch.enable_grad(), align_module_device(decoding_layer):
with torch.enable_grad(), align_module_device(
decoding_layer
), suspend_accelerate_hooks(wrapped_model):
ar_quant_scheme = self._mapping_config_to_autoround()
fp_layers = self.get_unquantized_layer_names(decoding_layer)
ar = AutoRound(
model=wrapped_model,
tokenizer="",
scheme=ar_quant_scheme,
iters=self.iters,
enable_torch_compile=self.enable_torch_compile,
batch_size=self.batch_size,
device_map=self.device_ids,
fp_layers=",".join(fp_layers) if fp_layers else "",
)
# TODO: configure layer-wise config based on self.resolved_config
ar.configure_layer_config(enable_gguf_official_mixed=False)
Expand All @@ -232,21 +284,25 @@ def apply_autoround(self, state, subgraph):
device = first_param.device
cur_inputs = self._all_module_input[decoding_layer._tmp_name]
decoding_layer.tuning_device = device
# Leave offload for LLMC to handle if `device_ids` is not set
auto_offload = False
if self.device_ids is not None:
# When device_ids is set, we move decoding layer to CPU first,
# then the submodules will be re-dispatched by AutoRound.
decoding_layer.to("cpu")
auto_offload = True

q_input, _ = ar.quantize_block(
block=decoding_layer,
inputs=cur_inputs,
q_input=self._q_input,
device=str(device),
# Leave offload for LLMC
auto_offload=False,
auto_offload=auto_offload,
)
self._q_input = q_input
# Update offload parameters and remove temporary attributes
for _, module in decoding_layer.named_modules():
if hasattr(module, "weight_scale") and hasattr(
module, "weight_zero_point"
):
if hasattr(module, "scale") and hasattr(module, "weight_zero_point"):
# Note: The model's weight is already q-dq in-place by auto-round.
weight_scale = module.scale
del module.scale
Expand Down Expand Up @@ -278,6 +334,17 @@ def on_finalize(self, state: State, **kwargs) -> bool:

return True

def get_unquantized_layer_names(self, wrapped_model: torch.nn.Module) -> List[str]:
unquantized_layers = []

for name, module in wrapped_model.named_modules():
if (
module.__class__.__name__ in self.resolved_targets
and getattr(module, "quantization_scheme", None) is None
):
unquantized_layers.append(name)
return unquantized_layers

def _add_temporary_names(self, model: torch.nn.Module):
for name, mod in model.named_modules():
mod._tmp_name = name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,64 @@ def test_oneshot_application(recipe, tmp_path):
# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")


@requires_gpu(2)
def test_oneshot_with_device_ids(tmp_path):
output = tmp_path / "oneshot_output"
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model)
dataset = get_dataset(
tokenizer=tokenizer,
seqlen=512,
nsamples=4,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

recipe = AutoRoundModifier(
ignore=["lm_head"],
iters=10,
config_groups={
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
)
},
device_ids="0,1",
)

oneshot(
model=model,
dataset=dataset,
output_dir=output,
recipe=recipe,
)
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)

# Check that the model is quantized
# for compression_config - decompress() will attach a quantization_config
# to the model as we decompress right away
# for quantization_config - we have CompressedLinear which will only
# decompress on the forward pass and does not call decompress(). Results
# in a slightly different parameter tree to access the quant config
quantization_config = model_loaded.config.quantization_config.quantization_config
assert quantization_config is not None

# check config is set properly
assert "lm_head" in quantization_config.ignore
assert len(quantization_config.config_groups) == 1
quant_scheme = quantization_config.config_groups["group_0"]
assert isinstance(quant_scheme, QuantizationScheme)

weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

# Check lm-head is not quantized
not_targetted = model_loaded.lm_head
assert not hasattr(not_targetted, "quantization_scheme")