Skip to content

Commit d54399c

Browse files
updated uint4 and perchannel_symmetricweight based on new API pytorch#391
1 parent c0497b8 commit d54399c

File tree

5 files changed

+285
-210
lines changed

5 files changed

+285
-210
lines changed

test/dtypes/test_uint4.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
from torchao.dtypes.uint4 import (
33
UInt4Tensor,
4-
PerChannelSymmetricWeightUInt4Tensor,
4+
)
5+
from torchao.dtypes import (
6+
PerChannelSymmetricWeightUInt4Tensor
57
)
68
import unittest
79
from unittest import TestCase, main

torchao/dtypes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4+
from .perchannel_symmetricweight import PerChannelSymmetricWeightUInt4Tensor
45
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
56

67
__all__ = [
@@ -9,4 +10,5 @@
910
"UInt4Tensor"
1011
"AffineQuantizedTensor",
1112
"to_affine_quantized",
13+
"PerChannelSymmetricWeightUInt4Tensor",
1214
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import torch
2+
from torchao.dtypes.uint4 import pack_uint4, unpack_uint4
3+
from torchao.dtypes import UInt4Tensor
4+
from typing import Dict, Any
5+
from torchao.dtypes.utils import _implements
6+
from torchao.dtypes.utils import _ATEN_OP_OR_TORCH_FN_TABLE
7+
8+
SYMMETRIC_WEIGHT_OPS_TABLE: Dict[Any, Any] = {}
9+
10+
from torchao.dtypes.utils import _implements
11+
12+
def implements(aten_ops_or_torch_fns):
13+
return _implements(PerChannelSymmetricWeightUInt4Tensor, aten_ops_or_torch_fns)
14+
15+
def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
16+
# assumes symmetric quantization
17+
# assumes axis == 0
18+
# assumes dense memory format
19+
# TODO(future): relax ^ as needed
20+
21+
# default setup for affine quantization of activations
22+
eps = torch.finfo(torch.float32).eps
23+
24+
# get min and max
25+
min_val, max_val = torch.aminmax(x, dim=1)
26+
27+
# calculate scale and zero point based on min and max
28+
# reference: https://fburl.com/code/srbiybme
29+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
30+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
31+
device = min_val_neg.device
32+
33+
# reference: https://fburl.com/code/4wll53rk
34+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
35+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
36+
# ensure scale is the same dtype as the original tensor
37+
scale = torch.clamp(scale, min=eps).to(x.dtype)
38+
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
39+
40+
# quantize based on qmin/qmax/scale/zp
41+
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
42+
x_div = x.transpose(0, 1) / scale
43+
x_round = torch.round(x_div)
44+
x_zp = x_round + zero_point
45+
x_zp = x_zp.transpose(0, 1)
46+
quant = torch.clamp(x_zp, quant_min, quant_max)
47+
48+
if target_dtype == torch.uint4:
49+
# TODO: simplify (maybe implement to)
50+
quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
51+
quant.to(torch.uint8), scale
52+
)
53+
else:
54+
quant = quant.to(target_dtype)
55+
56+
return quant, scale, zero_point
57+
58+
class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor):
59+
@staticmethod
60+
def __new__(cls, elem, scales, **kwargs):
61+
return super().__new__(cls, elem, **kwargs)
62+
63+
def __init__(self, elem, scales, **kwargs):
64+
super().__init__(elem, **kwargs)
65+
66+
self.scales = scales
67+
68+
def __tensor_flatten__(self):
69+
return ["elem", "scales"], None
70+
71+
@staticmethod
72+
def __tensor_unflatten__(flattened, meta, outer_size, outer_stride):
73+
assert meta is None
74+
elem = flattened["elem"]
75+
scales = flattened["scales"]
76+
return PerChannelSymmetricWeightUInt4Tensor(elem, scales)
77+
78+
@classmethod
79+
80+
# inconsistently.
81+
82+
def from_unpacked(cls, unpacked, scales):
83+
return cls(pack_uint4(unpacked), scales)
84+
85+
@classmethod
86+
def __torch_function__(cls, func, types, args=(), kwargs=None):
87+
kwargs = {} if kwargs is None else kwargs
88+
89+
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
90+
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
91+
92+
with torch._C.DisableTorchFunctionSubclass():
93+
return func(*args, **kwargs)
94+
95+
@classmethod
96+
def __torch_dispatch__(cls, func, types, args, kwargs):
97+
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
98+
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
99+
100+
raise NotImplementedError(
101+
f"PerChannelSymmetricWeightUInt4Tensor dispatch: attempting to run {func}, this is not supported"
102+
)
103+
104+
105+
@classmethod
106+
def from_float(cls, w_fp32):
107+
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
108+
w_fp32, 0, 15, torch.uint4
109+
)
110+
w_int4 = w_int4.to(device=w_fp32.device)
111+
return w_int4
112+
113+
@implements([torch.ops.aten.addmm.default])
114+
def _(func, args, kwargs):
115+
bias, x, weight = args
116+
x_view = x.view(-1, x.shape[-1])
117+
y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales
118+
y = y.reshape(*x.shape[:-1], -1)
119+
if bias is not None:
120+
y += bias
121+
return y
122+
123+
@implements([torch.ops.aten.t.default])
124+
def _(func, args, kwargs):
125+
# TODO: add proper support for transpose
126+
(tensor,) = args
127+
unpacked = unpack_uint4(tensor.elem)
128+
transposed = torch.ops.aten.t.default(unpacked)
129+
return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
130+
transposed, tensor.scales
131+
)
132+
133+
@implements([torch.ops.aten.detach.default])
134+
def _(func, args, kwargs):
135+
(tensor,) = args
136+
return
137+
138+
if __name__ == "__main__":
139+
# test
140+
x = torch.randn(2, 3, 4)
141+
w = torch.randn(5, 4)
142+
b = torch.randn(5)
143+
y = PerChannelSymmetricWeightUInt4Tensor.from_float(w)
144+
# print(y)

0 commit comments

Comments
 (0)