Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/test/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ runs:
with:
venv: ${{ inputs.venv }}
name: compressed
extra: "[dev,accelerate]"
extra: "[dev]"

- name: clean up
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Set Env
run: pip3 install --upgrade pip setuptools
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,accelerate]
run: pip3 install .[dev]
- name: clean up
run: |
echo "cleaning up disk space as GHA runner has limited disk size."
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _setup_install_requires() -> List:

def _setup_extras() -> Dict:
return {
"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3", "transformers<5.0"],
"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3", "transformers<5.0", "accelerate"],
"accelerate": ["accelerate"]
}

Expand Down
6 changes: 0 additions & 6 deletions src/compressed_tensors/linear/compressed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ def from_linear(
# mark module as compressed
module.quantization_status = QuantizationStatus.COMPRESSED

# handles case where forward is wrapped in new_forward by accelerate hooks
if hasattr(module, "_old_forward"):
module._old_forward = CompressedLinear.forward.__get__(
module, CompressedLinear
)

return module

def forward(self, input: Tensor) -> Tensor:
Expand Down
13 changes: 7 additions & 6 deletions src/compressed_tensors/offload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
"""
cache = base._parameters
if isinstance(cache, OffloadCache):
offload_module(
module, cache.onload_device, cache.offload_device, no_split=False
)
offload_module(module, cache.onload_device, cache.offload_device)

base.register_module(name, module)

Expand Down Expand Up @@ -178,9 +176,12 @@ def align_module_device(
if isinstance(module._parameters, OffloadCache):
assert isinstance(module._buffers, OffloadCache)
with module._parameters.disable_offloading():
with patch_attr(
module._parameters, "onload_device", execution_device
), patch_attr(module._buffers, "onload_device", execution_device):
if execution_device is not None:
with patch_attr(
module._parameters, "onload_device", execution_device
), patch_attr(module._buffers, "onload_device", execution_device):
yield
else:
yield

else:
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/offload/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
def offload_model(
Comment thread
brian-dellabetta marked this conversation as resolved.
model: ModelType,
onload_device: torch.device | str,
offload_device: Optional[torch.device | str | Literal["disk"]] = None,
offload_device: torch.device | str | Literal["disk"] = torch.device("cpu"),
) -> ModelType:
"""
Offload a model to the `offload_device`. During forward passes, model weights will
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
QuantizedAttentionImpl,
QuantizedKVCache,
)
from compressed_tensors.offload import unwrap_offload_forward
from compressed_tensors.quantization import (
ActivationOrdering,
DynamicType,
Expand All @@ -37,7 +38,6 @@
)
from compressed_tensors.quantization.utils import strategy_cdiv
from compressed_tensors.utils import (
disable_hf_hook,
get_execution_device,
get_head_dim,
get_num_attn_heads,
Expand Down Expand Up @@ -134,7 +134,7 @@ def initialize_module_for_quantization(
force_zero_point=force_zero_point,
)

with disable_hf_hook(module):
with unwrap_offload_forward(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)
Expand Down
35 changes: 0 additions & 35 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import torch
from compressed_tensors import TRANSFORM_CONFIG_NAME
from compressed_tensors.transform import TransformConfig, TransformFactory
from compressed_tensors.utils.offload import has_offloaded_params


__all__ = ["apply_transform_config"]
Expand All @@ -37,35 +34,3 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):

# attach config to model for compression/serialization
setattr(model, TRANSFORM_CONFIG_NAME, config)

# ensure that tied weight transforms can be serialized without aliases
# In the future, this could be done by transformers or model compressor
# which would make this more robust to changing dispatches after transforms
_tie_offloaded_tensors(model)


def _tie_offloaded_tensors(model: torch.nn.Module):
"""
When accelerate replaces tensors with meta tensors during offloading, the meta
tensors may not be identical, even if the offloaded values are identical.

However, transformers can only serialize correctly if meta tensors are identical
(see transformers#39263).

This function collects all meta tensors which have shared offloaded values and sets
those tensors to be identical so that they can be removed during serialization

:param model: model potentially containing offloaded meta tensors to fix
"""

# ensure that if a location shares an offloaded tensor pointers, that the
# meta tensor is also identical (assigned to the first instance of parameter)
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
for module in model.modules():
if has_offloaded_params(module):
for key, _ in module.named_parameters(recurse=False):
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()

if offloaded_ptr not in ptr_to_meta:
ptr_to_meta[offloaded_ptr] = getattr(module, key)
setattr(module, key, ptr_to_meta[offloaded_ptr])
14 changes: 3 additions & 11 deletions src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
initialize_hooked_kv_cache,
register_key_hook,
)
from compressed_tensors.offload import OffloadCache
from compressed_tensors.registry.registry import RegistryMixin, T
from compressed_tensors.transform import (
TransformArgs,
Expand All @@ -34,8 +35,6 @@
)
from compressed_tensors.utils import (
align_module_device,
delete_offload_module,
has_offloaded_params,
match_named_modules,
patch_attr,
register_offload_module,
Expand Down Expand Up @@ -116,13 +115,6 @@ def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
:param module: target module to apply transforms to
:param args: defines how the transform will be applied to the target module
"""
if has_offloaded_params(module):
if module._hf_hook.place_submodules:
raise NotImplementedError(
"Applying transforms to offloaded submodules with "
"`place_submodules=True` is not supported"
)

# create transform as submodule
transform_name = f"{self.name}_{args.location}"
transform = self.create_transform(module, args)
Expand Down Expand Up @@ -150,13 +142,13 @@ def input_hook(_, args):
if self.scheme.requires_grad:
# for training, the weight changes with every forward pass
# so we can leverage parametrization to propagate the gradient
if has_offloaded_params(module):
if isinstance(module._parameters, OffloadCache):
Comment thread
kylesayrs marked this conversation as resolved.
raise ValueError("Offloaded training is not supported")
P.register_parametrization(module, "weight", transform)

else:
# transform is no longer needed (unfusing is not supported)
delete_offload_module(module, transform_name)
delattr(module, transform_name)

# register output transformation hook
elif args.location == TransformLocation.OUTPUT:
Expand Down
Loading