Skip to content

Commit

Permalink
Fix backward disable case
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Aug 13, 2024
1 parent d960c6a commit a7a19be
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 36 deletions.
54 changes: 25 additions & 29 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand Down Expand Up @@ -225,14 +226,6 @@ def test_qat_8da4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# TODO: enable this after supporting aten.eq.default in both subclasses
# Compare converted state dict
# ptq_state_dict = ptq_model.state_dict()
# converted_state_dict = converted_model.state_dict()
# self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
# for k in ptq_state_dict.keys():
# torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand All @@ -245,6 +238,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

def _copy_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)

def _assert_matches_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand Down Expand Up @@ -279,9 +286,9 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)

# Disabled fake quant is just a normal linear
m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight.original_tensor)
m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight.original_tensor)
m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight.original_tensor)
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand Down Expand Up @@ -318,20 +325,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
disable_8da4w_fake_quant,
)

def get_qat_weight(qat_linear: torch.nn.Linear):
assert isinstance(qat_linear.weight, AffineFakeQuantizedTensor)
return qat_linear.weight.original_tensor

group_size = 16
torch.manual_seed(self.SEED)
m = M()
nn_model = copy.deepcopy(m)
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear1))
nn_model.linear2.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear2))
nn_model.sub.linear.weight = torch.nn.Parameter(get_qat_weight(qat_model.sub.linear))
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -353,9 +356,9 @@ def get_qat_weight(qat_linear: torch.nn.Linear):
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, get_qat_weight(qat_model.linear1), atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, get_qat_weight(qat_model.linear2), atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, get_qat_weight(qat_model.sub.linear), atol=0, rtol=0)
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
Expand Down Expand Up @@ -396,6 +399,7 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
Expand Down Expand Up @@ -429,14 +433,6 @@ def test_qat_4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# TODO: enable this after supporting aten.eq.default in both subclasses
# Compare converted state dict
# ptq_state_dict = ptq_model.state_dict()
# converted_state_dict = converted_model.state_dict()
# self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
# for k in ptq_state_dict.keys():
# torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
)
from .utils import _GenericFakeQuantize
from .utils import (
_GenericFakeQuantize,
_UnwrapAffineFakeQuantizedTensor,
)

aten = torch.ops.aten

Expand Down Expand Up @@ -129,7 +132,7 @@ def get_value(self) -> torch.Tensor:
if self.fake_quant_enabled:
return self.apply_fake_quant_fn(self)
else:
return self.original_tensor
return _UnwrapAffineFakeQuantizedTensor.apply(self)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/prototype/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def convert(
quantize_(model, quantize_fn)
return model

# TODO: deprecate
class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int8 dynamic per token fake
Expand Down Expand Up @@ -322,7 +323,7 @@ def convert(
quantize_(model, quantize_fn)
return model


# TODO: deprecate
class Int4WeightOnlyQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int4 fake quantized grouped
Expand Down
32 changes: 29 additions & 3 deletions torchao/quantization/prototype/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None, None


class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
"""
Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring
gradients are still passed to the tensor subclass. This is used in place of
`_GenericFakeQuantize` when fake quant is disabled.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: torch.Tensor,
) -> torch.Tensor:
# avoid circular dependencies
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
assert isinstance(input, AffineFakeQuantizedTensor)
return input.original_tensor

@staticmethod
def backward(ctx, gy):
return gy,


def _fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
Expand Down Expand Up @@ -206,8 +231,9 @@ def _enable_fake_quant(mod: torch.nn.Module, enable: bool):
if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK):
(prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
if enable:
handle = mod.register_forward_pre_hook(prehook)
if handle is None:
handle = mod.register_forward_pre_hook(prehook)
_forward_pre_hook_handler(mod, prehook, handle)
else:
handle.remove()
handle = None
_forward_pre_hook_handler(mod, prehook, handle)
_forward_pre_hook_handler(mod, prehook, None)
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _get_linear_subclass_inserter(
requires_grad: bool = False,
) -> Callable:
"""
Return a functinon that inserts wraps the weight and/or input activation of a
Return a function that inserts wraps the weight and/or input activation of a
linear module in tensor subclasses.
Args:
Expand Down

0 comments on commit a7a19be

Please sign in to comment.