Skip to content

Commit

Permalink
Fix dimension issues for int4 weight only quant path (#330)
Browse files Browse the repository at this point in the history
Summary:
Currently the accepted dimension of _quantized_linear is not clear, this PR fixes the issue.

Currently the "tensor_core_tiled" layout tensor does not do repacking in view operation, which is incorrect, this PR removes the view support (which is not needed right now), and restrict the use case to transpose op, and records the transpose status of the tensor instead of doing repacking for performance.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jun 10, 2024
1 parent c38fc4a commit 79f2c7f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 45 deletions.
29 changes: 29 additions & 0 deletions test/dtypes/test_aq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torchao.quantization.quant_api import get_apply_int4wo_quant
import torch
import unittest


class TestAQ(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
aqt = apply_int4wo_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

# transpose shape test
for _ in range(10):
t = t.t()
aqt = aqt.t()
shape = t.shape
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

if __name__ == "__main__":
run_tests()
6 changes: 3 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def _test_smooth_linear_impl(self, x_shape, lin_shape, device):
# rtol=0.00001), \
# 'y_smooth_fq_only not close to y_dynamic_q'

self.assertTrue(sqnr_smooth_fq.item() >= 40.0)
self.assertTrue(sqnr_dynamic_q.item() >= 40.0)
self.assertTrue(sqnr_fq.item() >= 40.0)
self.assertTrue(sqnr_smooth_fq.item() >= 40.0, f"got: {sqnr_smooth_fq.item()}")
self.assertTrue(sqnr_dynamic_q.item() >= 40.0, f"got: {sqnr_dynamic_q.item()}")
self.assertTrue(sqnr_fq.item() >= 40.0, f"got: {sqnr_fq.item()}")

# Restore backend
torch.backends.quantized.engine = orig_backend
Expand Down
84 changes: 42 additions & 42 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ def _apply_fn_to_data(self, fn):
fn(self.zero_point),
)

def _change_shape(self, shape):
return self.__class__(
self.int_data.view(shape), self.scale, self.zero_point
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -202,9 +197,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
if func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.view(tenssor.shape[::-1]), tenssor.scale, tenssor.zero_point
)
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
Expand Down Expand Up @@ -241,6 +238,7 @@ def __new__(
cls,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
):
kwargs = {}
kwargs["device"] = packed_weight.device
Expand All @@ -256,27 +254,30 @@ def __init__(
self,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
):
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"], []
return ["packed_weight", "scale_and_zero"], [self.transposed]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
return cls(packed_weight, scale_and_zero)
transposed, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed)

@classmethod
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero)
return cls(packed_weight, scale_and_zero, False)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
Expand All @@ -285,22 +286,15 @@ def to(self, *args, **kwargs):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device")
return self.__class__(
self.packed_weight.to(kwargs["device"]),
self.scale_and_zero.to(kwargs["device"])
self.scale_and_zero.to(kwargs["device"]),
self.transposed
)

def _apply_fn_to_data(self, fn):
self.packed_weight = fn(self.packed_weight)
self.scale_and_zero = fn(self.scale_and_zero)
return self

def _change_shape(self, shape):
# int_data, scale, zero = self.get_plain()
# int_data = int_data.view(shape)
# changed = self.from_plain(int_data, scale, zero)
# return changed
# TODO: changing shape is no-op for int4 packed weight right now
return self

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -310,10 +304,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
return return_and_correct_aliasing(func, args, kwargs, new)
if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])

raise NotImplementedError(
f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported"
Expand All @@ -329,8 +325,7 @@ def get_plain(self):
)
cur_shape = self.shape
assert len(cur_shape) == 4
# TODO: expose the arg
inner_k_tiles = self.cur_shape[-1] * 2
inner_k_tiles = cur_shape[-1] * 2
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
eye_shape = original_shape[1]
block_size = (1, 32)
Expand Down Expand Up @@ -557,11 +552,6 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

def _change_shape(self, shape, block_size):
return self.__class__(
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
Expand All @@ -583,13 +573,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True):
# TODO: the old tensor subclass can use the single implementation for both F.linear dispatch
# and aten.addmm/aten.mm dispatch because `_change_shape` is not implmeneted correctly (got ignored
# for the int_data), this makes the dimension for weight_qtensor indeterministic, we need to fix
# the issue and make sure we have a clear accepted dimension for `_quantized_linear_op`
# after that we can remove _from_linear flag
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Args:
input_tensor: dimension is (batch_size, in_features)
weight_tensor: dimension is (out_features, in_features)
bias: dimension is (out_features,)
"""
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
Expand Down Expand Up @@ -649,9 +641,11 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True)
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.layout == "tensor_core_tiled"
):
if not _from_flinear:
weight_qtensor = weight_qtensor.t()
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_qtensor.shape} second dim "
)

# TODO: check groupsize quantization
# avoid circular dep, TODO: move this to a common util.py
Expand Down Expand Up @@ -747,7 +741,8 @@ def aten_mm(func, *args, **kwargs):
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
weight_tensor = weight_tensor.t()
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand All @@ -761,7 +756,8 @@ def aten_mm(func, *args, **kwargs):
None
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
weight_tensor = weight_tensor.t()
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down Expand Up @@ -797,7 +793,11 @@ def t(func, *args, **kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
tensor = args[0]
shape = tensor.shape[::-1]
new = tensor.__class__(
tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride()
)
return return_and_correct_aliasing(func, args, kwargs, new)

to_aq = AffineQuantizedTensor.from_float

0 comments on commit 79f2c7f

Please sign in to comment.