Skip to content

Commit

Permalink
Refactor QAT to use tensor subclasses
Browse files Browse the repository at this point in the history
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates.

`AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`.

An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521).

- [x] Add AffineFakeQuantizedTensor
- [x] Add support for int4 weight only fake quantize
- [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w)
- [x] Add prepare and convert path to int4 QAT quantizer
- [x] Add prepare and convert path to 8da4w QAT quantizer
- [x] Support enabling and disabling fake quant dynamically
- [x] Support `__repr__` in AffineFakeQuantizedTensor
- [x] Fix backward pass for int4 weight only
- [x] Fix backward pass for int8 dynamic activations + int4 weight
  • Loading branch information
andrewor14 committed Aug 16, 2024
1 parent 0b66ff0 commit 9932050
Show file tree
Hide file tree
Showing 6 changed files with 613 additions and 155 deletions.
103 changes: 57 additions & 46 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
)
from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import (
AffineFakeQuantizedTensor,
)
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK,
)
from torchao.quantization.quant_api import (
int4_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
Expand Down Expand Up @@ -190,6 +201,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# TODO: compare against quantize_ API instead
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand Down Expand Up @@ -217,13 +229,6 @@ def test_qat_8da4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand All @@ -236,6 +241,20 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

def _copy_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor)

def _assert_matches_subclass_weights(
self,
nn_linear: torch.nn.Linear,
subclass_linear: AffineFakeQuantizedTensor,
):
torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand All @@ -247,6 +266,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
enable_8da4w_fake_quant,
)

def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool):
self.assertTrue(isinstance(m.weight, AffineFakeQuantizedTensor))
self.assertEqual(m.weight.fake_quant_enabled, enabled)
self.assertTrue(hasattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK))
(_, handle) = getattr(m, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK)
if enabled:
self.assertIsNotNone(handle)
else:
self.assertIsNone(handle)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
Expand All @@ -255,14 +284,14 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
self.assertFalse(qat_model.linear1._fake_quant_enabled)
self.assertFalse(qat_model.linear2._fake_quant_enabled)
self.assertFalse(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=False)
assert_fake_quant_enabled(qat_model.linear2, enabled=False)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=False)

# Disabled fake quant is just a normal linear
m2.linear1.weight = qat_model.linear1.weight
m2.linear2.weight = qat_model.linear2.weight
m2.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(m2.linear1, qat_model.linear1)
self._copy_subclass_weights(m2.linear2, qat_model.linear2)
self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear)
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -272,16 +301,16 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):

# Renable fake quant
qat_model.apply(enable_8da4w_fake_quant)
self.assertTrue(qat_model.linear1._fake_quant_enabled)
self.assertTrue(qat_model.linear2._fake_quant_enabled)
self.assertTrue(qat_model.sub.linear._fake_quant_enabled)
assert_fake_quant_enabled(qat_model.linear1, enabled=True)
assert_fake_quant_enabled(qat_model.linear2, enabled=True)
assert_fake_quant_enabled(qat_model.sub.linear, enabled=True)

# Fake quant should be applied as normal
quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model2 = quantizer2.prepare(m3)
qat_model2.linear1.weight = qat_model.linear1.weight
qat_model2.linear2.weight = qat_model.linear2.weight
qat_model2.sub.linear.weight = qat_model.sub.linear.weight
qat_model2.linear1.weight.original_tensor = qat_model.linear1.weight.original_tensor
qat_model2.linear2.weight.original_tensor = qat_model.linear2.weight.original_tensor
qat_model2.sub.linear.weight.original_tensor = qat_model.sub.linear.weight.original_tensor
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
Expand All @@ -306,9 +335,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = quantizer.prepare(m)
qat_model.apply(disable_8da4w_fake_quant)
nn_model.linear1.weight = qat_model.linear1.weight
nn_model.linear2.weight = qat_model.linear2.weight
nn_model.sub.linear.weight = qat_model.sub.linear.weight
self._copy_subclass_weights(nn_model.linear1, qat_model.linear1)
self._copy_subclass_weights(nn_model.linear2, qat_model.linear2)
self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

# Simulate training for both models
optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
Expand All @@ -330,9 +359,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
optimizer2.step()

# After 1 training step, weights should match exactly
torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1)
self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2)
self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
Expand All @@ -353,7 +382,7 @@ def test_qat_generic_fake_quantize(self):
block_size = (1, ao_input.shape[-1])
ao_s = copy.deepcopy(py_s)
ao_zp = copy.deepcopy(py_zp)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax, block_size)
ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
Expand All @@ -373,10 +402,7 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
Expand Down Expand Up @@ -422,9 +448,6 @@ def test_qat_4w_primitives(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
Expand Down Expand Up @@ -453,9 +476,6 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
Expand All @@ -470,11 +490,9 @@ def test_qat_4w_quantizer(self):
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)
ptq_model = m2
quantize_(ptq_model, int4_weight_only(group_size, TensorCoreTiledLayoutType(inner_k_tiles)))

# Compare model values
torch.manual_seed(self.SEED)
Expand All @@ -489,13 +507,6 @@ def test_qat_4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
int4_weight_only_fake_quantize,
int8_dynamic_activation_int4_weight_fake_quantize,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
Int8DynActInt4WeightQATLinear,
Expand All @@ -13,6 +15,8 @@
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"int4_weight_only_fake_quantize",
"int8_dynamic_activation_int4_weight_fake_quantize",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATLinear",
Expand Down
Loading

0 comments on commit 9932050

Please sign in to comment.