Skip to content

Commit

Permalink
Use torch.uint1 to torch.uint7 for Uintx tensor subclass
Browse files Browse the repository at this point in the history
Summary:
Previously we are using bit_width for uintx quantization, but we can actually use `dtype` directly.

But there are still some workaround to convert from torch dtype to bit_width right now, if we want to remove
all the hacks, we'd need to support Uintx tensor subclass properly and have `torch.uintx` dispatch to the tensor subclass
this is probably not the highest priority for now since good perf is more important.

Test Plan:
python test/dtypes/test_affine_quantized.py
pytest test/dtypes/test_uintx.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 14, 2024
1 parent 174e630 commit 54b5569
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 45 deletions.
45 changes: 44 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
)
from torchao.dtypes import (
to_affine_quantized,
)
import torch
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_5,
)

from torchao.dtypes.uintx.Uintx import _DTYPE_TO_BIT_WIDTH

class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -51,6 +55,45 @@ def test_weights_only(self):
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "sub byte dtype requires torch 2.3+")
def test_uintx_target_dtype(self):
from torchao.quantization.quant_api import uintx_weight_only
for dtype in _DTYPE_TO_BIT_WIDTH.keys():
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(l)
l = torch.compile(l)
l(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "sub byte dtype requires torch 2.3+")
def test_uintx_model_size(self):
from torchao.quantization.quant_api import uintx_weight_only
from torchao.utils import get_model_size_in_bytes
# scale size = 1/64 * 2 bytes = 1/32 bytes
# zero_point size = 1/64 * 4 bytes = 1/16 bytes
# dtype data size = 1 * bit_width/8 = bit_width/8 bytes
_dtype_to_ratio = {
torch.uint1: (1/8 + 1/16 + 1/32) / 2,
torch.uint2: (2/8 + 1/16 + 1/32) / 2,
torch.uint3: (3/8 + 1/16 + 1/32) / 2,
torch.uint4: (4/8 + 1/16 + 1/32) / 2,
torch.uint5: (5/8 + 1/16 + 1/32) / 2,
torch.uint6: (6/8 + 1/16 + 1/32) / 2,
torch.uint7: (7/8 + 1/16 + 1/32) / 2,
}
for dtype in _DTYPE_TO_BIT_WIDTH.keys():
l = torch.nn.Sequential(
torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device="cuda")
)
bf16_size = get_model_size_in_bytes(l)
# make sure it runs
uintx_weight_only(dtype)(l[0])
quantized_size = get_model_size_in_bytes(l)
self.assertTrue(bf16_size * _dtype_to_ratio[dtype] == quantized_size)

if __name__ == "__main__":
run_tests()
>
44 changes: 23 additions & 21 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import TORCH_VERSION_AFTER_2_5
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_5,
)

from torchao.quantization.quant_primitives import (
MappingType,
Expand All @@ -16,7 +19,12 @@
dequantize_affine,
)

bit_widths = (1, 2, 3, 4, 5, 6, 7)
# torch.uintx dtypes are introduced in 2.3
if TORCH_VERSION_AFTER_2_3:
dtypes = (torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7)
else:
dtypes = ()

group_sizes = [32, 64, 128]
devices = ["cpu", "cuda"]
@pytest.fixture(autouse=True)
Expand All @@ -36,57 +44,51 @@ def __init__(self, scale, device):
def forward(self, x):
return self.net(x)

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_model_quant(bit_width, group_size, device):
def test_uintx_weight_only_model_quant(dtype, group_size, device):
scale = 512
fp16 = Linear16(scale, device)
quantize_(fp16, uintx_weight_only(bit_width, group_size=group_size))
quantize_(fp16, uintx_weight_only(dtype, group_size=group_size))
uintx = torch.compile(fp16, fullgraph=True)
test_input = torch.randn(scale*2, dtype=torch.float16, device=device)
output = uintx.forward(test_input)
assert output != None, "model quantization failed"

@pytest.mark.parametrize("bit_width", bit_widths)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("group_size", group_sizes)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build")
def test_uintx_weight_only_quant(bit_width, group_size, device):
def test_uintx_weight_only_quant(dtype, group_size, device):
input_float = torch.randn((1, 256), dtype=torch.float16, device = device)
mapping_type = MappingType.SYMMETRIC
quant_min = 0
quant_max = 2 ** bit_width - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
target_dtype = torch.uint8
block_size = (1, group_size)

scale, zero_point = choose_qparams_affine(
input_float, mapping_type, block_size,
target_dtype, quant_min, quant_max, eps, torch.float32,
zero_point_dtype, True, zero_point_domain
dtype, eps=eps, scale_dtype=torch.float32,
zero_point_dtype=zero_point_dtype, preserve_zero=True, zero_point_domain=zero_point_domain
)

aqt = quantize_affine(
input_float, block_size, scale,
zero_point, target_dtype,
quant_min = quant_min,
quant_max = quant_max,
zero_point_domain = zero_point_domain
zero_point, dtype,
zero_point_domain=zero_point_domain
)
# Note: output will be uint8 tensor for sub byte tensors for now

q = to_uintx(aqt, bit_width, -1)
q = to_uintx(aqt, dtype, -1)
assert q != None, "quantization failed"
deqaunt = dequantize_affine(
q, block_size, scale,
zero_point, target_dtype,
quant_min = quant_min,
quant_max = quant_max,
zero_point_domain = zero_point_domain
zero_point, dtype,
zero_point_domain=zero_point_domain
)
assert deqaunt != None, "deqauntization failed"
2 changes: 2 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

