Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor QAT to use tensor subclasses #585

Merged
merged 1 commit into from
Aug 20, 2024
Merged

Refactor QAT to use tensor subclasses #585

merged 1 commit into from
Aug 20, 2024

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Aug 1, 2024

Overview

This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce AffineFakeQuantizedTensor, which is analogous to AffineQuantizedTensor but applies fake quantization instead and requires gradient updates.

How training works

AffineFakeQuantizedTensor wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (AffineFakeQuantizedTensor) and never to the inner tensor. For weights, the outer tensor is also a torch.nn.Parameter, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like aten.add_ and aten.mul_.

Input activation

An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module forward_pre_hook instead of relying on another subclass LinearActivationQuantizedTensor that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under __torch_dispatch__, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a forward_pre_hook that wraps the input activations before each call to forward. This approach is also motivated by how DTensor wraps their subclasses .

Work items

  • Add AffineFakeQuantizedTensor
  • Add support for int4 weight only fake quantize
  • Add support for int8 dynamic activations + int4 weight fake quantize (8da4w)
  • Add prepare and convert path to int4 QAT quantizer
  • Add prepare and convert path to 8da4w QAT quantizer
  • Support enabling and disabling fake quant dynamically
  • Support __repr__ in AffineFakeQuantizedTensor
  • Fix backward pass for int4 weight only
  • Fix backward pass for int8 dynamic activations + int4 weight

Copy link

pytorch-bot bot commented Aug 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/585

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a4148c3 with merge base b523f9f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft August 1, 2024 14:55
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 1, 2024
@andrewor14 andrewor14 changed the title [draft] temp PR, do not review [draft] Refactor QAT to use tensor subclasses Aug 2, 2024
@andrewor14 andrewor14 force-pushed the qat-subclass branch 2 times, most recently from b643642 to 353a503 Compare August 10, 2024 00:09
@andrewor14 andrewor14 changed the title [draft] Refactor QAT to use tensor subclasses Refactor QAT to use tensor subclasses Aug 10, 2024
@andrewor14 andrewor14 marked this pull request as ready for review August 10, 2024 00:13
@andrewor14 andrewor14 force-pushed the qat-subclass branch 2 times, most recently from 37651b6 to 2f398db Compare August 12, 2024 17:15
@andrewor14 andrewor14 requested a review from jerryzh168 August 12, 2024 23:46
@andrewor14 andrewor14 force-pushed the qat-subclass branch 2 times, most recently from 17ae0c2 to 78ca8d0 Compare August 13, 2024 00:00
@andrewor14 andrewor14 requested a review from bdhirsh August 13, 2024 00:02
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few questions but I'll defer to @jerryzh168 for the review

@andrewor14 andrewor14 force-pushed the qat-subclass branch 3 times, most recently from d0dd364 to 9932050 Compare August 16, 2024 02:52
This commit refactors QAT to use tensor subclasses. This is motivated by the general move towards tensor subclasses in torchao for better composability with other subclasses like DTensors. To achieve this, we introduce `AffineFakeQuantizedTensor`, which is analogous to `AffineQuantizedTensor` but applies fake quantization instead and requires gradient updates.

`AffineFakeQuantizedTensor` wraps the original weight or input activation tensor and applies fake quantize dynamically only when the linear function is called. Gradients only flow to the outer tensor (`AffineFakeQuantizedTensor`) and never to the inner tensor. For weights, the outer tensor is also a `torch.nn.Parameter`, and gradient updates received by the outer tensor are then passed to the inner tensor through ops like `aten.add_` and `aten.mul_`.

An important difference between the PTQ and the QAT flows is how input activation subclasses are inserted. For QAT, we use the nn.module `forward_pre_hook` instead of relying on another subclass `LinearActivationQuantizedTensor` that wraps the weight subclass. The problem with the old PTQ approach is it can create subclasses under `__torch_dispatch__`, which runs below autograd and so the created subclasses cannot have gradients, so it was difficult to get the gradients to flow correctly in such cases. It's also not super intuitive because quantizing input activation needs to go through the weights. In the new approach used by QAT, we instead register a `forward_pre_hook` that wraps the input activations before each call to forward. This approach is also motivated by how [DTensor wraps their subclasses ](https://github.com/pytorch/pytorch/blob/844103197d3e8cf6b4b59176e473365113f4f962/torch/distributed/tensor/parallel/style.py#L521).

- [x] Add AffineFakeQuantizedTensor
- [x] Add support for int4 weight only fake quantize
- [x] Add support for int8 dynamic activations + int4 weight fake quantize (8da4w)
- [x] Add prepare and convert path to int4 QAT quantizer
- [x] Add prepare and convert path to 8da4w QAT quantizer
- [x] Support enabling and disabling fake quant dynamically
- [x] Support `__repr__` in AffineFakeQuantizedTensor
- [x] Fix backward pass for int4 weight only
- [x] Fix backward pass for int8 dynamic activations + int4 weight
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks!

@jerryzh168 jerryzh168 merged commit 2c8e3f3 into main Aug 20, 2024
16 checks passed
@jerryzh168 jerryzh168 deleted the qat-subclass branch August 20, 2024 18:33
@andrewor14
Copy link
Contributor Author

For reference, just added some end-to-end benchmarks and evaluation results for tensor subclass QAT here: pytorch/torchtune#1330. Looks pretty good so far

andrewor14 added a commit that referenced this pull request Aug 27, 2024
Summary: Recent refactor into tensor subclasses (#585) broke
some existing use cases that rely on DDP and FSDP1, since the
new flow only supports FSDP2 currently. This commit adds back
the module swap API for now to provide a backdoor for these
use cases. In the long term, we still plan to deprecate the
module swap flow.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap
python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim
andrewor14 added a commit that referenced this pull request Aug 27, 2024
Summary: Recent refactor into tensor subclasses (#585) broke
some existing use cases that rely on DDP and FSDP1, since the
new flow only supports FSDP2 currently. This commit adds back
the module swap API for now to provide a backdoor for these
use cases. In the long term, we still plan to deprecate the
module swap flow.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap
python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim
andrewor14 added a commit that referenced this pull request Aug 27, 2024
Summary: Recent refactor into tensor subclasses (#585) broke
some existing use cases that rely on DDP and FSDP1, since the
new flow only supports FSDP2 currently. This commit adds back
the module swap API for now to provide a backdoor for these
use cases. In the long term, we still plan to deprecate the
module swap flow.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap
python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants