-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor QAT to use tensor subclasses
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates. `AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`. An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521). - [x] Add AffineFakeQuantizedTensor - [x] Add support for int4 weight only fake quantize - [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w) - [x] Add prepare and convert path to int4 QAT quantizer - [x] Add prepare and convert path to 8da4w QAT quantizer - [x] Support enabling and disabling fake quant dynamically - [x] Support `__repr__` in AffineFakeQuantizedTensor - [x] Fix backward pass for int4 weight only - [x] Fix backward pass for int8 dynamic activations + int4 weight
- Loading branch information
1 parent
0b66ff0
commit 9932050
Showing
6 changed files
with
613 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.