Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions src/llmcompressor/modifiers/autoround/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import torch
import torch.nn as nn
from accelerate.hooks import add_hook_to_module, remove_hook_from_submodules
from auto_round import AutoRound
from auto_round.schemes import PRESET_SCHEMES as AR_PRESET_SCHEMES
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme
from auto_round.wrapper import WrapperWALayer
from compressed_tensors.offload import get_execution_device, get_offloaded_device
from compressed_tensors.offload.module import offload_module, remove_module_offload
from compressed_tensors.quantization import (
QuantizationMetadata,
QuantizationScheme,
Expand Down Expand Up @@ -62,30 +63,22 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper:


@contextmanager
def suspend_accelerate_hooks(model: nn.Module):
def suspend_offloading(model: nn.Module):
"""
Temporarily suspend Accelerate hooks from a model.

This context manager detaches all Accelerate hooks (used for device offloading,
dtype casting, etc.) from the model, allowing Autoround to operate without
interference. On exit, the model is restored to its original device
and all hooks are re-attached.
Temporarily suspend offloading, allow AutoRound to take over device movement
"""
saved_hooks = {}
original_device = next(model.parameters()).device
offloading_info = dict()
for name, module in model.named_modules():
if hasattr(module, "_hf_hook"):
saved_hooks[name] = module._hf_hook

remove_hook_from_submodules(model)
try:
yield
finally:
remove_hook_from_submodules(model)
model.to(original_device)
for name, module in model.named_modules():
if name in saved_hooks:
add_hook_to_module(module, saved_hooks[name], append=True)
offloading_info[name] = (
get_execution_device(module),
get_offloaded_device(module),
)
remove_module_offload(module, onload_tensors=True)

yield

for name, module in model.named_modules():
offload_module(module, *offloading_info[name])


class AutoRoundModifier(Modifier, QuantizationMixin):
Expand Down Expand Up @@ -285,7 +278,7 @@ def apply_autoround(self, state, subgraph):
with (
torch.enable_grad(),
align_module_device(decoding_layer),
suspend_accelerate_hooks(wrapped_model),
suspend_offloading(wrapped_model),
):
ar = AutoRound(
model=wrapped_model,
Expand Down
39 changes: 16 additions & 23 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
from dataclasses import dataclass
from functools import wraps
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
from accelerate.hooks import remove_hook_from_module
from compressed_tensors.utils import (
has_offloaded_params,
offloaded_dispatch,
patch_attr,
remove_dispatch,
)
from compressed_tensors.offload import disable_onloading, offload_model
from compressed_tensors.utils import patch_attr
from compressed_tensors.utils.match import match_targets
from loguru import logger
from torch.fx import Graph, GraphModule, Node
Expand All @@ -26,6 +22,7 @@
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer
from llmcompressor.utils.dev import get_main_device
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import get_no_split_params

Expand Down Expand Up @@ -106,7 +103,7 @@ def trace_subgraphs(
# find modules
targets = match_modules(model, sequential_targets)
ancestors = get_sequential_ancestors(model, targets)
offloaded = set(m for m in model.modules() if has_offloaded_params(m))
offloaded = set() # TODO: cleanup logic

# initialize arguments
tracer = SequentialTracer(ancestors, offloaded)
Expand All @@ -131,6 +128,9 @@ def trace_subgraphs(
assert isinstance(model.forward, MethodType)
assert isinstance(type(model).forward, FunctionType)

# avoid device movement during tracing
stack.enter_context(disable_onloading())

with append_autowrap_source_on_fail():
graph = GraphModule(
model,
Expand Down Expand Up @@ -529,7 +529,11 @@ def is_ancestor(module: Module) -> bool:
return ancestors


def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
def dispatch_for_sequential(
model: PreTrainedModel,
onload_device: Optional[torch.device | str] = None,
offload_device: torch.device | str = torch.device("cpu"),
) -> PreTrainedModel:
"""
Dispatch a model for sequential calibration using a sequential pipeline.
The model will be offloaded to the CPU and dispatched to CUDA/XPU device
Expand All @@ -538,20 +542,9 @@ def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
:param model: model to dispatch
:return: dispatched model
"""
remove_dispatch(model)

if torch.cuda.is_available():
offloaded_dispatch(model, execution_device=torch.device("cuda:0"))
elif hasattr(torch, "xpu") and torch.xpu.is_available():
offloaded_dispatch(model, execution_device=torch.device("xpu:0"))
elif hasattr(torch, "npu") and torch.npu.is_available():
offloaded_dispatch(model, execution_device=torch.device("npu:0"))
else:
logger.warning(
"CUDA/XPU/NPU is not available! Compressing model on CPU instead"
)

return model
if onload_device is None:
onload_device = get_main_device()
return offload_model(model, onload_device, offload_device)


def _get_autowrap_functions() -> tuple[Callable[[Any], Any], ...]:
Expand Down
11 changes: 6 additions & 5 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING

import torch
from compressed_tensors.utils import disable_offloading, get_execution_device
from compressed_tensors.utils import disable_offloading
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

Expand All @@ -16,6 +16,7 @@
handle_sequential_oom,
trace_subgraphs,
)
from llmcompressor.utils.dev import get_main_device
from llmcompressor.utils.helpers import (
DISABLE_QAC_MODIFIERS,
DisableQuantization,
Expand Down Expand Up @@ -62,8 +63,9 @@ def __call__(
session = active_session()

# prepare model for sequential onloading
dispatch_for_sequential(model)
model_device = get_execution_device(model)
onload_device = get_main_device()
offload_device = torch.device(dataset_args.sequential_offload_device)
dispatch_for_sequential(model, onload_device, offload_device)

# prepare to trace subgraphs
modifiers = session.lifecycle.recipe.modifiers
Expand Down Expand Up @@ -91,9 +93,8 @@ def __call__(
stack.enter_context(DisableQuantization(model))

# prepare intermediates cache
offload_device = torch.device(dataset_args.sequential_offload_device)
activations = IntermediatesCache.from_dataloader(
dataloader, model_device, offload_device=offload_device
dataloader, onload_device, offload_device
)

for subgraph_index, subgraph in enumerate(subgraphs):
Expand Down
50 changes: 27 additions & 23 deletions src/llmcompressor/utils/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging
import os
import tempfile
from functools import wraps
from typing import Type

import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from compressed_tensors.utils import patch_attr, remove_dispatch
from compressed_tensors.offload import dispatch_model
from compressed_tensors.utils import patch_attr
from huggingface_hub import snapshot_download
from loguru import logger
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
Expand All @@ -17,6 +18,7 @@
__all__ = [
"skip_weights_download",
"patch_transformers_logger_level",
"get_main_device",
"dispatch_for_generation",
]

Expand Down Expand Up @@ -116,28 +118,30 @@ def patch_transformers_logger_level(level: int = logging.ERROR):
transformers_logger.setLevel(level=restore_log_level)


def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel:
def get_main_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda:0")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu:0")
else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
return torch.device("cpu")


@wraps(dispatch_model)
def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel:
"""
Dispatch a model autoregressive generation. This means that modules are dispatched
evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks
that may have existed previously.
evenly across avaiable devices and kept onloaded if possible.

:param model: model to dispatch
:return: model which is dispatched
:param hint_batch_size: reserve memory for batch size of inputs
:param hint_batch_seq_len: reserve memory for sequence of length of inputs
:param hint_model_dtype: reserve memory for model's dtype.
Will be inferred from model if none is provided
:param hint_extra_memory: extra memory reserved for model serving
:param no_split_modules: names of module classes which should not be split
across multiple devices
:return: dispatched model
"""
remove_dispatch(model)

no_split_module_classes = model._get_no_split_modules("auto")
max_memory = get_balanced_memory(
model,
dtype=model.dtype,
no_split_module_classes=no_split_module_classes,
)
device_map = infer_auto_device_map(
model,
dtype=model.dtype,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
)

return dispatch_model(model, device_map=device_map)
return dispatch_model(*args, **kwargs)
12 changes: 3 additions & 9 deletions src/llmcompressor/utils/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from compressed_tensors import has_offloaded_params, register_offload_parameter
from loguru import logger
from torch.nn import Parameter
from transformers import PreTrainedModel
Expand Down Expand Up @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel):

# clone data to untie
for module in (input_embed, output_embed):
if not has_offloaded_params(module):
data = module.weight.data
else:
data = module._hf_hook.weights_map["weight"]

requires_grad = module.weight.requires_grad
untied_param = Parameter(data.clone(), requires_grad=requires_grad)
register_offload_parameter(module, "weight", untied_param)
weight = module.weight
param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad)
module.register_parameter("weight", param)

# modify model config
if hasattr(model.config, "tie_word_embeddings"):
Expand Down
2 changes: 2 additions & 0 deletions tests/llmcompressor/modeling/test_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


@pytest.mark.unit
@torch.no_grad()
def test_center_embeddings():
embedding = torch.nn.Embedding(10, 10)
center_embeddings(embedding)
Expand All @@ -15,6 +16,7 @@ def test_center_embeddings():


@pytest.mark.unit
@torch.no_grad()
def test_fuse_norm_linears():
norm = torch.nn.LayerNorm((5,))
norm.weight.data = torch.rand(norm.weight.shape)
Expand Down
12 changes: 9 additions & 3 deletions tests/llmcompressor/pipelines/test_model_free_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_tiny_block_quant():
[_get_tiny_w4a16_quant(), "FP8_dynamic", _get_tiny_block_quant(), "NVFP4A16"],
)
def test_model_free_ptq_matches_oneshot(scheme, tmp_path):
model = "nm-testing/tinysmokellama-3.2"
model = "Qwen/Qwen3-0.6B"
ignore = ["model.embed_tokens", "lm_head"]
device = "cuda:0"

Expand Down Expand Up @@ -119,8 +119,14 @@ def _assert_config_equal(a_path: str, b_path: str):

a_qconfig = config_a.pop("quantization_config")
b_qconfig = config_b.pop("quantization_config")
config_a.pop("transformers_version")
config_b.pop("transformers_version")
config_a.pop("transformers_version", None)
config_b.pop("transformers_version", None)
config_a.pop("torch_dtype", None)
config_b.pop("torch_dtype", None)
config_a.pop("dtype", None)
config_b.pop("dtype", None)
config_a.pop("layer_types", None)
config_b.pop("layer_types", None)

assert config_a == config_b

Expand Down