From cda787ccc79664ccdc0ec10cf6ba06e87b9a079c Mon Sep 17 00:00:00 2001
From: Jerry Zhang <jerryzh168@gmail.com>
Date: Wed, 15 May 2024 17:45:51 -0700
Subject: [PATCH] Remove input_quant_func from AffineQuantizedTensor subclass
 (#243)

* Remove input_quant_func from AffineQuantizedTensor subclass

Summary:
Currently we have a input_quant_func in the AffineQuantizedTensor, which is a bit convoluted, we want to use a
separate LinearActAffineQuantizedTensor subclass for activation quantization (dynamic quantization) instead

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:

* Add dispatch for dynamic quantization in `AffineQuantizedTensor`

Summary:
This PR added dispatch for int8act-int8 weight dynamic quantization that's calling `int_scaled_matmul` kernel in the end

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_dyn_quant

Reviewers:

Subscribers:

Tasks:

Tags:

* Fix test
---
 test/quantization/test_quant_api.py |  86 ++++++++--
 torchao/quantization/subclass.py    | 257 +++++++++++++++++++++++-----
 2 files changed, 286 insertions(+), 57 deletions(-)

diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index cea659e61d..fcab07c913 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -395,7 +395,10 @@ def test_eval_wrapper(self):
     # TODO: move to a separate test file
     @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
     def test_quantized_tensor_subclass_8da4w(self):
-        from torchao.quantization.subclass import AffineQuantizedTensor
+        from torchao.quantization.subclass import (
+            AffineQuantizedTensor,
+            LinearActQuantizedTensor,
+        )
         from torchao.quantization.quant_primitives import MappingType
         import copy
 
@@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
         quant_max = 7
 
         # TODO: make a general helper function?
+        # input settings
         def get_per_token_block_size(x):
             block_size = []
             for i in range(len(x.shape)-1):
@@ -421,13 +425,18 @@ def get_per_token_block_size(x):
         input_target_dtype = torch.int8
         input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
 
+        def dynamic_quant(linear):
+            # note: order is important
+            linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
+            linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
+
         m = ToyLinearModel().eval()
         m_copy = copy.deepcopy(m)
         example_inputs = m.example_inputs()
-        m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
-        m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
-        assert isinstance(m.linear1.weight, AffineQuantizedTensor)
-        assert isinstance(m.linear2.weight, AffineQuantizedTensor)
+        dynamic_quant(m.linear1)
+        dynamic_quant(m.linear2)
+        assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
+        assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
 
         # reference
         from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
@@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self):
         preserve_zero = False
         zero_point_dtype = torch.bfloat16
 
-        # weight only quantization
-        input_quant_func = None
-
         # use 1024 so that we don't need padding
         m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
         m_copy = copy.deepcopy(m)
@@ -475,7 +481,6 @@ def to_quantized(weight):
                 zero_point_dtype=zero_point_dtype,
                 preserve_zero=preserve_zero,
                 zero_point_domain=ZeroPointDomain.FLOAT,
-                input_quant_func=input_quant_func,
             )
 
         m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
@@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self):
         eps = torch.finfo(torch.float32).eps
         zero_point_dtype = torch.int64
 
-        # weight only quantization
-        input_quant_func = None
-
         m = ToyLinearModel().eval().to(torch.bfloat16)
         m_copy = copy.deepcopy(m)
         example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
 
         def to_quantized(weight):
             block_size = (1, weight.shape[1])
-            return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
+            return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
 
         m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
         m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
@@ -532,5 +534,63 @@ def to_quantized(weight):
         torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
 
 