aten = torch.ops.aten


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -200,6 +201,7 @@ def from_float(

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
# Note: output will be uint8 tensor for sub byte tensors for now
int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand Down
35 changes: 29 additions & 6 deletions torchao/dtypes/uintx/Uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,30 @@
_dispatch__torch_dispatch__,
)
from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls

from torchao.utils import TORCH_VERSION_AFTER_2_3

aten = torch.ops.aten

# Note: Uintx does not work for torch 2.3 and below
_DTYPE_TO_BIT_WIDTH = {}
_BIT_WIDTH_TO_DTYPE = {}

if TORCH_VERSION_AFTER_2_3:
_DTYPE_TO_BIT_WIDTH = {
torch.uint1: 1,
torch.uint2: 2,
torch.uint3: 3,
torch.uint4: 4,
torch.uint5: 5,
torch.uint6: 6,
torch.uint7: 7,
}

_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
else:
print("uintx feature need torch 2.3+, please upgrade pytorch")


class UintxTensor(torch.Tensor):
"""
Splits int data into packed shards based on bit size
Expand Down Expand Up @@ -90,15 +110,18 @@ def get_plain(self):
def apply_transformation(self, fn):
og = self.get_plain()
new = fn(og)
return self.from_uint8(new, self.bit_width, self.pack_dim)
dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width]
return self.from_uint8(new, dtype, self.pack_dim)

# temporary until kernels on packed tensors are created
def apply_fn_to_shards(self, fn):
new_shards = [fn(shard) for shard in self.get_shards()]
return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim)

@classmethod
def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):
def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1):
assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BITWIDTH.keys()}"
bit_width = _DTYPE_TO_BIT_WIDTH[dtype]
shards = pack(int_data, bit_width, dim=pack_dim)
shape = list(int_data.shape)
shape[pack_dim] = shape[pack_dim] * bit_width // 8
Expand All @@ -107,7 +130,6 @@ def from_uint8(cls, int_data: torch.Tensor, bit_width, pack_dim: int = -1):

implements = UintxTensor.implements


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down Expand Up @@ -137,16 +159,17 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8))
)

# quantization api integrations
to_uintx = UintxTensor.from_uint8

@dataclass(frozen=True)
class UintxLayoutType(LayoutType):
bit_width: int
dtype: torch.dtype
pack_dim: int = -1

def post_process(self, input: torch.Tensor) -> torch.Tensor:
return to_uintx(input, self.bit_width, self.pack_dim)
return to_uintx(input, self.dtype, self.pack_dim)

@register_layout_cls(UintxLayoutType)
class UintxAQTLayout(PlainAQTLayout):
Expand Down
24 changes: 12 additions & 12 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,36 +483,36 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
"""
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
x is the number of bits specified by the `bit_width` argument
x is the number of bits specified by `dtype`
Args:
`dtype`: torch.uint1 to torch.uint7 sub byte dtypes
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, defaults to 64
`pack_dim`: the dimension we use for packing, defaults to -1
"""
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import to_affine_quantized
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

layout_type = UintxLayoutType(bit_width=bit_width, pack_dim=pack_dim)
def apply_uintx_weight_only_quant(weight):
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
quant_min = 0
quant_max = 2**bit_width - 1
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT

return to_affine_quantized(
weight, mapping_type, block_size, torch.uint8,
quant_min = quant_min, quant_max = quant_max,
eps = eps, zero_point_dtype=zero_point_dtype,
weight, mapping_type, block_size, dtype,
eps=eps, zero_point_dtype=zero_point_dtype,
zero_point_domain=zero_point_domain,
layout_type=layout_type,
)
Expand Down
17 changes: 12 additions & 5 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,21 @@ class ZeroPointDomain(Enum):
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
}
_SUB_BYTE_DTYPE_BOUNDS: Dict[torch.dtype, Tuple[int, int]] = {}

if TORCH_VERSION_AFTER_2_3:
_DTYPE_TO_QVALUE_BOUNDS.update({
_SUB_BYTE_DTYPE_BOUNDS = {
torch.uint1: (0, 2**1-1),
torch.uint2: (0, 2**2-1),
torch.uint3: (0, 2**3-1),
torch.uint4: (0, 2**4-1),
torch.uint5: (0, 2**5-1),
torch.uint6: (0, 2**6-1),
torch.uint7: (0, 2**7-1),
})
}
_DTYPE_TO_QVALUE_BOUNDS.update(
_SUB_BYTE_DTYPE_BOUNDS
)


quant_lib = torch.library.Library("quant", "FRAGMENT")
Expand Down Expand Up @@ -213,6 +217,10 @@ def _quantize_affine(
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
# torch.uintx dtypes yet
if output_dtype in _SUB_BYTE_DTYPE_BOUNDS:
output_dtype = torch.uint8
return _quantize_affine_no_dtype_cast(
input,
block_size,
Expand Down Expand Up @@ -325,10 +333,9 @@ def _dequantize_affine(
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""

# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
if input_dtype not in _SUB_BYTE_DTYPE_BOUNDS:
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
Expand Down

0 comments on commit 54b5569

Please sign in to comment.