diff --git a/src/llmcompressor/modifiers/autoround/base.py b/src/llmcompressor/modifiers/autoround/base.py index f8ffd73123..f994ebd985 100644 --- a/src/llmcompressor/modifiers/autoround/base.py +++ b/src/llmcompressor/modifiers/autoround/base.py @@ -3,11 +3,12 @@ import torch import torch.nn as nn -from accelerate.hooks import add_hook_to_module, remove_hook_from_submodules from auto_round import AutoRound from auto_round.schemes import PRESET_SCHEMES as AR_PRESET_SCHEMES from auto_round.schemes import QuantizationScheme as ARQuantizationScheme from auto_round.wrapper import WrapperWALayer +from compressed_tensors.offload import get_execution_device, get_offloaded_device +from compressed_tensors.offload.module import offload_module, remove_module_offload from compressed_tensors.quantization import ( QuantizationMetadata, QuantizationScheme, @@ -62,30 +63,22 @@ def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper: @contextmanager -def suspend_accelerate_hooks(model: nn.Module): +def suspend_offloading(model: nn.Module): """ - Temporarily suspend Accelerate hooks from a model. - - This context manager detaches all Accelerate hooks (used for device offloading, - dtype casting, etc.) from the model, allowing Autoround to operate without - interference. On exit, the model is restored to its original device - and all hooks are re-attached. + Temporarily suspend offloading, allow AutoRound to take over device movement """ - saved_hooks = {} - original_device = next(model.parameters()).device + offloading_info = dict() for name, module in model.named_modules(): - if hasattr(module, "_hf_hook"): - saved_hooks[name] = module._hf_hook - - remove_hook_from_submodules(model) - try: - yield - finally: - remove_hook_from_submodules(model) - model.to(original_device) - for name, module in model.named_modules(): - if name in saved_hooks: - add_hook_to_module(module, saved_hooks[name], append=True) + offloading_info[name] = ( + get_execution_device(module), + get_offloaded_device(module), + ) + remove_module_offload(module, onload_tensors=True) + + yield + + for name, module in model.named_modules(): + offload_module(module, *offloading_info[name]) class AutoRoundModifier(Modifier, QuantizationMixin): @@ -285,7 +278,7 @@ def apply_autoround(self, state, subgraph): with ( torch.enable_grad(), align_module_device(decoding_layer), - suspend_accelerate_hooks(wrapped_model), + suspend_offloading(wrapped_model), ): ar = AutoRound( model=wrapped_model, diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 53e5ee1e74..0fafebbc49 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -4,16 +4,12 @@ from dataclasses import dataclass from functools import wraps from types import FunctionType, MethodType -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional import torch from accelerate.hooks import remove_hook_from_module -from compressed_tensors.utils import ( - has_offloaded_params, - offloaded_dispatch, - patch_attr, - remove_dispatch, -) +from compressed_tensors.offload import disable_onloading, offload_model +from compressed_tensors.utils import patch_attr from compressed_tensors.utils.match import match_targets from loguru import logger from torch.fx import Graph, GraphModule, Node @@ -26,6 +22,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer +from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import get_no_split_params @@ -106,7 +103,7 @@ def trace_subgraphs( # find modules targets = match_modules(model, sequential_targets) ancestors = get_sequential_ancestors(model, targets) - offloaded = set(m for m in model.modules() if has_offloaded_params(m)) + offloaded = set() # TODO: cleanup logic # initialize arguments tracer = SequentialTracer(ancestors, offloaded) @@ -131,6 +128,9 @@ def trace_subgraphs( assert isinstance(model.forward, MethodType) assert isinstance(type(model).forward, FunctionType) + # avoid device movement during tracing + stack.enter_context(disable_onloading()) + with append_autowrap_source_on_fail(): graph = GraphModule( model, @@ -529,7 +529,11 @@ def is_ancestor(module: Module) -> bool: return ancestors -def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: +def dispatch_for_sequential( + model: PreTrainedModel, + onload_device: Optional[torch.device | str] = None, + offload_device: torch.device | str = torch.device("cpu"), +) -> PreTrainedModel: """ Dispatch a model for sequential calibration using a sequential pipeline. The model will be offloaded to the CPU and dispatched to CUDA/XPU device @@ -538,20 +542,9 @@ def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel: :param model: model to dispatch :return: dispatched model """ - remove_dispatch(model) - - if torch.cuda.is_available(): - offloaded_dispatch(model, execution_device=torch.device("cuda:0")) - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - offloaded_dispatch(model, execution_device=torch.device("xpu:0")) - elif hasattr(torch, "npu") and torch.npu.is_available(): - offloaded_dispatch(model, execution_device=torch.device("npu:0")) - else: - logger.warning( - "CUDA/XPU/NPU is not available! Compressing model on CPU instead" - ) - - return model + if onload_device is None: + onload_device = get_main_device() + return offload_model(model, onload_device, offload_device) def _get_autowrap_functions() -> tuple[Callable[[Any], Any], ...]: diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index d207b00f6a..1c707d8e3d 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING import torch -from compressed_tensors.utils import disable_offloading, get_execution_device +from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -16,6 +16,7 @@ handle_sequential_oom, trace_subgraphs, ) +from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import ( DISABLE_QAC_MODIFIERS, DisableQuantization, @@ -62,8 +63,9 @@ def __call__( session = active_session() # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) + onload_device = get_main_device() + offload_device = torch.device(dataset_args.sequential_offload_device) + dispatch_for_sequential(model, onload_device, offload_device) # prepare to trace subgraphs modifiers = session.lifecycle.recipe.modifiers @@ -91,9 +93,8 @@ def __call__( stack.enter_context(DisableQuantization(model)) # prepare intermediates cache - offload_device = torch.device(dataset_args.sequential_offload_device) activations = IntermediatesCache.from_dataloader( - dataloader, model_device, offload_device=offload_device + dataloader, onload_device, offload_device ) for subgraph_index, subgraph in enumerate(subgraphs): diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index 8f256ce805..9349371581 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -2,13 +2,14 @@ import logging import os import tempfile +from functools import wraps from typing import Type import torch -from accelerate import dispatch_model, infer_auto_device_map -from accelerate.utils import get_balanced_memory -from compressed_tensors.utils import patch_attr, remove_dispatch +from compressed_tensors.offload import dispatch_model +from compressed_tensors.utils import patch_attr from huggingface_hub import snapshot_download +from loguru import logger from safetensors.torch import save_file from transformers import AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import TORCH_INIT_FUNCTIONS @@ -17,6 +18,7 @@ __all__ = [ "skip_weights_download", "patch_transformers_logger_level", + "get_main_device", "dispatch_for_generation", ] @@ -116,28 +118,30 @@ def patch_transformers_logger_level(level: int = logging.ERROR): transformers_logger.setLevel(level=restore_log_level) -def dispatch_for_generation(model: PreTrainedModel) -> PreTrainedModel: +def get_main_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda:0") + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu:0") + else: + logger.warning("CUDA/XPU is not available! Compressing model on CPU instead") + return torch.device("cpu") + + +@wraps(dispatch_model) +def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel: """ Dispatch a model autoregressive generation. This means that modules are dispatched - evenly across avaiable devices and kept onloaded if possible. Removes any HF hooks - that may have existed previously. + evenly across avaiable devices and kept onloaded if possible. :param model: model to dispatch - :return: model which is dispatched + :param hint_batch_size: reserve memory for batch size of inputs + :param hint_batch_seq_len: reserve memory for sequence of length of inputs + :param hint_model_dtype: reserve memory for model's dtype. + Will be inferred from model if none is provided + :param hint_extra_memory: extra memory reserved for model serving + :param no_split_modules: names of module classes which should not be split + across multiple devices + :return: dispatched model """ - remove_dispatch(model) - - no_split_module_classes = model._get_no_split_modules("auto") - max_memory = get_balanced_memory( - model, - dtype=model.dtype, - no_split_module_classes=no_split_module_classes, - ) - device_map = infer_auto_device_map( - model, - dtype=model.dtype, - max_memory=max_memory, - no_split_module_classes=no_split_module_classes, - ) - - return dispatch_model(model, device_map=device_map) + return dispatch_model(*args, **kwargs) diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 9b4831a290..b8e61f725e 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -1,5 +1,4 @@ import torch -from compressed_tensors import has_offloaded_params, register_offload_parameter from loguru import logger from torch.nn import Parameter from transformers import PreTrainedModel @@ -28,14 +27,9 @@ def untie_word_embeddings(model: PreTrainedModel): # clone data to untie for module in (input_embed, output_embed): - if not has_offloaded_params(module): - data = module.weight.data - else: - data = module._hf_hook.weights_map["weight"] - - requires_grad = module.weight.requires_grad - untied_param = Parameter(data.clone(), requires_grad=requires_grad) - register_offload_parameter(module, "weight", untied_param) + weight = module.weight + param = Parameter(weight.data.clone(), requires_grad=weight.requires_grad) + module.register_parameter("weight", param) # modify model config if hasattr(model.config, "tie_word_embeddings"): diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index 005d89f99b..21f8707be0 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -5,6 +5,7 @@ @pytest.mark.unit +@torch.no_grad() def test_center_embeddings(): embedding = torch.nn.Embedding(10, 10) center_embeddings(embedding) @@ -15,6 +16,7 @@ def test_center_embeddings(): @pytest.mark.unit +@torch.no_grad() def test_fuse_norm_linears(): norm = torch.nn.LayerNorm((5,)) norm.weight.data = torch.rand(norm.weight.shape) diff --git a/tests/llmcompressor/pipelines/test_model_free_ptq.py b/tests/llmcompressor/pipelines/test_model_free_ptq.py index c0c2aec116..8decda7c05 100644 --- a/tests/llmcompressor/pipelines/test_model_free_ptq.py +++ b/tests/llmcompressor/pipelines/test_model_free_ptq.py @@ -45,7 +45,7 @@ def _get_tiny_block_quant(): [_get_tiny_w4a16_quant(), "FP8_dynamic", _get_tiny_block_quant(), "NVFP4A16"], ) def test_model_free_ptq_matches_oneshot(scheme, tmp_path): - model = "nm-testing/tinysmokellama-3.2" + model = "Qwen/Qwen3-0.6B" ignore = ["model.embed_tokens", "lm_head"] device = "cuda:0" @@ -119,8 +119,14 @@ def _assert_config_equal(a_path: str, b_path: str): a_qconfig = config_a.pop("quantization_config") b_qconfig = config_b.pop("quantization_config") - config_a.pop("transformers_version") - config_b.pop("transformers_version") + config_a.pop("transformers_version", None) + config_b.pop("transformers_version", None) + config_a.pop("torch_dtype", None) + config_b.pop("torch_dtype", None) + config_a.pop("dtype", None) + config_b.pop("dtype", None) + config_a.pop("layer_types", None) + config_b.pop("layer_types", None) assert config_a == config_b