+    @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
+    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
+    def test_quantized_tensor_subclass_int8_dyn_quant(self):
+        from torchao.quantization.subclass import AffineQuantizedTensor
+        from torchao.quantization.subclass import LinearActQuantizedTensor
+        from torchao.quantization.quant_primitives import MappingType
+        from torchao.quantization.quant_primitives import ZeroPointDomain
+        import copy
+
+        # weight settings
+        mapping_type = MappingType.SYMMETRIC
+        def get_weight_block_size(x):
+            return (1, x.shape[1])
+        target_dtype = torch.int8
+        eps = torch.finfo(torch.float32).eps
+        zero_point_dtype = torch.int64
+
+        # input settings
+        def get_per_token_block_size(x):
+            block_size = list(x.shape)
+            for i in range(len(block_size)-1):
+                block_size[i] = 1
+            return block_size
+
+        input_mapping_type = MappingType.SYMMETRIC
+        input_target_dtype = torch.int8
+        input_eps = 1e-5
+        input_quant_min = -127
+        input_quant_max = 127
+        input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)
+
+        # use 1024 so that we don't need padding
+        m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
+        m_copy = copy.deepcopy(m)
+        example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
+
+        def dynamic_quant(linear):
+            # note: order is important
+            linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
+            linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
+
+        dynamic_quant(m.linear1)
+        dynamic_quant(m.linear2)
+        assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
+        assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
+        assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
+        assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
+
+        # reference
+        from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
+        change_linear_weights_to_int8_dqtensors(m_copy)
+
+        res = m(*example_inputs)
+        ref = m_copy(*example_inputs)
+
+        self.assertTrue(torch.equal(res, ref))
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py
index 607cb77766..bc40ffeaff 100644
--- a/torchao/quantization/subclass.py
+++ b/torchao/quantization/subclass.py
@@ -21,7 +21,9 @@
     quantize_affine,
     dequantize_affine,
     ZeroPointDomain,
+    MappingType,
 )
+from torchao.kernel.intmm import int_scaled_matmul
 from .utils import find_multiple
 from typing import Tuple, Optional, Callable
 
@@ -36,6 +38,30 @@
 
 aten = torch.ops.aten
 
+def _aqt_is_int8(aqt):
+    """Check if an AffineQuantizedTensor is int8 quantized Tensor"""
+    return (
+        aqt.int_data.dtype == torch.int8 and
+        aqt.quant_min is None or aqt.quant_min == -128 and
+        aqt.quant_max is None or aqt.quant_max == 127
+    )
+
+def _aqt_is_int8_reduced_range(aqt):
+    return (
+        aqt.int_data.dtype == torch.int8 and
+        aqt.quant_min == -127 and
+        aqt.quant_max is None or aqt.quant_max == 127
+    )
+
+def _aqt_is_uint4(aqt):
+    """Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
+    # TODO: use torch.uint4
+    return (
+        aqt.int_data.dtype == torch.int32 and
+        aqt.quant_min is None or aqt.quant_min == 0 and
+        aqt.quant_max is None or aqt.quant_max == 15
+    )
+
 
 class QuantizedLinearWeightBase(torch.Tensor):
     """
