Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
ff376e0
add offloading logic
kylesayrs Dec 17, 2025
bb07f1c
remove tracing requirement
kylesayrs Dec 17, 2025
362c115
fix name change
kylesayrs Dec 17, 2025
59904c1
simplify utils
kylesayrs Dec 17, 2025
32ab4e2
rename
kylesayrs Dec 18, 2025
9a7794a
add all necessary functions
kylesayrs Dec 18, 2025
9cdd4be
fix rebase
kylesayrs Dec 18, 2025
6f41d1d
add test_forward_call
kylesayrs Dec 18, 2025
8cb0eee
remove accelerate dep
kylesayrs Dec 18, 2025
53b9561
docstring
kylesayrs Dec 18, 2025
82e0111
move around
kylesayrs Dec 18, 2025
4f5afb0
fix global access
kylesayrs Dec 18, 2025
6d26e83
add back module disable_offloading
kylesayrs Dec 18, 2025
e830d85
remove disk cache
kylesayrs Dec 18, 2025
7ba45b2
docstrings, todos
kylesayrs Dec 18, 2025
48d40ef
cleanup, add tests
kylesayrs Dec 18, 2025
60f5cb5
allow direct parameter assignment in disable_onloading
kylesayrs Dec 18, 2025
6587de9
WIP: dispatch
kylesayrs Dec 18, 2025
9831f20
WIP
kylesayrs Dec 18, 2025
cab01ea
add buffers
kylesayrs Dec 18, 2025
1cbcd8e
WIP: dispatch
kylesayrs Dec 18, 2025
9a2e319
fix module, tested dispatch
kylesayrs Dec 18, 2025
38dac49
move helpers
kylesayrs Dec 18, 2025
2830fb2
match signature
kylesayrs Dec 19, 2025
611c8cf
use setitem to invalidate cache when updating
kylesayrs Dec 19, 2025
8607704
fix dispatch typo, add tests
kylesayrs Dec 19, 2025
f923d23
remove shared module logic
kylesayrs Dec 19, 2025
16cca67
docstrings
kylesayrs Dec 19, 2025
6037bed
docstrings
kylesayrs Dec 19, 2025
3a4b045
specialize to cpu cache
kylesayrs Dec 22, 2025
7de029b
move most logic to base
kylesayrs Dec 23, 2025
cca3b87
docstrings
kylesayrs Dec 23, 2025
3aa8400
WIP
kylesayrs Dec 31, 2025
fa1eae4
works, is simpler and likely better, tracing will have to be done wit…
kylesayrs Jan 1, 2026
af56a02
fully adopt
kylesayrs Jan 2, 2026
d64a1ab
update tests
kylesayrs Jan 2, 2026
9204038
fix update_offload_parameter, cleanup
kylesayrs Jan 2, 2026
a5dac7a
add special case for parameter moving
kylesayrs Jan 2, 2026
4fb40c2
only unwrap forward
kylesayrs Jan 2, 2026
d02790d
fix resolving
kylesayrs Jan 2, 2026
bbef1f9
share classvars across all subclasses, docstrings
kylesayrs Jan 2, 2026
ee4dd11
add interface tests
kylesayrs Jan 2, 2026
7decf7e
remove excess code from dispatch, docstrings
kylesayrs Jan 2, 2026
9079361
fix typo for dispatch_model, add default hints
kylesayrs Jan 2, 2026
81620ba
remove global access
kylesayrs Jan 3, 2026
672918f
better remove dispatch
kylesayrs Jan 3, 2026
64ddcf5
fix typo
kylesayrs Jan 3, 2026
fd6acef
remove weakref dict, through experimentation it was found that it was…
kylesayrs Jan 3, 2026
2015ff8
simplify global flags
kylesayrs Jan 3, 2026
4ed7871
guarantee avoid excess device movement
kylesayrs Jan 5, 2026
3509b00
balanced dispatch, more intelligent offload_module
kylesayrs Jan 5, 2026
99374e3
fix docstring
kylesayrs Jan 5, 2026
c59d23c
add offloading, binary search
kylesayrs Jan 7, 2026
16adf7a
simplify
kylesayrs Jan 7, 2026
b34e744
cleanup utils, fix typo with disable_onloading
kylesayrs Jan 7, 2026
aabad78
fix typo
kylesayrs Jan 7, 2026
a42bc5a
no split does not control disabling offloading
kylesayrs Jan 8, 2026
1c4ac41
suggestions: fix offloading dispatch bug, rename things
kylesayrs Jan 9, 2026
b9cf3a4
fix typo
kylesayrs Jan 9, 2026
3d30bdf
remove no_split arguments where applicable
kylesayrs Jan 12, 2026
4432f13
Merge branch 'main' into kylesayrs/torch_offloader
dsikka Jan 19, 2026
cc3dbca
Merge branch 'main' into kylesayrs/torch_offloader
kylesayrs Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ quality:
@echo "Running copyright checks";
python utils/copyright.py quality $(PYCHECKGLOBS)
@echo "Running python quality checks";
black --check $(PYCHECKDIRS);
black --target-version py310 --check $(PYCHECKDIRS);
isort --check-only $(PYCHECKDIRS);
flake8 $(PYCHECKDIRS);

Expand All @@ -17,7 +17,7 @@ style:
@echo "Running copyright style";
python utils/copyright.py style $(PYCHECKGLOBS)
@echo "Running python styling";
black $(PYCHECKDIRS);
black --target-version py310 $(PYCHECKDIRS);
isort $(PYCHECKDIRS);

