Skip to content

Commit

Permalink
Move some util functions from quantization.utils to torchao.utils
Browse files Browse the repository at this point in the history
Summary:

Moved
```
TORCH_VERSION_AFTER_2_(2/3/4)
get_model_size_in_bytes
unwrap_tensor_subclass
```

from quantization/utils.py to torchao/utils.py

Test Plan:
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 7, 2024
1 parent 335171f commit d5d6c05
Show file tree
Hide file tree
Showing 16 changed files with 108 additions and 105 deletions.
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)

from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down
3 changes: 2 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
swap_linear_with_mx_linear,
)

from torchao.quantization.utils import compute_error, TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AFTER_2_4

# trying to outsmart flake8
__has_cuda = torch.cuda.is_available()
Expand Down
10 changes: 5 additions & 5 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torchao.prototype.common.bitpacking import pack, unpack
import pytest
from torch.utils._triton import has_triton
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand All @@ -20,15 +20,15 @@ def test_uint3_to_int16_col_wise_cpu():
unpacked = unpack(packed, 3, False, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint4_to_uint8():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint4_to_uint8_compile():
Expand All @@ -40,7 +40,7 @@ def test_uint4_to_uint8_compile():
unpacked = unpack_compiled(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16():
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
Expand All @@ -67,4 +67,4 @@ def test_uint3_to_int16_col_wise():
packed = pack(test_tensor,16, 3, False)
unpacked = unpack(packed, 3, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
assert(unpadded.allclose(test_tensor))
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down Expand Up @@ -556,7 +556,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
self.assertTrue(torch.equal(res, ref))

# workaround for export path
from torchao.quantization.utils import unwrap_tensor_subclass
from torchao.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)

m = torch.export.export(m_unwrapped, example_inputs).module()
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MappingType,
)

from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2
from torchao.utils import TORCH_VERSION_AFTER_2_2

try:
# Only works for torch2.2 or newer.
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from .utils import (
_lm_eval_available,
_MultiInput,
TORCH_VERSION_AFTER_2_3,
)
from torchao.utils import (
find_multiple,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from typing import Any, Dict, Optional
from .unified import Quantizer

Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"compute_error",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4
import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Callable

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from .utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
groupwise_affine_quantize_tensor_from_qparams,
MappingType,
)
from .utils import find_multiple
from torchao.utils import find_multiple
from typing import Tuple, Optional, Callable, Dict, Any


Expand Down
86 changes: 0 additions & 86 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,10 @@
from torch.utils._python_dispatch import TorchDispatchMode
from packaging import version
import torch.nn.utils.parametrize as parametrize
from torchao.utils import find_multiple


__all__ = [
"find_multiple",
"compute_error",
"_apply_logging_hook",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
]

try:
Expand Down Expand Up @@ -87,67 +79,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

return rs


class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors

def unwrap_tensor_subclass(model, filter_fn=None):
for name, child in model.named_children():
# make sure child.weight is a tensor subclass
if (
isinstance(child, torch.nn.Linear) and
hasattr(child, "weight") and
type(child.weight) is not torch.Tensor and
type(child.weight) is not torch.nn.Parameter and
isinstance(child.weight, torch.Tensor) and
issubclass(type(child.weight), torch.Tensor)
):
parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass())
unwrap_tensor_subclass(child)
return model


# https://discuss.pytorch.org/t/finding-model-size/130275
def get_model_size_in_bytes(model):
s = 0
for p in model.parameters():
s += p.nelement() * p.element_size()
for b in model.buffers():
s += b.nelement() * b.element_size()
return s


class _MultiInput:

def __init__(self, inputs):
Expand All @@ -165,20 +96,3 @@ def cuda(self):
self.values = [
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
]


# TODO: quantization namespace is not the right place ot have this
if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
TORCH_VERSION_AFTER_2_4 = False

if version.parse(torch.__version__) >= version.parse("2.3.0.dev"):
TORCH_VERSION_AFTER_2_3 = True
else:
TORCH_VERSION_AFTER_2_3 = False

if version.parse(torch.__version__) >= version.parse("2.2.0.dev"):
TORCH_VERSION_AFTER_2_2 = True
else:
TORCH_VERSION_AFTER_2_2 = False
87 changes: 87 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from functools import reduce
from math import gcd

__all__ = [
"benchmark_model",
"profiler_runner",
"get_compute_capability",
"skip_if_compute_capability_less_than",
"benchmark_torch_function_in_microseconds",
"find_multiple",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
]


def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
Expand Down Expand Up @@ -65,3 +79,76 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
if n % k == 0:
return n
return n + k - (n % k)

# https://discuss.pytorch.org/t/finding-model-size/130275
def get_model_size_in_bytes(model):
s = 0
for p in model.parameters():
s += p.nelement() * p.element_size()
for b in model.buffers():
s += b.nelement() * b.element_size()
return s

class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors

def unwrap_tensor_subclass(model, filter_fn=None):
for name, child in model.named_children():
# make sure child.weight is a tensor subclass
if (
isinstance(child, torch.nn.Linear) and
hasattr(child, "weight") and
type(child.weight) is not torch.Tensor and
type(child.weight) is not torch.nn.Parameter and
isinstance(child.weight, torch.Tensor) and
issubclass(type(child.weight), torch.Tensor)
):
parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass())
unwrap_tensor_subclass(child)
return model

if version.parse(torch.__version__) >= version.parse("2.4.0.dev"):
TORCH_VERSION_AFTER_2_4 = True
else:
TORCH_VERSION_AFTER_2_4 = False

if version.parse(torch.__version__) >= version.parse("2.3.0.dev"):
TORCH_VERSION_AFTER_2_3 = True
else:
TORCH_VERSION_AFTER_2_3 = False

if version.parse(torch.__version__) >= version.parse("2.2.0.dev"):
TORCH_VERSION_AFTER_2_2 = True
else:
TORCH_VERSION_AFTER_2_2 = False

0 comments on commit d5d6c05

Please sign in to comment.