diff --git a/Makefile b/Makefile index fee3283c1..0d4085600 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ quality: @echo "Running copyright checks"; python utils/copyright.py quality $(PYCHECKGLOBS) @echo "Running python quality checks"; - black --check $(PYCHECKDIRS); + black --target-version py310 --check $(PYCHECKDIRS); isort --check-only $(PYCHECKDIRS); flake8 $(PYCHECKDIRS); @@ -17,7 +17,7 @@ style: @echo "Running copyright style"; python utils/copyright.py style $(PYCHECKGLOBS) @echo "Running python styling"; - black $(PYCHECKDIRS); + black --target-version py310 $(PYCHECKDIRS); isort $(PYCHECKDIRS); # run tests for the repo diff --git a/src/compressed_tensors/__init__.py b/src/compressed_tensors/__init__.py index c892e81a9..5c703e6b4 100644 --- a/src/compressed_tensors/__init__.py +++ b/src/compressed_tensors/__init__.py @@ -20,5 +20,14 @@ from .compressors import * from .config import * from .quantization import QuantizationConfig, QuantizationStatus -from .utils import * + +# avoid resolving compressed_tensors.offload as compressed_tensors.utils.offload +from .utils.offload import * +from .utils.helpers import * +from .utils.internal import * +from .utils.match import * +from .utils.permutations_24 import * +from .utils.safetensors_load import * +from .utils.semi_structured_conversions import * +from .utils.type import * from .version import * diff --git a/src/compressed_tensors/offload/__init__.py b/src/compressed_tensors/offload/__init__.py new file mode 100644 index 000000000..072dbdf7a --- /dev/null +++ b/src/compressed_tensors/offload/__init__.py @@ -0,0 +1,197 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from typing import Iterable, Optional + +import torch +from compressed_tensors.offload.cache import OffloadCache +from compressed_tensors.offload.dispatch import ( # noqa: F401 + dispatch_model, + offload_model, + remove_dispatch, +) +from compressed_tensors.offload.module import offload_module, unwrap_offload_forward +from compressed_tensors.offload.utils import get_module_device, move_module_tensor +from compressed_tensors.utils.helpers import patch_attr + + +__all__ = [ + # dispatch models + "offload_model", + "dispatch_model", + "remove_dispatch", + # control movement + "disable_onloading", + "disable_offloading", + # manipulate parameters + "update_offload_parameter", + "get_execution_device", + "get_offloaded_device", + "register_offload_module", + # manipulate forward + "unwrap_offload_forward", + # backwards compatibility: should be deprecated + "align_modules", + "align_module_device", +] + + +@contextlib.contextmanager +def disable_offloading(): + """ + When offloading is disabled, onloaded tensors remain onloaded in memory until exit + + ``` + with OffloadCache.disable_offloading(): + ... = cache["weight"] + ... = cache["weight"] # cache hit + ... = cache["weight"] # cache hit + + # upon exit, all onloaded weights are released + ``` + """ + with OffloadCache.disable_offloading(): + yield + + +@contextlib.contextmanager +def disable_onloading(): + """ + When onloading is disabled, tensors are not offloaded on access, and assignments do + not trigger offloading. This is mostly used to disable device movement for debugging + + ``` + with OffloadCache.disable_onloading(): + tensor = ... + cache["weight"] = tensor # assignments do not trigger onloading + cache["weight"] is tensor # tensor remains offloaded + ``` + """ + with OffloadCache.disable_onloading(): + yield + + +def update_offload_parameter(module: torch.nn.Module, name: str, data: torch.Tensor): + """ + Update the data of an existing parameter and its offload dict. Supports both + parameters of offloaded modules and non-offloaded modules + + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + """ + if isinstance(module._parameters, OffloadCache): + with module._parameters.disable_onloading(): + value = getattr(module, name) + value.copy_(module._parameters.offload(data)) + setattr(module, name, value) + + else: + getattr(module, name).copy_(data) + + +def get_execution_device(module: torch.nn.Module) -> torch.device | str: + """ + Get the device which inputs should be moved to before module execution. + + :param module: module to check, may be offloaded + :return: onload device of module + """ + if isinstance(module._parameters, OffloadCache): + return module._parameters.onload_device + + else: + return get_module_device(module) + + +def get_offloaded_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is offloaded to onto after forward pass + """ + with disable_onloading(): + return get_module_device(module) + + +def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module): + """ + Register a submodule with offloading if the parent module is offloaded + + :param base: module to attach submodule to + :param name: name of submodule + :param module: submodule to attach + """ + cache = base._parameters + if isinstance(cache, OffloadCache): + offload_module( + module, cache.onload_device, cache.offload_device, no_split=False + ) + + base.register_module(name, module) + + +""" Implemented for backwards compatibility """ + + +@contextlib.contextmanager +def align_modules( + modules: torch.nn.Module | Iterable[torch.nn.Module], + execution_device: Optional[torch.device] = None, +): + """ + Context manager for onloading modules to a device, and disabling onload and offload + attempts triggered by forward calls. Used for sequential onloading of layers + + :param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload + :param execution_device: device to onload to + """ + with contextlib.ExitStack() as stack: + for module in modules: + stack.enter_context(align_module_device(module, execution_device)) + yield + + +@contextlib.contextmanager +def align_module_device( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): + """ + Context manager that moves a module's parameters to the specified execution device. + + :param module: Module with parameters to align + :param execution_device: If provided, overrides the module's execution device + within the context. Otherwise, use hook execution device or pass + """ + + if isinstance(module._parameters, OffloadCache): + assert isinstance(module._buffers, OffloadCache) + with module._parameters.disable_offloading(): + with patch_attr( + module._parameters, "onload_device", execution_device + ), patch_attr(module._buffers, "onload_device", execution_device): + yield + + else: + original_device = {} + for name, param in module.named_parameters(recurse=False): + original_device[name] = param.device + move_module_tensor(module, name, execution_device) + + try: + yield + finally: + for name, param in module.named_parameters(recurse=False): + device = original_device[name] + move_module_tensor(module, name, device) diff --git a/src/compressed_tensors/offload/cache/__init__.py b/src/compressed_tensors/offload/cache/__init__.py new file mode 100644 index 000000000..e840b8e1a --- /dev/null +++ b/src/compressed_tensors/offload/cache/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + +from .base import OffloadCache +from .cpu import CPUCache diff --git a/src/compressed_tensors/offload/cache/base.py b/src/compressed_tensors/offload/cache/base.py new file mode 100644 index 000000000..55fc1c2dc --- /dev/null +++ b/src/compressed_tensors/offload/cache/base.py @@ -0,0 +1,231 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from abc import ABC, abstractmethod +from collections.abc import MutableMapping +from typing import ClassVar, Literal, Optional + +import torch +import torch.distributed as dist + + +class OffloadCache(MutableMapping, ABC): + """ + Base class for offload caches. Subclasses must implement `offload` and `onload`. + Instances have similar behavior to dicts, except that tensors are offloaded when + assigned and onloaded when accessed. + + Typical usage: + ``` + module._parameters = cache_cls.from_mapping(module._parameters, onload_device) + tensor = ... + module._parameters["name"] = tensor # tensor is offloaded + onloaded_tensor = module._parameters["name"] # tensor is onloaded + ``` + + This class implements two contexts for more fine-grained control of device movement: + `OffloadCache.disable_offloading` and `OffloadCache.disable_onloading`. For more + info, see `compressed_tensors.offload::(disable_offloading|disable_onloading)` + """ + + onload_device: torch.device | str + offload_device: Optional[torch.device | str] + + # global flags for disabling + offloading_disabled: ClassVar[bool] = False + onloading_disabled: ClassVar[bool] = False + + # names -> offloaded tensors (populated from _parameters or _buffers) + offloaded_values: dict[str, torch.Tensor] + + # offloaded tensors -> onloaded tensors (only when offloading is disabled) + keep_onloaded_values: ClassVar[dict[torch.Tensor, torch.Tensor]] = dict() + + @classmethod + def cls_from_device( + cls, + device: Optional[torch.device | str | Literal["disk"]] = None, + ) -> type["OffloadCache"]: + """ + Get the subclass which implements offloading for the given `offload_device`. + Use `torch.distributed` to detect if the environment is distributed + + :param device: offload device used to find subclass + :return: subclass of `OffloadCache` + """ + from compressed_tensors.offload.cache.cpu import CPUCache + from compressed_tensors.offload.cache.device import DeviceCache + + device_type = torch.device(device).type if device != "disk" else "disk" + distributed = dist.is_available() and dist.is_initialized() + + match (device_type, distributed): + case ("cpu", False): + return CPUCache + case ("cuda", False): + return DeviceCache + case _: + raise NotImplementedError( + f"Offload of type {device} and " + f"distributed={distributed} has not been implemented" + ) + + @classmethod + def from_mapping( + cls, + mapping: MutableMapping[str, torch.Tensor | None], + onload_device: torch.device | str, + ): + """ + Initialize an instance from a given mapping, typically `Module._parameters` or + `Module._buffers`. Mapping values will be offloaded + + :param mapping: mapping used to populate cache + :param onload_device: device which tensors will be onloaded to + """ + instance = cls(onload_device=onload_device) + instance.offloaded_values = { + name: instance.offload(tensor) for name, tensor in mapping.items() + } + + return instance + + def __init__(self, onload_device: torch.device | str): + super().__init__() + self.onload_device = onload_device + self.offloaded_values = dict() + + @abstractmethod + def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor: + """ + Given an offloaded tensor, returns that tensor after onloading + + :param offloaded: offloaded tensor + :return: onloaded tensor + """ + raise NotImplementedError() + + @abstractmethod + def offload(self, tensor: torch.Tensor | None) -> torch.Tensor: + """ + Given a tensor, returns that tensor after offloading + + :param tensor: tensor to offload + :return: offloaded tensor + """ + raise NotImplementedError() + + def __getitem__(self, key: str) -> torch.Tensor: + """ + Onload a tensor + + If called within the `disable_offloading` context, a strong reference of the + onloaded tensor is kept so that future accesses will not require device movement + + :param key: name of tensor to access + :return: onloaded tensor + """ + offloaded = self.offloaded_values[key] + + # when onloading is disabled, offloaded tensors can be accessed directly + if offloaded is None or self.onloading_disabled: + return offloaded + + # check for cache hit + if offloaded in self.keep_onloaded_values: + return self.keep_onloaded_values[offloaded] + + # onload value + onloaded = self.onload(offloaded) + + # when offloading is disabled, populate cache + if self.offloading_disabled: + self.keep_onloaded_values[offloaded] = onloaded + + return onloaded + + def __setitem__(self, key: str, value: torch.Tensor | None): + """ + Offload a tensor and add it to the cache. + + If called within the `disable_onloading` context, the tensor is not offloaded + and is instead assigned directly + + :param key: name of tensor + :param value: tensor value to offload + """ + if key in self: + del self[key] + + # when onloading is disabled, parameters can be access and assigned directly + if self.onloading_disabled: + self.offloaded_values[key] = value + return + + self.offloaded_values[key] = self.offload(value) + + def __delitem__(self, key: str): + """ + Remove the offloaded tensor associated with `key`. Any references to its + onloaded tensors held by this class are invalidated. + + :param key: name of tensor to invalidate + """ + offloaded = self.offloaded_values[key] + del self.offloaded_values[key] + + # remove strong ref + if offloaded in self.keep_onloaded_values: + del self.keep_onloaded_values[offloaded] + + def __contains__(self, key) -> bool: + return key in self.offloaded_values + + def __iter__(self): + return iter(self.offloaded_values) + + def __len__(self): + return len(self.offloaded_values) + + @classmethod + @contextlib.contextmanager + def disable_offloading(cls): + """ + Context to disable all offloading for offloaded modules which share this cache. + After a weight has been fetched once, that onloaded value is cached and + subsequent fetches will leverage the cache, reducing device movement + """ + if not OffloadCache.offloading_disabled: + OffloadCache.offloading_disabled = True + yield + OffloadCache.offloading_disabled = False + OffloadCache.keep_onloaded_values.clear() + else: + yield + + @classmethod + @contextlib.contextmanager + def disable_onloading(cls): + """ + Context to disable all onloading for offloaded modules which share this cache. + This is mostly used for debugging purposes, and allows the caller to directly + inspect offloaded tensors and directly assign offloaded tensors without copying + """ + if not OffloadCache.onloading_disabled: + OffloadCache.onloading_disabled = True + yield + OffloadCache.onloading_disabled = False + else: + yield diff --git a/src/compressed_tensors/offload/cache/cpu.py b/src/compressed_tensors/offload/cache/cpu.py new file mode 100644 index 000000000..5f98886a8 --- /dev/null +++ b/src/compressed_tensors/offload/cache/cpu.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.offload.cache.base import OffloadCache +from compressed_tensors.offload.utils import send_tensors + + +class CPUCache(OffloadCache): + """ + Handles offloading and onloading tensors from/to cpu memory + """ + + offload_device = torch.device("cpu") + + def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor: + """ + Onload a tensor from cpu to device + + :param key: cpu tensor to onload + :return: device tensor + """ + return send_tensors(offloaded, device=self.onload_device, copy=False) + + def offload(self, tensor: torch.Tensor | None) -> torch.Tensor: + """ + Offload a tensor from any device to cpu + + :param value: tensor on any device + :return: tensor on cpu + """ + return send_tensors(tensor, device=self.offload_device, copy=False) diff --git a/src/compressed_tensors/offload/cache/device.py b/src/compressed_tensors/offload/cache/device.py new file mode 100644 index 000000000..3afd9eb06 --- /dev/null +++ b/src/compressed_tensors/offload/cache/device.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.offload.cache.base import OffloadCache +from compressed_tensors.offload.utils import send_tensors + + +class DeviceCache(OffloadCache): + """ + Handles offloading and onloading tensors from/to device memory. Onloading + tensors is a no-op. + """ + + def __init__(self, onload_device: torch.device | str): + self.onload_device = onload_device + self.offload_device = onload_device + self.offloaded_values = dict() + + def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor: + """ + No op, offloaded tensors are already on device + + :param key: cpu tensor to onload + :return: device tensor + """ + assert offloaded.device == self.onload_device + return offloaded + + def offload(self, tensor: torch.Tensor | None) -> torch.Tensor: + """ + Offload a tensor from any device to a device + + :param value: tensor on any device + :return: tensor on cpu + """ + return send_tensors(tensor, device=self.offload_device, copy=False) diff --git a/src/compressed_tensors/offload/dispatch.py b/src/compressed_tensors/offload/dispatch.py new file mode 100644 index 000000000..5206e10f4 --- /dev/null +++ b/src/compressed_tensors/offload/dispatch.py @@ -0,0 +1,228 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Container +from copy import deepcopy +from functools import partial +from typing import Literal, Optional, TypeVar + +import torch +from compressed_tensors.offload.module import offload_module, remove_module_offload +from compressed_tensors.offload.utils import get_module_sizes +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.binary_search import SearchFailureError, max_binary_search +from loguru import logger +from transformers import PreTrainedModel + + +__all__ = [ + "offload_model", + "dispatch_model", + "remove_dispatch", + "get_device_memory", +] + +ModelType = TypeVar("ModelType", bound=torch.nn.Module) + + +def offload_model( + model: ModelType, + onload_device: torch.device | str, + offload_device: Optional[torch.device | str | Literal["disk"]] = None, +) -> ModelType: + """ + Offload a model to the `offload_device`. During forward passes, model weights will + be onloaded to the `onload_device` + + :param model: model to dispatch + :param onload_device: device to move weights to during forward pass + :param offload_device: device to offload weights to + :return: dispatched model + """ + # remove any previous dispatches + remove_dispatch(model) + + # offload modules in place + for module in model.modules(): + offload_module(module, onload_device, offload_device) + + return model + + +def dispatch_model( + model: ModelType, + device_memory: Optional[dict[torch.device, int]] = None, + extra_memory: Optional[int] = None, + no_split_modules: Optional[Container[str]] = None, +) -> ModelType: + """ + Dispatch a model for autoregressive generation. This means that modules are + dispatched evenly across available devices and kept onloaded if possible. If + onloading the entire model is not possible, some modules may be offloaded. + + Disclaimers: + * Optimal runtime assumes that modules are called in order of `model.modules()` + + :param model: model to dispatch + :param device_memory: optional dictionary mapping torch device to available memory. + If none is provided, all available devices will be used + :param extra_memory: the amount of memory to be reserved for activations + :param no_split_modules: names of module classes which should not be split + across multiple devices + :return: dispatched model + """ + # remove previous dispatches + remove_dispatch(model) + + # infer no_split_modules + if no_split_modules is None: + no_split_modules = getattr(model, "_no_split_modules", tuple()) + + # estimate activations memory requirement + if extra_memory is None: + if isinstance(model, PreTrainedModel): + extra_memory = ( + 1 # batch_size + * 2048 # seq_len + * getattr_chain(model, "_config.hidden_dim", 256) + * getattr(model, "dtype", torch.bfloat16).itemsize + ) + else: + extra_memory = 0 + + # collect devices + if device_memory is None: + device_memory: dict[torch.device, int] = get_device_memory() + if len(device_memory) <= 0: + raise MemoryError("Did not find any devices to dispatch model to") + + # collect module sizes + sizes = get_module_sizes(model, no_split_modules) + if len(sizes) <= 0: + raise ValueError("Model does not have any modules") + + # search for the best dispatch which maximizes extra memory across devices + try: + max_extra_memory = min(device_memory.values()) + extra_memory, (dispatch, _) = max_binary_search( + fn=partial(_get_greedy_dispatch, sizes, device_memory), + cond=(lambda result: len(result[0]) == len(sizes)), + start=extra_memory, + end=max_extra_memory, + ) + + # fallback: create a cpu dispatch + except SearchFailureError: + dispatch, device_memory = _get_greedy_dispatch( + sizes, device_memory, extra_memory + ) + assert len(dispatch) < len(sizes) + + last_device = dispatch[-1][1] if len(dispatch) else list(device_memory)[0] + sizes_dict = {module: size for module, size in sizes} + largest_offloaded_module = max(size for _, size in sizes[len(dispatch) :]) + + # pop off modules until all offloaded modules can fit in last device + while largest_offloaded_module > device_memory[last_device] - extra_memory: + if len(dispatch) <= 0: + raise ValueError( + f"Cannot fit no_split module of size {largest_offloaded_module} " + f"bytes into any device: {device_memory}" + ) + + module, last_device, _ = dispatch.pop(-1) + device_memory[last_device] += sizes_dict[module] + largest_offloaded_module = max(largest_offloaded_module, sizes_dict[module]) + + # fill dispatch back with cpu offloading + for module, _ in list(sizes[len(dispatch) :]): + dispatch.append((module, last_device, "cpu")) + + extra_memory = 0 + logger.warning("Forced to offload modules due to insufficient gpu resources") + + # dispatch + finally: + assert len(dispatch) == len(sizes) + for module, onload, offload in dispatch: + for submodule in module.modules(): + offload_module(submodule, onload, offload) + + logger.debug(f"Dispatched model with {extra_memory} bytes of extra memory") + return model + + +def get_device_memory() -> dict[torch.device, int]: + """ + Get the total memory of all available cuda devices + + :return: list of device memory dataclasses + """ + if not torch.cuda.is_available(): + return dict() + + return { + # TODO: extend to xpu, ect. + torch.device(f"cuda:{idx}"): torch.cuda.get_device_properties(idx).total_memory + for idx in range(torch.cuda.device_count()) + } + + +def remove_dispatch( + module: torch.nn.Module, onload_tensors: bool = False +) -> torch.nn.Module: + """ + Remove any existing dispatches from module + + :param onload_tensors: Whether to move tensors to the onloaded device, or keep them + on the offload device. Defaults to False. + :return: module with offloading functionality removed + """ + for submodule in module.modules(): + remove_module_offload(submodule, onload_tensors) + + return module + + +def _get_greedy_dispatch( + sizes: list[tuple[torch.nn.Module, int]], + device_memory: dict[torch.device, int], + extra_memory: int = 0, +) -> tuple[ + list[tuple[torch.nn.Module, torch.device, torch.device]], dict[torch.device, int] +]: + dispatch = list() + memory_remaining = deepcopy(device_memory) + + device_index = 0 + devices = list(memory_remaining.keys()) + + if len(devices) <= 0: + raise ValueError() + + for module, size in sizes: + while True: + if device_index >= len(devices): + return dispatch, memory_remaining + + device = devices[device_index] + if size > memory_remaining[device] - extra_memory: + device_index += 1 + continue + + dispatch.append((module, device, device)) + memory_remaining[device] -= size + break + + return dispatch, memory_remaining diff --git a/src/compressed_tensors/offload/module.py b/src/compressed_tensors/offload/module.py new file mode 100644 index 000000000..0798b3484 --- /dev/null +++ b/src/compressed_tensors/offload/module.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +from functools import wraps + +import torch +from compressed_tensors.offload.cache.base import OffloadCache +from compressed_tensors.offload.utils import send_tensors + + +def offload_module( + module: torch.nn.Module, + onload_device: torch.device | str, + offload_device: torch.device | str, +): + """ + Offload a module. Any existing parameters or buffers will be offloaded to the + offload device specified by the `cache`. Accessing module parameters or buffers will + cause them to be onloaded to the `onload_device`. + + Calling `forward` will result in input tensors being moved to the `onload_device`, + and any onloaded parameters or buffers will remain onloaded for the duration of + the forward call if `no_split` is set to `True`. + + :param module: module to offload + :param onload_device: device used to onload parameters and buffers + :param offload_device: device used to offload parameters and buffers + """ + cache_cls = OffloadCache.cls_from_device(offload_device) + module._parameters = cache_cls.from_mapping(module._parameters, onload_device) + module._buffers = cache_cls.from_mapping(module._buffers, onload_device) + + original_forward_func = module.forward.__func__ + module._original_forward_func = original_forward_func + + @wraps(original_forward_func) + def forward(self, *args, **kwargs): + if not OffloadCache.onloading_disabled: + args = send_tensors(args, device=onload_device) + kwargs = send_tensors(kwargs, device=onload_device) + + return self._original_forward_func(self, *args, **kwargs) + + module.forward = forward.__get__(module) + + return module + + +def remove_module_offload(module: torch.nn.Module, onload_tensors: bool = False): + """ + Remove any offloading applied to the module + + :param onload_tensors: Whether to move tensors to the onloaded device, or keep them + on the offload device. Defaults to False. + """ + if isinstance(module._parameters, OffloadCache): + assert isinstance(module._buffers, OffloadCache) + + if onload_tensors: + module._parameters = { + name: module._parameters.onload(param) + for name, param in module._parameters.offloaded_values.items() + } + module._buffers = { + name: module._buffers.onload(param) + for name, param in module._buffers.offloaded_values.items() + } + else: + module._parameters = module._parameters.offloaded_values + module._buffers = module._buffers.offloaded_values + + module.forward = module._original_forward_func.__get__(module) + del module._original_forward_func + + +@contextlib.contextmanager +def unwrap_offload_forward(module: torch.nn.Module): + """ + Upon entering, module forward function is unwrapped. Upon exiting the offloading + wrapper is added again. Any modifications made to the forward function while within + the context will be reflected upon exiting. + """ + if hasattr(module, "_original_forward_func"): + offload_forward = module.forward + module.forward = module._original_forward_func.__get__(module) + yield + module._original_forward_func = module.forward.__func__ + module.forward = offload_forward + + else: + yield diff --git a/src/compressed_tensors/offload/utils.py b/src/compressed_tensors/offload/utils.py new file mode 100644 index 000000000..06c18df70 --- /dev/null +++ b/src/compressed_tensors/offload/utils.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Container +from dataclasses import fields, is_dataclass +from itertools import chain +from typing import Optional, TypeVar + +import torch +from loguru import logger + + +__all__ = [ + "send_tensors", + "get_module_device", + "move_module_tensor", + "module_size", +] + +T = TypeVar("T") + + +def send_tensors(value: T, *args, **kwargs) -> T: + """ + Recursively identify and move tensors using `torch.Tensor.to` + + :param value: value containing tensors to move + :param args: arguments to `to` + :param kwargs: keyword arguments to `to` + :return: value with moved tensors + """ + match value: + case torch.nn.Parameter(): + data = value.to(*args, **kwargs) + # special case: avoid changing param pointer when possible + if data.data_ptr() == value.data_ptr(): + return value + return value.__class__(data, requires_grad=value.requires_grad) + case torch.Tensor(): + return value.to(*args, **kwargs) + case list(): + return [send_tensors(v, *args, **kwargs) for v in value] + case tuple(): + return tuple(send_tensors(v, *args, **kwargs) for v in value) + case dict(): + return {k: send_tensors(v, *args, **kwargs) for k, v in value.items()} + case _ if is_dataclass(value): + return type(value)( + **{ + f.name: send_tensors(getattr(value, f.name), *args, **kwargs) + for f in fields(value) + } + ) + case _: + return value + + +def get_module_device( + module: torch.nn.Module, default: Optional[torch.device] = None +) -> torch.device: + """ + Infer the device of a module using the first + parameter or buffer registered to the module + + :param module: module to check + :param default: default device if module does not have tensors or buffers + :return: device of module + """ + tensor = next(module.parameters(), next(module.buffers(), None)) + if tensor is not None: + return tensor.device + elif default is not None: + return default + else: + logger.warning( + f"Unable to get execution device of {module}, falling back to CPU" + ) + return torch.device("cpu") + + +def move_module_tensor( + module: torch.nn.Module, + name: str, + device: int | str | torch.device, +): + """ + Move a module's tensor to a new device + + :param module: module containing tensors to move + :param name: name of tensor to move + :param device: new devices + """ + if name in module._parameters: + module._parameters[name] = send_tensors(module._parameters[name], device=device) + + elif name in module._buffers: + module._buffers[name] = send_tensors(module._buffers[name], device=device) + + +def get_module_sizes( + model: torch.nn.Module, no_split_modules: Container[str] = tuple() +) -> list[tuple[torch.nn.Module, int]]: + """ + Returns a list of modules and their sizes. Only non-splittable modules are returned. + Non-splittable modules are modules specified by `no_split_modules` or modules with + direct parameters. + + :param model: model to get sizes from + :param no_split_modules: module class names which cannot be split + :return: list of modules and their sizes + """ + module_sizes = [] + + def dfs(module: torch.nn.Module): + # modules with direct parameters cannot be split + # otherwise, submodules could return a device that is different than params + direct_size = module_size(module, recurse=False) + no_split = module.__class__.__name__ in no_split_modules or direct_size > 0 + + total_size = module_size(module, recurse=no_split) + if total_size > 0: + module_sizes.append((module, total_size)) + + if not no_split: + for child in module.children(): + dfs(child) + + dfs(model) + + return module_sizes + + +def module_size(module: torch.nn.Module, recurse: bool = True) -> int: + """ + Get the size of the module's parameters and buffers in bytes + + :param module: module to check + :param recurse: whether calculate recursive size, or only direct tensors + :return: total size of module parameters and buffers + """ + from compressed_tensors.offload import disable_onloading + + with disable_onloading(): + tensors = chain( + module.parameters(recurse=recurse), module.buffers(recurse=recurse) + ) + return sum((tensor.nbytes for tensor in tensors), 0) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 573826e18..ca060cf91 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -363,28 +363,28 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): @wraps(forward_func_orig) # ensures docstring, names, etc are propagated def wrapped_forward(self, *args, **kwargs): - if not getattr(module, "quantization_enabled", True): + if not getattr(self, "quantization_enabled", True): # quantization is disabled on forward passes, return baseline # forward call - return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + return forward_func_orig.__get__(self, self.__class__)(*args, **kwargs) input_ = args[0] - compressed = module.quantization_status == QuantizationStatus.COMPRESSED + compressed = self.quantization_status == QuantizationStatus.COMPRESSED if scheme.input_activations is not None: # prehook should calibrate activations before forward call - input_ = forward_quantize(module, input_, "input", scheme.input_activations) + input_ = forward_quantize(self, input_, "input", scheme.input_activations) if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() self.weight.data = forward_quantize( - module, self.weight, "weight", scheme.weights + self, self.weight, "weight", scheme.weights ) # perform wrapped forward call - output = forward_func_orig.__get__(module, module.__class__)( + output = forward_func_orig.__get__(self, self.__class__)( input_, *args[1:], **kwargs ) @@ -395,14 +395,12 @@ def wrapped_forward(self, *args, **kwargs): if scheme.output_activations is not None: # forward-hook should calibrate/forward_quantize if ( - module.quantization_status == QuantizationStatus.CALIBRATION + self.quantization_status == QuantizationStatus.CALIBRATION and not scheme.output_activations.dynamic ): return output - output = forward_quantize( - module, output, "output", scheme.output_activations - ) + output = forward_quantize(self, output, "output", scheme.output_activations) return output # bind wrapped forward to module class so reference to `self` is correct diff --git a/src/compressed_tensors/utils/binary_search.py b/src/compressed_tensors/utils/binary_search.py new file mode 100644 index 000000000..e9a7a4022 --- /dev/null +++ b/src/compressed_tensors/utils/binary_search.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, TypeVar + + +T = TypeVar("T") + + +__all__ = ["SearchFailureError", "max_binary_search"] + + +class SearchFailureError(ValueError): + pass + + +def max_binary_search( + fn: Callable[[int], T], + cond: Callable[[T], bool], + start: int, + end: int, +) -> tuple[int, T]: + best_idx = None + best_val = None + + while start <= end: + mid = (start + end) // 2 + val = fn(mid) + + if cond(val): + # condition is true, search higher + best_idx, best_val = mid, val + start = mid + 1 + else: + # condition is false, search lower + end = mid - 1 + + if best_idx is None: + raise SearchFailureError() + + return best_idx, best_val diff --git a/tests/test_offload/cache/test_cpu.py b/tests/test_offload/cache/test_cpu.py new file mode 100644 index 000000000..fb87e5332 --- /dev/null +++ b/tests/test_offload/cache/test_cpu.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +from weakref import ref + +import pytest +import torch +from compressed_tensors.offload.cache.cpu import CPUCache +from tests.testing_utils import requires_gpu + + +ONLOAD_DEVICE = torch.device("cuda:0") +OFFLOAD_DEVICE = torch.device("cpu") + + +@pytest.fixture(scope="function") +def cache(): + return CPUCache(ONLOAD_DEVICE) + + +@pytest.mark.unit +@requires_gpu +def test_onloading(cache: CPUCache): + tensor = torch.ones(10) + cache["weight"] = tensor + onloaded = cache["weight"] + + assert type(onloaded) is type(tensor) + assert torch.equal(onloaded.to(tensor.device), tensor) + + +@pytest.mark.unit +@requires_gpu +def test_garbage_collect(cache: CPUCache): + cache["weight"] = torch.ones(10) + onloaded = cache["weight"] + + onloaded_ref = ref(onloaded) + del onloaded + gc.collect() + assert onloaded_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_offload(cache: CPUCache): + tensor = torch.ones(10, device=ONLOAD_DEVICE) + offloaded = cache.offload(tensor) + assert offloaded.device == OFFLOAD_DEVICE + assert torch.equal(offloaded.to(ONLOAD_DEVICE), tensor) + + +@pytest.mark.unit +@requires_gpu +def test_onload(cache: CPUCache): + tensor = torch.ones(10, device=ONLOAD_DEVICE) + onloaded = cache.onload(cache.offload(tensor)) + assert onloaded.device == ONLOAD_DEVICE + assert torch.equal(onloaded, onloaded) + + +@pytest.mark.unit +@requires_gpu +def test_disable_offloading(cache: CPUCache): + cache["weight"] = torch.ones(10) + + outside_onloaded = cache["weight"] + outside_onloaded_ref = ref(outside_onloaded) + assert outside_onloaded.device == ONLOAD_DEVICE + + with cache.disable_offloading(): + inside_onloaded = cache["weight"] + inside_onloaded_ref = ref(inside_onloaded) + assert inside_onloaded.device == ONLOAD_DEVICE + + del outside_onloaded + del inside_onloaded + gc.collect() + + assert outside_onloaded_ref() is None + assert inside_onloaded_ref() is not None + + assert outside_onloaded_ref() is None + assert inside_onloaded_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_disable_onloading(cache: CPUCache): + tensor = torch.ones(10) + cache.offloaded_values["weight"] = tensor + + with cache.disable_onloading(): + onloaded = cache["weight"] + assert onloaded is tensor + + assert onloaded is tensor + + +@pytest.mark.unit +@requires_gpu +def test_delete(cache: CPUCache): + cache["weight"] = torch.ones(10) + onloaded = cache["weight"] + onloaded_ref = ref(onloaded) + + with cache.disable_offloading(): + del cache["weight"] + del onloaded + gc.collect() + + assert onloaded_ref() is None + + assert onloaded_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_shared_attributes(cache: CPUCache): + assert cache.offload_device is CPUCache.offload_device + assert cache.offloading_disabled is CPUCache.offloading_disabled + assert cache.onloading_disabled is CPUCache.onloading_disabled + assert cache.keep_onloaded_values is CPUCache.keep_onloaded_values + + assert not hasattr(CPUCache, "onload_device") + assert not hasattr(CPUCache, "offloaded_values") diff --git a/tests/test_offload/test_dispatch.py b/tests/test_offload/test_dispatch.py new file mode 100644 index 000000000..5667fff98 --- /dev/null +++ b/tests/test_offload/test_dispatch.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +import torch +from compressed_tensors.offload.cache import CPUCache, OffloadCache +from compressed_tensors.offload.dispatch import ( + dispatch_model, + get_device_memory, + offload_model, +) +from compressed_tensors.offload.utils import module_size +from tests.testing_utils import requires_gpu +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class Decoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear0 = torch.nn.Linear(5, 5) + self.linear1 = torch.nn.Linear(5, 5) + + def forward(self, input): + return self.linear1(self.linear0(input)) + + +class Model(torch.nn.Module): + _no_split_modules = ["Decoder"] + + def __init__(self): + super().__init__() + self.decoder0 = Decoder() + self.decoder1 = Decoder() + + def forward(self, input): + return self.decoder1(self.decoder0(input)) + + +def assert_module_on_device(module: torch.nn.Module, device: torch.device | str): + assert not isinstance(module._parameters, CPUCache) + for name, param in module.named_parameters(): + assert torch.device(param.device) == torch.device(device), name + + +def assert_module_offloaded( + module: torch.nn.Module, + onload_device: torch.device | str, + offload_device: torch.device | str, + req_params: bool = False, +): + for name, submodule in module.named_modules(): + if isinstance(submodule, torch.nn.ModuleList): + continue + if req_params and module_size(submodule)[0] <= 0: + continue + + assert isinstance(submodule._parameters, OffloadCache), name + assert torch.device(submodule._parameters.onload_device) == torch.device( + onload_device + ) + assert torch.device(submodule._parameters.offload_device) == torch.device( + offload_device + ) + + +def has_memory_requirements(device_memory: dict[torch.device, int]): + real_device_memory = get_device_memory() + for key, req in device_memory.items(): + if key not in real_device_memory or real_device_memory[key] < req: + return False + + return True + + +@pytest.mark.unit +@requires_gpu +def test_dispatch_one_device(): + model = Model() + device_memory = {torch.device("cuda:0"): module_size(model)} + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform one device dispatch test, not enough device memory") + + dispatch_model(model, device_memory=device_memory) + assert_module_on_device(model, "cuda:0") + + +@pytest.mark.unit +@requires_gpu +def test_dispatch_two_devices(): + model = Model() + device_memory = { + torch.device("cuda:0"): module_size(model.decoder0), + torch.device("cuda:1"): module_size(model) - module_size(model.decoder0), + } + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform split dispatch test: not enough devices or memory") + + # first decoder on first device, rest on second device + dispatch_model(model, device_memory=device_memory) + assert_module_on_device(model.decoder0, "cuda:0") + assert_module_on_device(model.decoder1, "cuda:1") + + +@pytest.mark.unit +@requires_gpu +def test_dispatch_no_split(): + model = Model() + device_memory = { + torch.device("cuda:0"): module_size(model.decoder0.linear0), + torch.device("cuda:1"): module_size(model), + } + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform split dispatch test: not enough devices or mem") + + # first device is skipped: all ends up on second device + dispatch_model(model, device_memory=device_memory) + assert_module_on_device(model, "cuda:1") + + +@pytest.mark.unit +@requires_gpu +def test_dispatch_split(): + model = Model() + first_linear = model.decoder0.linear0 + device_memory = { + torch.device("cuda:0"): module_size(first_linear), + torch.device("cuda:1"): module_size(model) - module_size(first_linear), + } + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform split dispatch test: not enough devices or memory") + + # first linear on first device, rest on second device + dispatch_model(model, device_memory=device_memory, no_split_modules=tuple()) + assert_module_on_device(model.decoder0.linear0, "cuda:0") + assert_module_on_device(model.decoder0.linear1, "cuda:1") + assert_module_on_device(model.decoder1, "cuda:1") + + +@pytest.mark.unit +@requires_gpu +def test_dispatch_offloaded(): + model = Model() + device_memory = { + torch.device("cuda:0"): ( + module_size(model.decoder0.linear0) + module_size(model.decoder1) + ), + } + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform split dispatch test: not enough devices or mem") + + with patch("compressed_tensors.offload.dispatch.get_module_sizes") as mock_sizes: + # first two linears are disjoint, but not enough memory to fit decoder1 + mock_sizes.return_value = [ + (model.decoder0.linear0, module_size(model.decoder0.linear0)), + (model.decoder0.linear1, module_size(model.decoder0.linear1)), + (model.decoder1, module_size(model.decoder1)), + ] + + # first linear stays onloaded + # second linear is popped off to fit offloaded decoder1 + dispatch_model(model, device_memory=device_memory, no_split_modules=tuple()) + assert_module_on_device(model.decoder0.linear0, "cuda:0") + assert_module_offloaded(model.decoder0.linear1, "cuda:0", "cpu") + assert_module_offloaded(model.decoder1, "cuda:0", "cpu") + + +@pytest.mark.integration +@requires_gpu +@pytest.mark.parametrize("model_id", ["nm-testing/tinysmokellama-3.2"]) +@torch.inference_mode() +def test_offload_and_dispatch_model(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id).eval() + tokenizer = AutoTokenizer.from_pretrained(model_id) + + device_memory = {torch.device("cuda:0"): module_size(model)} + if not has_memory_requirements(device_memory): + pytest.skip("Cannot perform split dispatch test: not enough devices or mem") + + model.to("cuda:0") + sample = tokenizer("Hello my name is", return_tensors="pt") + sample = {k: v.to("cuda:0") for k, v in sample.items()} + true_logits = model(**sample).logits + + # offload entire model + model = offload_model(model, "cuda:0", "cpu") + offloaded_logits = model(**sample).logits + for child in model.children(): + assert_module_offloaded(child, "cuda:0", torch.device("cpu")) + assert torch.allclose(offloaded_logits, true_logits) + + # dispatch model and fits + model = dispatch_model(model, device_memory=device_memory, extra_memory=0) + dispatched_logits = model(**sample).logits + assert_module_on_device(model, "cuda:0") + assert torch.allclose(dispatched_logits, true_logits) + + # dispatch model with offload + device_memory[torch.device("cuda:0")] = device_memory[torch.device("cuda:0")] // 2 + model = dispatch_model(model, device_memory=device_memory, extra_memory=0) + dispatched_logits = model(**sample).logits + assert_module_on_device(model, "cuda:0") + assert torch.allclose(dispatched_logits, true_logits) diff --git a/tests/test_offload/test_interface.py b/tests/test_offload/test_interface.py new file mode 100644 index 000000000..6cbf58658 --- /dev/null +++ b/tests/test_offload/test_interface.py @@ -0,0 +1,174 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.offload import ( + align_module_device, + align_modules, + disable_offloading, + disable_onloading, + get_execution_device, + get_offloaded_device, + update_offload_parameter, +) +from compressed_tensors.offload.cache import CPUCache +from compressed_tensors.offload.module import offload_module +from tests.testing_utils import requires_gpu + + +ONLOAD_DEVICE = torch.device("cuda:0") +OFFLOAD_DEVICE = torch.device("cpu") + + +@pytest.fixture(scope="function") +def cache(): + return CPUCache(ONLOAD_DEVICE) + + +@pytest.fixture(scope="function") +def linear(): + return torch.nn.Linear(5, 5, bias=True, device=OFFLOAD_DEVICE) + + +@pytest.fixture(scope="function") +def offloaded_linear(linear, cache): + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + return linear + + +@pytest.mark.unit +@requires_gpu +def test_disable_offloading(): + cache1 = CPUCache(ONLOAD_DEVICE) + cache2 = CPUCache(ONLOAD_DEVICE) + + cache1["weight"] = torch.tensor(0, device=OFFLOAD_DEVICE) + cache2["weight"] = torch.tensor(1, device=OFFLOAD_DEVICE) + + with disable_offloading(): + assert cache1["weight"] in cache1.keep_onloaded_values.values() + assert cache2["weight"] in cache2.keep_onloaded_values.values() + + +@pytest.mark.unit +@requires_gpu +def test_disable_onloading(): + cache1 = CPUCache(ONLOAD_DEVICE) + cache2 = CPUCache(ONLOAD_DEVICE) + + cache1["weight"] = torch.tensor(0, device=OFFLOAD_DEVICE) + cache2["weight"] = torch.tensor(1, device=OFFLOAD_DEVICE) + + with disable_onloading(): + assert cache1["weight"].device == OFFLOAD_DEVICE + assert cache2["weight"].device == OFFLOAD_DEVICE + + +@pytest.mark.unit +@requires_gpu +@pytest.mark.parametrize("offload", (True, False)) +def test_update_offload_parameter(linear: torch.nn.Linear, cache, offload): + init_data = torch.tensor(0.0, device=OFFLOAD_DEVICE) + linear.weight = torch.nn.Parameter(init_data, requires_grad=False) + if offload: + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + + assert linear.weight == 0 + + update_offload_parameter(linear, "weight", 1) + assert linear.weight == 1 + + with disable_offloading(): + update_offload_parameter(linear, "weight", 2) + assert linear.weight == 2 + assert linear.weight == 2 + + with disable_onloading(): + update_offload_parameter(linear, "weight", 3) + assert linear.weight == 3 + assert linear.weight == 3 + + +@pytest.mark.unit +@requires_gpu +def test_get_execution_device(linear: torch.nn.Linear, cache): + assert get_execution_device(linear) == OFFLOAD_DEVICE + linear.to(ONLOAD_DEVICE) + assert get_execution_device(linear) == ONLOAD_DEVICE + + linear.to(OFFLOAD_DEVICE) + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + assert get_execution_device(linear) == ONLOAD_DEVICE + + with disable_onloading(): + assert get_execution_device(linear) == ONLOAD_DEVICE + + with disable_offloading(): + assert get_execution_device(linear) == ONLOAD_DEVICE + + +@pytest.mark.unit +@requires_gpu +def test_get_offloaded_device(linear: torch.nn.Linear, cache): + assert get_offloaded_device(linear) == OFFLOAD_DEVICE + linear.to(ONLOAD_DEVICE) + assert get_offloaded_device(linear) == ONLOAD_DEVICE + + linear.to(OFFLOAD_DEVICE) + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + assert get_offloaded_device(linear) == OFFLOAD_DEVICE + + with disable_onloading(): + assert get_offloaded_device(linear) == OFFLOAD_DEVICE + + with disable_offloading(): + assert get_offloaded_device(linear) == OFFLOAD_DEVICE + + +@pytest.mark.unit +@requires_gpu +def register_offload_module(linear: torch.nn.Linear, cache): + sub1 = torch.nn.Linear(1, 1) + register_offload_module(linear, "sub1", sub1) + assert linear.sub1 is sub1 + + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + sub2 = torch.nn.Linear(1, 1) + register_offload_module(linear, "sub2", sub2) + assert linear.sub2 is sub2 + assert sub2.weight.device == ONLOAD_DEVICE + + +@pytest.mark.unit +@requires_gpu +def test_align_modules(offloaded_linear: torch.nn.Linear): + linear = torch.nn.Linear(1, 1, device=ONLOAD_DEVICE) + + with align_modules((linear, offloaded_linear), OFFLOAD_DEVICE): + assert linear.weight.device == OFFLOAD_DEVICE + assert offloaded_linear.weight.device == OFFLOAD_DEVICE + + +@pytest.mark.unit +@requires_gpu +@pytest.mark.parametrize("offload", (True, False)) +def test_align_module_device(linear: torch.nn.Linear, cache, offload): + if offload: + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + else: + linear.to(ONLOAD_DEVICE) + + with align_module_device(linear, OFFLOAD_DEVICE): + assert linear.weight.device == OFFLOAD_DEVICE diff --git a/tests/test_offload/test_module.py b/tests/test_offload/test_module.py new file mode 100644 index 000000000..afafa36cc --- /dev/null +++ b/tests/test_offload/test_module.py @@ -0,0 +1,213 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +from weakref import ref + +import pytest +import torch +from compressed_tensors.offload import disable_offloading, disable_onloading +from compressed_tensors.offload.cache.cpu import CPUCache +from compressed_tensors.offload.module import offload_module +from tests.testing_utils import requires_gpu + + +ONLOAD_DEVICE = torch.device("cuda:0") +OFFLOAD_DEVICE = torch.device("cpu") + + +@pytest.fixture(scope="function") +def cache(): + return CPUCache(ONLOAD_DEVICE) + + +@pytest.fixture(scope="function") +def linear(): + return torch.nn.Linear(5, 5, bias=True, device=OFFLOAD_DEVICE) + + +@pytest.fixture(scope="function") +def offloaded_linear(linear, cache): + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + return linear + + +@pytest.fixture(scope="function") +def input(): + return torch.zeros(6, device=OFFLOAD_DEVICE) + + +@pytest.mark.unit +@requires_gpu +def test_onloading(linear: torch.nn.Linear, cache): + weight = linear.weight + bias = linear.bias + + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + onloaded_weight = linear.weight + onloaded_bias = linear.bias + + assert onloaded_weight.device == ONLOAD_DEVICE + assert onloaded_bias.device == ONLOAD_DEVICE + + assert type(onloaded_weight) is type(weight) + assert type(onloaded_bias) is type(bias) + assert torch.equal(onloaded_weight.to(weight.device), weight) + assert torch.equal(onloaded_bias.to(bias.device), bias) + + +@pytest.mark.unit +@requires_gpu +def test_garbage_collect(offloaded_linear: torch.nn.Linear): + weight_ref = ref(offloaded_linear.weight) + bias_ref = ref(offloaded_linear.bias) + + del offloaded_linear + gc.collect() + + assert weight_ref() is None + assert bias_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_disable_offloading(offloaded_linear: torch.nn.Linear): + outside_onloaded = offloaded_linear.weight + outside_onloaded_ref = ref(outside_onloaded) + assert outside_onloaded.device == ONLOAD_DEVICE + + with disable_offloading(): + inside_onloaded = offloaded_linear.weight + inside_onloaded_ref = ref(inside_onloaded) + assert inside_onloaded.device == ONLOAD_DEVICE + + del outside_onloaded + del inside_onloaded + gc.collect() + + assert outside_onloaded_ref() is None + assert inside_onloaded_ref() is not None + + assert outside_onloaded_ref() is None + assert inside_onloaded_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_disable_onloading(linear: torch.nn.Linear, cache): + offloaded_weight = linear.weight + + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + + with disable_onloading(): + weight = linear.weight + assert weight is offloaded_weight + + # new parameter assignments are direct + new_param = torch.nn.Parameter(torch.ones(5, device=ONLOAD_DEVICE)) + linear.new_param = new_param + assert linear.new_param is new_param + + assert weight is offloaded_weight + + +@pytest.mark.unit +@requires_gpu +def test_delete(offloaded_linear: torch.nn.Linear): + weight_ref = ref(offloaded_linear.weight) + bias_ref = ref(offloaded_linear.bias) + + del offloaded_linear.weight + del offloaded_linear.bias + gc.collect() + + assert weight_ref() is None + assert bias_ref() is None + + +@pytest.mark.unit +@requires_gpu +def test_forward_call(linear: torch.nn.Linear, cache): + def forward(self, input: torch.Tensor) -> torch.Tensor: + assert input.device == ONLOAD_DEVICE + return torch.nn.functional.linear(input, linear.weight, linear.bias) + + linear.forward = forward.__get__(linear) + + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + + with torch.no_grad(): + input = torch.zeros(5, device=OFFLOAD_DEVICE) + output = linear.forward(input) + assert output.device == ONLOAD_DEVICE + + +@pytest.mark.parametrize("param_device", (ONLOAD_DEVICE, OFFLOAD_DEVICE)) +@pytest.mark.parametrize("use_register_parameter", (True, False)) +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_register_parameter( + offloaded_linear: torch.nn.Linear, + param_device, + use_register_parameter, + requires_grad, +): + # register param + data = torch.ones(5, device=param_device) + param = torch.nn.Parameter(data, requires_grad=requires_grad) + if use_register_parameter: + offloaded_linear.register_parameter("param_name", param) + else: + offloaded_linear.param_name = param + + # new param is correctly onloaded + assert offloaded_linear.param_name.device == ONLOAD_DEVICE + assert torch.equal(offloaded_linear.param_name.to(param_device), param) + + +@pytest.mark.parametrize("param_device", (ONLOAD_DEVICE, OFFLOAD_DEVICE)) +@pytest.mark.parametrize("use_register_parameter", (True, False)) +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_register_parameter_invalidates( + offloaded_linear: torch.nn.Linear, + param_device, + use_register_parameter, + requires_grad, +): + with disable_offloading(): + # original weight is kept onloaded + onloaded_weight = offloaded_linear.weight + assert onloaded_weight in set(CPUCache.keep_onloaded_values.values()) + + # add new param + data = torch.ones(5, device=param_device) + param = torch.nn.Parameter(data, requires_grad=requires_grad) + if use_register_parameter: + offloaded_linear.register_parameter("weight", param) + else: + offloaded_linear.weight = param + + # new param is correct + assert offloaded_linear.weight.device == ONLOAD_DEVICE + assert torch.equal(offloaded_linear.weight.to(param_device), param) + + # original weight is invalidated + assert onloaded_weight not in set(CPUCache.keep_onloaded_values.values()) + + +def test_forward_signature(linear: torch.nn.Linear, cache): + original_signature = inspect.signature(linear.forward) + + offload_module(linear, ONLOAD_DEVICE, OFFLOAD_DEVICE) + assert inspect.signature(linear.forward) == original_signature