Skip to content

Commit

Permalink
differentiable constructor?
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Aug 9, 2024
1 parent 6288d74 commit b643642
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 24 deletions.
119 changes: 103 additions & 16 deletions torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.utils._pytree as pytree
from typing import Callable, Optional, Tuple
from torchao.quantization.quant_primitives import (
_get_and_check_qmin_qmax,
Expand All @@ -17,6 +18,7 @@

aten = torch.ops.aten


class AffineFakeQuantizedTensor(torch.Tensor):
"""
Affine fake quantized tensor subclass. Affine quantization means we quantize the floating point tensor
Expand All @@ -40,38 +42,52 @@ def __new__(
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
requires_grad: bool = False,
**kwargs,
):
kwargs = {}
kwargs["device"] = original_tensor.device
kwargs["dtype"] = original_tensor.dtype
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, original_tensor.shape, **kwargs) # type: ignore[attr-defined]
kwargs.setdefault("dtype", original_tensor.dtype)
kwargs.setdefault("device", original_tensor.device)
return torch.Tensor._make_wrapper_subclass(
cls,
original_tensor.shape,
requires_grad=requires_grad,
**kwargs,
)

def __init__(
self,
original_tensor: torch.Tensor,
apply_fake_quant_fn: Callable,
fake_quant_enabled: bool = True,
requires_grad: bool = False,
**kwargs,
):
# TODO: original_tensor is not getting updated!
original_tensor.requires_grad_(self.requires_grad)
self.original_tensor = original_tensor
self.apply_fake_quant_fn = apply_fake_quant_fn
self.fake_quant_enabled = fake_quant_enabled
self.requires_grad = requires_grad

def __repr__(self):
return f"AffineFakeQuantizedTensor({self.original_tensor})"

def __tensor_flatten__(self):
return ["original_tensor"], [self.apply_fake_quant_fn, self.fake_quant_enabled]
return ["original_tensor"], [
self.apply_fake_quant_fn,
self.fake_quant_enabled,
self.requires_grad,
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride,
):
original_tensor = tensor_data_dict["original_tensor"]
(apply_fake_quant_fn, fake_quant_enabled) = tensor_attributes
(apply_fake_quant_fn, fake_quant_enabled, requires_grad) = tensor_attributes
return cls(
original_tensor,
apply_fake_quant_fn,
fake_quant_enabled,
requires_grad,
)

@classmethod
Expand All @@ -89,10 +105,10 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
def apply_fake_quant_fn(t: torch.Tensor):
def apply_fake_quant_fn(t: AffineFakeQuantizedTensor):
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
scale, zero_point = choose_qparams_affine(
t,
t.original_tensor,
mapping_type,
block_size,
target_dtype,
Expand All @@ -114,10 +130,17 @@ def apply_fake_quant_fn(t: torch.Tensor):
zero_point_domain,
)
return fq
return cls(input_float, apply_fake_quant_fn)
fake_quant_enabled = True
requires_grad = True
return cls(
input_float,
apply_fake_quant_fn,
fake_quant_enabled,
requires_grad,
)

def to_fake_quantized(self) -> torch.Tensor:
return self.apply_fake_quant_fn(self.original_tensor)
return self.apply_fake_quant_fn(self).original_tensor

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
Expand All @@ -142,10 +165,15 @@ def to(self, *args, **kwargs):
self.original_tensor.to(device),
self.apply_fake_quant_fn,
self.fake_quant_enabled,
self.requires_grad,
**kwargs,
)

def _apply_fn_to_data(self, fn):
"""
Create a new subclass with `fn` applied to the original tensor,
to be called within __torch_dispatch__.
"""
return self.__class__(
fn(self.original_tensor),
self.apply_fake_quant_fn,
Expand Down Expand Up @@ -194,13 +222,13 @@ def _(func, types, *args, **kwargs):
@implements([aten.detach.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach),
)

@implements([aten.clone.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone),
)

@implements([aten._to_copy.default])
Expand All @@ -215,7 +243,66 @@ def _(func, types, *args, **kwargs):
@implements([aten.t.default])
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
func, args, kwargs, args[0]._apply_fn_to_data(torch.t),
)

# TODO: merge all these?
#@implements([aten.view.default])
#def _(func, types, *args, **kwargs):
# fn = lambda x: x.view(*args[1:], **kwargs)
# return return_and_correct_aliasing(
# func, args, kwargs, args[0]._apply_fn_to_data(fn),
# )
#
#@implements([aten.amin.default])
#def _(func, types, *args, **kwargs):
# fn = lambda x: x.amin(*args[1:], **kwargs)
# return return_and_correct_aliasing(
# func, args, kwargs, args[0]._apply_fn_to_data(fn),
# )
#
#@implements([aten.amax.default])
#def _(func, types, *args, **kwargs):
# fn = lambda x: x.amax(*args[1:], **kwargs)
# return return_and_correct_aliasing(
# func, args, kwargs, args[0]._apply_fn_to_data(fn),
# )
#
#@implements([aten.clamp.default])
#def _(func, types, *args, **kwargs):
# clamp_fn = lambda x: x.clamp(*args[1:], **kwargs)
# return return_and_correct_aliasing(
# func, args, kwargs, args[0]._apply_fn_to_data(clamp_fn),
# )
#
#@implements([aten.round.default])
#def _(func, types, *args, **kwargs):
# return return_and_correct_aliasing(
# func, args, kwargs, args[0]._apply_fn_to_data(torch.round),
# )

#@implements([
# aten.sub.Tensor,
# aten.add.Tensor,
# aten.add_.Tensor,
# aten.div.Tensor,
# aten.mul.Tensor,
# aten.mul_.Tensor,
# aten.ge.Scalar,
# aten.le.Scalar,
# aten.logical_and.default,
#])
#def _(func, types, *args, **kwargs):
# assert len(args) == 2, f"dispatched the wrong op to the binary handler: {func}"
# new_args = pytree.tree_map_only(AffineFakeQuantizedTensor, lambda x: x.original_tensor, args)
# new_data = func(*new_args, **kwargs)
# first_afqt = args[0] if isinstance(args[0], AffineFakeQuantizedTensor) else args[1]
# out = AffineFakeQuantizedTensor(
# new_data,
# first_afqt.apply_fake_quant_fn,
# first_afqt.fake_quant_enabled,
# )
# return return_and_correct_aliasing(func, args, kwargs, out)


to_affine_fake_quantized = AffineFakeQuantizedTensor.from_float
27 changes: 19 additions & 8 deletions torchao/quantization/prototype/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def forward(
quant_max: int,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
## Note: for bf16 inputs, casting them to fp32 has the unexpected
## side effect of reducing memory footprint significantly, presumably
## because bf16 * fp32 kernels are not as memory efficient
#assert input.dtype == torch.float32
#assert scales.dtype == torch.float32
#assert zero_points.dtype == torch.int32
# avoid circular deps
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)

if isinstance(input, AffineFakeQuantizedTensor):
_input = input.original_tensor
else:
_input = input

(fq, mask) = fake_quantize_affine_cachemask(
input,
_input,
block_size,
scales,
zero_points,
Expand All @@ -55,7 +58,15 @@ def forward(
)

ctx.save_for_backward(mask)
return fq

if isinstance(input, AffineFakeQuantizedTensor):
return AffineFakeQuantizedTensor(
fq,
input.apply_fake_quant_fn,
input.fake_quant_enabled,
)
else:
return fq

@staticmethod
def backward(ctx, gy):
Expand Down

0 comments on commit b643642

Please sign in to comment.