-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for int4 weight-only QAT #383
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/383
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2f91cd3 with merge base aef7e09 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Cool! Do you think we could also merge a table showing accuracy and speedup in the QAT folder? I'd like to use that in a new training section this week |
Yeah setting up the experiments now, hoping to get those numbers soon |
6d9a205
to
2ac2250
Compare
What happened with this PR? Are we still planning on merging it? |
Yes, Jerry and I discussed offline. I plan to refactor the existing quant primitives to make a common |
d56774b
to
6cf40b7
Compare
else: | ||
_convert_qat_linear_4w(child) | ||
|
||
class Int4WeightOnlyQATLinear(torch.nn.Linear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you going to follow up to refactor this to use quantize_
API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that will come later when I refactor QAT to use tensor subclasses
) | ||
|
||
Args: | ||
input (torch.Tensor): original float32, float16 or bfloat16 Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please dedup the args, also Args
should come before Returns
I think
Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, thanks!
* Add cachemask variant for fake_quantize_affine Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask * Add support for int4 weight-only QAT Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
Summary: This commit adds an int4 weight-only QAT flow targeting the efficient tinygemm kernel. This means during fine-tuning we only simulate numerics of the kernel in bf16, but we only actually call the kernel after quantizing the model. For more detail, see pytorch/ao#383. Test Plan: TODO
Summary: This commit adds an int4 weight-only QAT flow targeting the efficient tinygemm kernel. This means during fine-tuning we only simulate numerics of the kernel in bf16, but we only actually call the kernel after quantizing the model. For more detail, see pytorch/ao#383. Test Plan: TODO
Summary: This commit adds an int4 weight-only QAT flow targeting the efficient tinygemm kernel. This means during fine-tuning we only simulate numerics of the kernel in bf16, but we only actually call the kernel after quantizing the model. For more detail, see pytorch/ao#383. Test Plan: TODO
Summary: This commit adds an int4 weight-only QAT flow targeting the efficient tinygemm kernel. This means during fine-tuning we only simulate numerics of the kernel in bf16, but we only actually call the kernel after quantizing the model. For more detail, see pytorch/ao#383. Test Plan: Fine-tune QAT command: ``` tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \ batch_size=8 \ fake_quant_after_n_steps=1000 \ checkpointer.output_dir="/tmp/qat_results" \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \ quantizer.groupsize=128 ``` Quantize command: ``` tune run quantize --config recipes/configs/quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \ quantizer.groupsize=128 \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="/tmp/qat_results" \ checkpointer.output_dir="/tmp/qat_results" \ checkpointer.checkpoint_files=[meta_model_2.pt] \ checkpointer.model_type=LLAMA3 ``` Eval command: ``` tune run eleuther_eval --config eleuther_evaluation \ tasks="[hellaswag, wikitext]" \ model._component_=torchtune.models.llama3.llama3_8b \ quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \ quantizer.groupsize=128 \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="/tmp/qat_results" \ checkpointer.output_dir="/tmp/qat_results" \ checkpointer.checkpoint_files=[meta_model_2-4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model ``` Evaluation results: ``` | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4806|± |0.0167| | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4914|± |0.0164| | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| |--------------|------:|------|-----:|------|-----:|---|-----:| |truthfulqa_mc2| 2|none | 0|acc |0.4872|± |0.0167| ```
Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing
Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting.
Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_linear
Reviewers: jerryzh168, msaroufim
Subscribers: jerryzh168, msaroufim, HDCharles, supriyar