Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int4 weight-only QAT #383

Merged
merged 2 commits into from
Jul 17, 2024
Merged

Add support for int4 weight-only QAT #383

merged 2 commits into from
Jul 17, 2024

Conversation

andrewor14
Copy link
Contributor

Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing
Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting.

Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_linear

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim, HDCharles, supriyar

@andrewor14 andrewor14 requested a review from jerryzh168 June 15, 2024 23:35
Copy link

pytorch-bot bot commented Jun 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/383

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2f91cd3 with merge base aef7e09 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 15, 2024
@msaroufim
Copy link
Member

Cool! Do you think we could also merge a table showing accuracy and speedup in the QAT folder? I'd like to use that in a new training section this week

@andrewor14
Copy link
Contributor Author

Cool! Do you think we could also merge a table showing accuracy and speedup in the QAT folder? I'd like to use that in a new training section this week

Yeah setting up the experiments now, hoping to get those numbers soon

@andrewor14 andrewor14 force-pushed the 4w_qat branch 6 times, most recently from 6d9a205 to 2ac2250 Compare June 16, 2024 22:24
@msaroufim
Copy link
Member

What happened with this PR? Are we still planning on merging it?

@andrewor14
Copy link
Contributor Author

What happened with this PR? Are we still planning on merging it?

Yes, Jerry and I discussed offline. I plan to refactor the existing quant primitives to make a common fake_quantize_affine first, then update this PR to use that

@andrewor14 andrewor14 force-pushed the 4w_qat branch 4 times, most recently from d56774b to 6cf40b7 Compare July 12, 2024 20:48
@andrewor14 andrewor14 requested a review from jerryzh168 July 12, 2024 20:48
else:
_convert_qat_linear_4w(child)

class Int4WeightOnlyQATLinear(torch.nn.Linear):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you going to follow up to refactor this to use quantize_ API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that will come later when I refactor QAT to use tensor subclasses

)

Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please dedup the args, also Args should come before Returns I think

Summary: In QAT, we often wish to filter out the gradients
corresponding to values outside the expected quantization
range, for example:

```
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

grad = grad * mask
```

The existing `fake_quantize_affine` returns the dequantized
values only, so callers do not have access to this mask.
This commit adds the variant to this op that returns both
the dequantized values and the mask, similar to
`fake_quantize_per_tensor_affine_cachemask` in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask
Summary: This commit adds support for int4 weight-only QAT,
which simulates the numerics of the existing
Int4WeightOnlyQuantizer. The main motivation for this is to
provide an end-to-end path for running QAT and lowering to
the efficient int4 tinygemm cuda kernel. To enable this,
we have to add new fake quantization primitives to match
the numerics of the tinygemm kernel, and this required
refactoring existing quant primitives to skip dtype casting.

Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_linear

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks!

@andrewor14 andrewor14 merged commit f8789f7 into main Jul 17, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Add cachemask variant for fake_quantize_affine

Summary: In QAT, we often wish to filter out the gradients
corresponding to values outside the expected quantization
range, for example:

```
q = _quantize_affine_no_dtype_cast(...)
dq = _dequantize_affine_no_dtype_check(...)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))

grad = grad * mask
```

The existing `fake_quantize_affine` returns the dequantized
values only, so callers do not have access to this mask.
This commit adds the variant to this op that returns both
the dequantized values and the mask, similar to
`fake_quantize_per_tensor_affine_cachemask` in core.

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask

* Add support for int4 weight-only QAT

Summary: This commit adds support for int4 weight-only QAT,
which simulates the numerics of the existing
Int4WeightOnlyQuantizer. The main motivation for this is to
provide an end-to-end path for running QAT and lowering to
the efficient int4 tinygemm cuda kernel. To enable this,
we have to add new fake quantization primitives to match
the numerics of the tinygemm kernel, and this required
refactoring existing quant primitives to skip dtype casting.

Test Plan:
python test/quantization/test_qat.py -k test_qat_4w_linear

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Sep 12, 2024
Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.

Test Plan: TODO
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Sep 12, 2024
Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.

Test Plan: TODO
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Sep 12, 2024
Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.

Test Plan: TODO
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Sep 13, 2024
Summary: This commit adds an int4 weight-only QAT flow targeting
the efficient tinygemm kernel. This means during fine-tuning
we only simulate numerics of the kernel in bf16, but we only
actually call the kernel after quantizing the model. For more
detail, see pytorch/ao#383.

Test Plan:

Fine-tune QAT command:
```
tune run --nnodes 1 --nproc_per_node 6 --rdzv_endpoint="localhost:8900" qat_distributed --config llama3/8B_qat_full \
    batch_size=8 \
    fake_quant_after_n_steps=1000 \
    checkpointer.output_dir="/tmp/qat_results" \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQATQuantizer \
    quantizer.groupsize=128
```

Quantize command:
```
tune run quantize --config recipes/configs/quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
    quantizer.groupsize=128 \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="/tmp/qat_results" \
    checkpointer.output_dir="/tmp/qat_results" \
    checkpointer.checkpoint_files=[meta_model_2.pt] \
    checkpointer.model_type=LLAMA3
```

Eval command:
```
tune run eleuther_eval --config eleuther_evaluation \
    tasks="[hellaswag, wikitext]" \
    model._component_=torchtune.models.llama3.llama3_8b \
    quantizer._component_=torchtune.training.quantization.Int4WeightOnlyQuantizer \
    quantizer.groupsize=128 \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="/tmp/qat_results" \
    checkpointer.output_dir="/tmp/qat_results" \
    checkpointer.checkpoint_files=[meta_model_2-4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
```

Evaluation results:
```
|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4806|±  |0.0167|

|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4914|±  |0.0164|

|    Tasks     |Version|Filter|n-shot|Metric|Value |   |Stderr|
|--------------|------:|------|-----:|------|-----:|---|-----:|
|truthfulqa_mc2|      2|none  |     0|acc   |0.4872|±  |0.0167|
```
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants