-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor QAT to use tensor subclasses #585
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/585
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a4148c3 with merge base b523f9f (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
3b7e221
to
8bc2b70
Compare
556c5bb
to
6288d74
Compare
b643642
to
353a503
Compare
37651b6
to
2f398db
Compare
17ae0c2
to
78ca8d0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left a few questions but I'll defer to @jerryzh168 for the review
78ca8d0
to
a7a19be
Compare
torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py
Outdated
Show resolved
Hide resolved
d0dd364
to
9932050
Compare
9932050
to
61bd8d3
Compare
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates. `AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`. An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521). - [x] Add AffineFakeQuantizedTensor - [x] Add support for int4 weight only fake quantize - [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w) - [x] Add prepare and convert path to int4 QAT quantizer - [x] Add prepare and convert path to 8da4w QAT quantizer - [x] Support enabling and disabling fake quant dynamically - [x] Support `__repr__` in AffineFakeQuantizedTensor - [x] Fix backward pass for int4 weight only - [x] Fix backward pass for int8 dynamic activations + int4 weight
61bd8d3
to
a4148c3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks!
For reference, just added some end-to-end benchmarks and evaluation results for tensor subclass QAT here: pytorch/torchtune#1330. Looks pretty good so far |
Summary: Recent refactor into tensor subclasses (#585) broke some existing use cases that rely on DDP and FSDP1, since the new flow only supports FSDP2 currently. This commit adds back the module swap API for now to provide a backdoor for these use cases. In the long term, we still plan to deprecate the module swap flow. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim
Summary: Recent refactor into tensor subclasses (#585) broke some existing use cases that rely on DDP and FSDP1, since the new flow only supports FSDP2 currently. This commit adds back the module swap API for now to provide a backdoor for these use cases. In the long term, we still plan to deprecate the module swap flow. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim
Summary: Recent refactor into tensor subclasses (#585) broke some existing use cases that rely on DDP and FSDP1, since the new flow only supports FSDP2 currently. This commit adds back the module swap API for now to provide a backdoor for these use cases. In the long term, we still plan to deprecate the module swap flow. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim
Overview
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce
AffineFakeQuantizedTensor
, which is analogous toAffineQuantizedTensor
but applies fake quantization instead and requires gradient updates.How training works
AffineFakeQuantizedTensor
wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (AffineFakeQuantizedTensor
) and never to the inner tensor. For weights, the outer tensor is also atorch.nn.Parameter
, and gradient updates received by the outer tensor are then passed to the inner tensor through ops likeaten.add_
andaten.mul_
.Input activation
An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module
forward_pre_hook
instead of relying on another subclassLinearActivationQuantizedTensor
that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under__torch_dispatch__
, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register aforward_pre_hook
that wraps the input activations before each call to forward. This approach is also motivated by how DTensor wraps their subclasses .Work items
__repr__
in AffineFakeQuantizedTensor