-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial support for 8da4w QAT #138
Conversation
5b1d404
to
044280a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you don't need the extra underscore in _prototype
, but otherwise looks good as a way to get started :)
653a07a
to
d5cd97f
Compare
return (qmin, qmax) | ||
|
||
|
||
def replace_linear_8da4w_qat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be the same as replace_linear_8da4w, maybe we want to abstract out a helper function, to be less error prone
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, please make sure CI passes before landing
ae8bd77
to
83dd03a
Compare
Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group python test/quantization/test_qat.py -k test_fake_quantize_per_token python test/quantization/test_qat.py -k test_qat_8da4w_linear python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch, HDCharles Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar Tasks: #86
@@ -1144,14 +1144,30 @@ def replace_linear_8da4w( | |||
), | |||
) | |||
else: | |||
replace_linear_8da4w( | |||
_replace_linear_8da4w( | |||
child, | |||
groupsize, | |||
padding_allowed, | |||
precision, | |||
scales_precision, | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it miss the linear_class
? cc. @jerryzh168 @andrewor14
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can fix it in fbcode first I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh my bad, let me submit a PR
Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group python test/quantization/test_qat.py -k test_fake_quantize_per_token python test/quantization/test_qat.py -k test_qat_8da4w_linear python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch, HDCharles Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar Tasks: pytorch#86
Summary: This commit adds support for QAT, where linear layers are fake quantized with int8 per token dynamic activations (8da) and int4 grouped per channel weights (4w). This initial implementation uses the same module swap approach as 8da4w PTQ for simplicity and code reuse. In the future, we may wish to consider migrating both flows to use tensor subclasses for better composability with other PyTorch features.
Test Plan:
python test/quantization/test_qat.py -k test_fake_quantize_per_channel_group
python test/quantization/test_qat.py -k test_fake_quantize_per_token
python test/quantization/test_qat.py -k test_qat_8da4w_linear
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer
Reviewers: jerryzh168, cpuhrsch, HDCharles
Subscribers: jerryzh168, cpuhrsch, HDCharles, supriyar
Tasks: #86