@@ -643,7 +669,6 @@ def __new__(
         quant_min: Optional[int] = None,
         quant_max: Optional[int] = None,
         zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
-        input_quant_func: Optional[Callable] = None,
         dtype=None,
         # TODO: remove args and kwargs
         *args,
@@ -670,7 +695,6 @@ def __init__(
         quant_min: Optional[int] = None,
         quant_max: Optional[int] = None,
         zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
-        input_quant_func: Optional[Callable] = None,
         dtype=None,
         *args,
         **kwargs
@@ -682,12 +706,11 @@ def __init__(
         self.quant_min = quant_min
         self.quant_max = quant_max
         self.zero_point_domain = zero_point_domain
-        self.input_quant_func = input_quant_func
 
     def __repr__(self):
         return (
             f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
-            f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
+            f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
         )
 
     def dequantize(self, output_dtype=None):
@@ -696,14 +719,14 @@ def dequantize(self, output_dtype=None):
         return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)
 
     def __tensor_flatten__(self):
-        return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.input_quant_func, self.dtype]
+        return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
 
     @classmethod
     def __tensor_unflatten__(
         cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
     ):
         int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
-        block_size, shape, quant_min, quant_max, zero_point_domain, input_quant_func, dtype = tensor_attributes
+        block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes
         return cls(
             int_data,
             scale,
@@ -713,7 +736,6 @@ def __tensor_unflatten__(
             quant_min,
             quant_max,
             zero_point_domain,
-            input_quant_func=input_quant_func,
             dtype=dtype,
             strides=outer_stride,
         )
@@ -730,7 +752,6 @@ def from_float(
         eps = None,
         scale_dtype = None,
         zero_point_dtype = None,
-        input_quant_func = None,
         preserve_zero = True,
         zero_point_domain = ZeroPointDomain.INT,
     ):
@@ -745,7 +766,6 @@ def from_float(
             quant_min,
             quant_max,
             zero_point_domain,
-            input_quant_func=input_quant_func,
             dtype=input_float.dtype
         )
 
@@ -759,27 +779,63 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
                 args[1],
                 args[2] if len(args) > 2 else None,
             )
-            if weight_qtensor.input_quant_func is None:
-                is_cuda = args[0].is_cuda
-                is_cpu = args[0].device == torch.device("cpu")
-                # weight only quantization
-                is_int8 = (
-                    weight_qtensor.int_data.dtype == torch.int8 and
-                    weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and
-                    weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127
-                )
-                is_uint4 = (
-                    weight_qtensor.int_data.dtype == torch.int32 and
-                    weight_qtensor.quant_min == 0 and
-                    weight_qtensor.quant_max == 15
-                )
+            is_cuda = weight_qtensor.is_cuda
+            is_cpu = weight_qtensor.device == torch.device("cpu")
+            if isinstance(weight_qtensor, AffineQuantizedTensor):
+                weight_is_int8 = _aqt_is_int8(weight_qtensor)
+                weight_is_uint4 = _aqt_is_uint4(weight_qtensor)
+
+                if isinstance(input_tensor, AffineQuantizedTensor):
+                    # if input tensor is quantized, either dispatch to the int8 mm kernel
+                    # or just dequantize the input tensor
+                    input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
+                    input_tensor_dtype_is_expected = input_tensor.dtype in [
+                        torch.float,
+                        torch.bfloat16
+                    ]
+                    if (
+                        is_cuda and
+                        input_is_int8 and
+                        input_tensor_dtype_is_expected
+                    ):
+                        #
+                        # 1. do the matrix form of dot(X_i, W_j)
+                        #
+                        #
+                        # 2. rescale the output
+                        #
+                        # in cases with large matrices, y_dot_int32 can grow sufficiently
+                        # large that y_dot_int32 * a float16 scale is greater than the maximum
+                        # value of a float 16, (which results in a value of inf even if multiplying
+                        # by the other scale would bring it within the expected range)
+
+                        x_vals_int8 = input_tensor.int_data
+                        x_scales = input_tensor.scale
+                        w_vals_int8_t = weight_qtensor.int_data.contiguous().t()
+                        w_scales = weight_qtensor.scale
+                        tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
+                        y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
+
+                        y = (y_dot_scaled * w_scales).reshape(
+                            *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
+                        )
+
+                        # can downcast only at the very end
+                        output_dtype = input_tensor.dtype
+                        y = y.to(output_dtype)
+                        if bias is not None:
+                            y += bias
+                        return y
+                    else:
+                        input_tensor = input_tensor.dequantize()
 
+                # weight only quantization
                 # TODO: enable cpu and mps path as well
                 # TODO: make sure weight dimension matches the expectation of the int4mm kernel
                 # TODO: move this to TinygemmAffineQuantizedTensor
                 if (
                     is_cuda and
-                    is_uint4 and
+                    weight_is_uint4 and
                     weight_qtensor.dtype == torch.bfloat16 and
                     len(weight_qtensor.shape) == 2 and
                     weight_qtensor.block_size[0] == 1 and
@@ -796,7 +852,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
                     return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
                 elif (
                     is_cpu and
-                    is_int8 and
+                    weight_is_int8 and
                     len(weight_qtensor.shape) == 2 and
                     len(weight_qtensor.block_size) == 2 and
                     weight_qtensor.block_size[0] == 1 and
@@ -805,18 +861,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
                     # TODO: enable mps path as well
                     # per channel int8 weight only quantizated mm
                     return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
+                else:
+                    weight_tensor = weight_qtensor.dequantize()
+                    return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
             else:
-                # dynamic quantization
-                input_tensor = weight_qtensor.input_quant_func(input_tensor)
-                input_tensor = input_tensor.dequantize()
-            weight_tensor = weight_qtensor.dequantize()
-            return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
+                if isinstance(input_tensor, AffineQuantizedTensor):
+                    input_tensor = input_tensor.dequantize()
+                return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
 
-        try:
-            with torch._C.DisableTorchFunctionSubclass():
-                return func(*args, **kwargs)
-        except:
-            print(f"ERR: subclass doesn't implement {func}")
+        with torch._C.DisableTorchFunctionSubclass():
+            return func(*args, **kwargs)
 
 
     def _get_to_kwargs(self, *args, **kwargs):
@@ -844,7 +898,6 @@ def to(self, *args, **kwargs):
             self.quant_min,
             self.quant_max,
             self.zero_point_domain,
-            self.input_quant_func,
             **kwargs,
         )
 
@@ -858,7 +911,6 @@ def _apply_fn_to_data(self, fn):
             self.quant_min,
             self.quant_max,
             self.zero_point_domain,
-            self.input_quant_func,
             dtype=self.dtype,
         )
 
@@ -900,16 +952,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
                     args[1],
                     None if len(args) == 2 else args[2],
                 )