# run tests for the repo
Expand Down
11 changes: 10 additions & 1 deletion src/compressed_tensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,14 @@
from .compressors import *
from .config import *
from .quantization import QuantizationConfig, QuantizationStatus
from .utils import *

# avoid resolving compressed_tensors.offload as compressed_tensors.utils.offload
from .utils.offload import *
from .utils.helpers import *
from .utils.internal import *
from .utils.match import *
from .utils.permutations_24 import *
from .utils.safetensors_load import *
from .utils.semi_structured_conversions import *
from .utils.type import *
from .version import *
197 changes: 197 additions & 0 deletions src/compressed_tensors/offload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from typing import Iterable, Optional

import torch
from compressed_tensors.offload.cache import OffloadCache
from compressed_tensors.offload.dispatch import ( # noqa: F401
dispatch_model,
offload_model,
remove_dispatch,
)
from compressed_tensors.offload.module import offload_module, unwrap_offload_forward
from compressed_tensors.offload.utils import get_module_device, move_module_tensor
from compressed_tensors.utils.helpers import patch_attr


__all__ = [
# dispatch models
"offload_model",
"dispatch_model",
"remove_dispatch",
# control movement
"disable_onloading",
"disable_offloading",
# manipulate parameters
"update_offload_parameter",
"get_execution_device",
"get_offloaded_device",
"register_offload_module",
# manipulate forward
"unwrap_offload_forward",
# backwards compatibility: should be deprecated
"align_modules",
"align_module_device",
Comment thread
brian-dellabetta marked this conversation as resolved.
]


@contextlib.contextmanager
def disable_offloading():
"""
When offloading is disabled, onloaded tensors remain onloaded in memory until exit

```
with OffloadCache.disable_offloading():
... = cache["weight"]
... = cache["weight"] # cache hit
... = cache["weight"] # cache hit

# upon exit, all onloaded weights are released
```
"""
with OffloadCache.disable_offloading():
yield


@contextlib.contextmanager
def disable_onloading():
"""
When onloading is disabled, tensors are not offloaded on access, and assignments do
not trigger offloading. This is mostly used to disable device movement for debugging

```
with OffloadCache.disable_onloading():
tensor = ...
cache["weight"] = tensor # assignments do not trigger onloading
cache["weight"] is tensor # tensor remains offloaded
```
"""
with OffloadCache.disable_onloading():
yield


def update_offload_parameter(module: torch.nn.Module, name: str, data: torch.Tensor):
"""
Update the data of an existing parameter and its offload dict. Supports both
parameters of offloaded modules and non-offloaded modules

:param module: module containing the parameter to update
:param name: name of module parameter to update
:param data: tensor to update parameter with
"""
if isinstance(module._parameters, OffloadCache):
with module._parameters.disable_onloading():
value = getattr(module, name)
value.copy_(module._parameters.offload(data))
setattr(module, name, value)

else:
getattr(module, name).copy_(data)


def get_execution_device(module: torch.nn.Module) -> torch.device | str:
"""
Get the device which inputs should be moved to before module execution.

:param module: module to check, may be offloaded
:return: onload device of module
"""
if isinstance(module._parameters, OffloadCache):
return module._parameters.onload_device

else:
return get_module_device(module)


def get_offloaded_device(module: torch.nn.Module) -> torch.device:
"""
:param module: module to check
:return: device module is offloaded to onto after forward pass
"""
with disable_onloading():
return get_module_device(module)


def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
"""
Register a submodule with offloading if the parent module is offloaded

:param base: module to attach submodule to
:param name: name of submodule
:param module: submodule to attach
"""
cache = base._parameters
if isinstance(cache, OffloadCache):
offload_module(
module, cache.onload_device, cache.offload_device, no_split=False
)

base.register_module(name, module)


""" Implemented for backwards compatibility """


@contextlib.contextmanager
def align_modules(
modules: torch.nn.Module | Iterable[torch.nn.Module],
execution_device: Optional[torch.device] = None,
):
"""
Context manager for onloading modules to a device, and disabling onload and offload
attempts triggered by forward calls. Used for sequential onloading of layers

:param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
:param execution_device: device to onload to
"""
with contextlib.ExitStack() as stack:
for module in modules:
stack.enter_context(align_module_device(module, execution_device))
yield


@contextlib.contextmanager
def align_module_device(
module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
"""
Context manager that moves a module's parameters to the specified execution device.

:param module: Module with parameters to align
:param execution_device: If provided, overrides the module's execution device
within the context. Otherwise, use hook execution device or pass
"""

if isinstance(module._parameters, OffloadCache):
assert isinstance(module._buffers, OffloadCache)
with module._parameters.disable_offloading():
with patch_attr(
module._parameters, "onload_device", execution_device
), patch_attr(module._buffers, "onload_device", execution_device):
yield

else:
original_device = {}
for name, param in module.named_parameters(recurse=False):
original_device[name] = param.device
move_module_tensor(module, name, execution_device)

try:
yield
finally:
for name, param in module.named_parameters(recurse=False):
device = original_device[name]
move_module_tensor(module, name, device)
17 changes: 17 additions & 0 deletions src/compressed_tensors/offload/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .base import OffloadCache
from .cpu import CPUCache
Loading
Loading