Skip to content

Commit

Permalink
apply FQ in __torch_function__
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Aug 12, 2024
1 parent 6288d74 commit 37651b6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 34 deletions.
85 changes: 57 additions & 28 deletions torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.utils._pytree as pytree
from typing import Callable, Optional, Tuple
from torchao.quantization.quant_primitives import (
_get_and_check_qmin_qmax,
Expand Down Expand Up @@ -40,21 +41,23 @@ def __new__(
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
**kwargs,
):
kwargs = {}
kwargs["device"] = original_tensor.device
kwargs["dtype"] = original_tensor.dtype
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, original_tensor.shape, **kwargs) # type: ignore[attr-defined]
kwargs.setdefault("dtype", original_tensor.dtype)
kwargs.setdefault("device", original_tensor.device)
return torch.Tensor._make_wrapper_subclass(
cls,
original_tensor.shape,
**kwargs,
)

def __init__(
self,
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
**kwargs
):
# TODO: original_tensor is not getting updated!
original_tensor.requires_grad_(self.requires_grad)
self.original_tensor = original_tensor
self.apply_fake_quant_fn = apply_fake_quant_fn
self.fake_quant_enabled = fake_quant_enabled
Expand Down Expand Up @@ -90,9 +93,10 @@ def from_float(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
def apply_fake_quant_fn(t: torch.Tensor):
assert isinstance(t, cls)
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
scale, zero_point = choose_qparams_affine(
t,
t.original_tensor,
mapping_type,
block_size,
target_dtype,
Expand All @@ -114,10 +118,18 @@ def apply_fake_quant_fn(t: torch.Tensor):
zero_point_domain,
)
return fq
return cls(input_float, apply_fake_quant_fn)
return cls(
input_float,
apply_fake_quant_fn,
fake_quant_enabled=True,
requires_grad=True,
)

def to_fake_quantized(self) -> torch.Tensor:
return self.apply_fake_quant_fn(self.original_tensor)
def get_value(self) -> torch.Tensor:
if self.fake_quant_enabled:
return self.apply_fake_quant_fn(self)
else:
return self.original_tensor

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
Expand All @@ -130,6 +142,7 @@ def _get_to_kwargs(self, *args, **kwargs):
"device": device,
"dtype": dtype,
"memory_format": memory_format,
"requires_grad": self.requires_grad,
}
return kwargs

Expand All @@ -146,10 +159,15 @@ def to(self, *args, **kwargs):
)

def _apply_fn_to_data(self, fn):
"""
Create a new `AffineFakeQuantizedTensor` with `fn` applied to the
original tensor, to be called within __torch_dispatch__.
"""
return self.__class__(
fn(self.original_tensor),
self.apply_fake_quant_fn,
self.fake_quant_enabled,
requires_grad=False,
)

implements = classmethod(_implements)
Expand All @@ -167,55 +185,66 @@ def _(func, types, *args, **kwargs):
args[2] if len(args) > 2 else None,
)
if isinstance(input_tensor, AffineFakeQuantizedTensor):
input_tensor = input_tensor.to_fake_quantized()
input_tensor = input_tensor.get_value()
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
weight_tensor = weight_tensor.to_fake_quantized()
weight_tensor = weight_tensor.get_value()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements([aten.mm.default, aten.addmm.default])
def _(func, types, *args, **kwargs):
if func == aten.addmm.default:
bias = args[0]
input_index = 1
input_tensor = args[1]
weight_tensor = args[2]
else:
bias = None
input_index = 0
input_tensor = args[input_index]
weight_tensor = args[input_index + 1]
input_tensor = args[0]
weight_tensor = args[1]
if isinstance(input_tensor, AffineFakeQuantizedTensor):
input_tensor = input_tensor.to_fake_quantized()
input_tensor = input_tensor.get_value()
if isinstance(weight_tensor, AffineFakeQuantizedTensor):
weight_tensor = weight_tensor.to_fake_quantized()
weight_tensor = weight_tensor.get_value()
if bias is not None:
return func(bias, input_tensor, weight_tensor)
else:
return func(input_tensor, weight_tensor)


@implements([aten.detach.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach),
)


@implements([aten.clone.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone),
)

@implements([aten._to_copy.default])

@implements([aten.t.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
func, args, kwargs, args[0]._apply_fn_to_data(torch.t),
)

@implements([aten.t.default])

@implements([
aten.add.Tensor,
aten.add_.Tensor,
aten.mul_.Tensor,
])
def _(func, types, *args, **kwargs):
assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}"
new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args)
first_afq_tensor = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1]
fn = lambda x: func(*new_args, **kwargs)
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
func, args, kwargs, first_afq_tensor._apply_fn_to_data(fn),
)


to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
19 changes: 13 additions & 6 deletions torchao/quantization/prototype/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ def forward(
quant_max: int,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
## Note: for bf16 inputs, casting them to fp32 has the unexpected
## side effect of reducing memory footprint significantly, presumably
## because bf16 * fp32 kernels are not as memory efficient
#assert input.dtype == torch.float32
#assert scales.dtype == torch.float32
#assert zero_points.dtype == torch.int32
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

original_input = input
if isinstance(original_input, AffineFakeQuantizedTensor):
input = input.original_tensor

(fq, mask) = fake_quantize_affine_cachemask(
input,
Expand Down Expand Up @@ -149,9 +151,14 @@ def _unwrap_affine_fake_quantized_tensor(t: torch.Tensor):
Return the original, non-fake-quantized float tensor from a `AffineFakeQuantizedTensor`.
"""
# avoid circular dependencies
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
if isinstance(t, LinearActivationQuantizedTensor):
t = t.original_weight_tensor
assert isinstance(t, AffineFakeQuantizedTensor)
return t.original_tensor

Expand Down

0 comments on commit 37651b6

Please sign in to comment.