-            if weight_qtensor.input_quant_func is not None:
-                # dynamic quantization
-                input_tensor = weight_qtensor.input_quant_func(input_tensor)
-                input_tensor = input_tensor.dequantize()
             weight_tensor = weight_qtensor.dequantize()
             return func(input_tensor, weight_tensor, bias)
 
-        if (func is aten.detach.default or
-            func is aten.clone.default or
-            func is aten._to_copy.default):
+        if func is aten.detach.default:
             return return_and_correct_aliasing(
                 func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
             )
@@ -933,3 +979,126 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
                 kwargs,
                 args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
             )
+
+        raise NotImplementedError(
+            f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
+        )
+
+
+class LinearActQuantizedTensor(torch.Tensor):
+    """
+    Applies activation quantization for linear operator
+    """
+    def __new__(
+        cls,
+        original_weight_tensor: torch.Tensor,
+        input_quant_func: Callable,
+    ):
+        kwargs = {}
+        dtype = original_weight_tensor.dtype
+        kwargs["dtype"] = dtype
+        kwargs["requires_grad"] = False
+        shape = original_weight_tensor.shape
+        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]
+
+    def __init__(
+        self,
+        original_weight_tensor: torch.Tensor,
+        input_quant_func: Callable,
+    ):
+        self.original_weight_tensor = original_weight_tensor
+        self.input_quant_func = input_quant_func
+
+    def __tensor_flatten__(self):
+        return ["original_weight_tensor"], [self.input_quant_func]
+
+    @classmethod
+    def __tensor_unflatten__(
+        cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
+    ):
+        original_weight_tensor = tensor_data_dict["original_weight_tensor"]
+        input_quant_func = tensor_attributes
+        return cls(
+            original_weight_tensor,
+            input_quant_func,
+        )
+
+    @classmethod
+    def from_float(
+        cls,
+        input_float,
+        input_quant_func,
+    ):
+        return cls(
+            input_float,
+            input_quant_func,
+        )
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        kwargs = {} if kwargs is None else kwargs
+
+        if func is torch.nn.functional.linear:
+            input_tensor, weight_tensor, bias = (
+                args[0],
+                args[1],
+                args[2] if len(args) > 2 else None,
+            )
+            if isinstance(weight_tensor, LinearActQuantizedTensor):
+                input_quant_func = weight_tensor.input_quant_func
+                original_weight_tensor = weight_tensor.original_weight_tensor
+                aqt = input_quant_func(input_tensor)
+                return torch.nn.functional.linear(aqt, original_weight_tensor, bias)
+
+        with torch._C.DisableTorchFunctionSubclass():
+            return func(*args, **kwargs)
+
+    def _apply_fn_to_data(self, fn):
+        return self.__class__(
+            fn(self.original_weight_tensor),
+            self.input_quant_func,
+        )
+
+    def __torch_dispatch__(cls, func, types, args, kwargs):
+        if (
+            func in [aten.mm.default, aten.addmm.default]
+            and args[0].is_floating_point()
+        ):
+            if func == aten.addmm.default:
+                assert args[1].shape[-1] == args[2].shape[0], (
+                    f"need mat1 shape: {args[1].shape} final"
+                    f"dim to match mat2 shape: {args[2].shape} first dim "
+                )
+                input_tensor, weight_qtensor, bias = (
+                    args[1],
+                    args[2],
+                    args[0],
+                )
+                aqt = self.input_quant_func(input_tensor)
+                return func(bias, aqt, weight_tensor)
+            else:
+                assert args[0].shape[-1] == args[1].shape[0], (
+                    f"need mat1 shape: {args[0].shape} final dim"
+                    f"to match mat2 shape: {args[1].shape} first dim"
+                )
+                input_tensor, weight_qtensor, bias = (
+                    args[0],
+                    args[1],
+                    None if len(args) == 2 else args[2],
+                )
+                aqt = self.input_quant_func(input_tensor)
+                return func(aqt, weight_tensor, bias)
+
+        if func is aten.detach.default:
+            return return_and_correct_aliasing(
+                func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
+            )
+
+        if func is aten.clone.default:
+            return return_and_correct_aliasing(
+                func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
+            )
+
+        raise NotImplementedError(
+            f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
+        )