-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AWQ support #530
Comments
If this is something that we want to add, I can take a stab at integrating it. |
Thanks @vayuda! I feel it might be useful to integrate at the quant_primitive ops / qmodule level at least, similar to #533 (comment), while keeping their special quantization logic (our current we are still discussing how we can support it, but please feel free to let us know your thoughts on this as well. |
Is there already an up to date example of a quantization workflow which uses calibration data to tune certain parameters (the scale factor in the case of AWQ) |
I think you could start with https://github.com/pytorch/ao/blob/main/tutorials/calibration_flow/static_quant.py also I was talking to @HDCharles about this, and it seems sufficient to implement AWQ just for linear and ignore the complicated case that I linked in the awq code for now. |
BTW, in quite a few popular LLMs, such |
yeah, I talked about this in the we also have some evidence that ao/torchao/quantization/smoothquant.py Lines 150 to 152 in afde175
|
Update description of runner and build process in runner_build.md
* make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
* code beautification * code beautification, move functions together * make --device fast the default (pytorch#515) * make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]> * add unpacking support (pytorch#525) * add unpacking support * fix typos and linter * perform parallel prefill when possible (pytorch#568) * perform parallel prefill when possible * typo * disable hack * remove print * remove debug messages which prevent export * fixes * stream results in generate.py (pytorch#571) * remove logging interfering with export --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
AWQ seems popular: 3000 appearances in huggingface models: (https://huggingface.co/models?sort=trending&search=AWQ), similar to GPTQ. Maybe we can add this to torchao as well.
Overview
At the high level, AWQ tries to scale weight based on some power of average per channel magnitude of activation (Sx^(alpha)) as mentioned in the paper, where Sx is the average magnitude of activation (per-channel).
Implementation in original awq repo
Main things are finding scale and applying scale to weights.
Note: In original awq implementation, the logic of finding scale is a bit complicated, but that's mainly to deal with the separate qkv modules. we could start by just implementing awq for simple linears, and worry about the more complicated model structures later.
For applying the scales, in the original impl, we have to manually specify what is the
prev_module
, we could do the same, or we can symbolic trace the model (to preserve all call_modules) in order to figure out the relationship between different modules programmably.How to implement it in torchao
First, I think we can focus on implementing AWQ for linear module only, we can get the activation stats using observers, and search for alpha parameter based on the output of the quantized linear module as well, we can reuse the existing quant_primitives for affine quantization in torchao.
Step 1. Collecting Observer Stats
In terms of collecting activation stats, we could follow what we did in
ao/tutorials/calibration_flow/static_quant.py
Lines 19 to 35 in afde175
ObservedLinear
with observer (or just a logger) to log the activation(s)we can create a function
insert_awq_observers_
similar toao/tutorials/calibration_flow/static_quant.py
Line 37 in afde175
Step 2. Integrate with
AffineQuantizedTensor
Calculating per channel scale can happen when we apply quantization to the weights, similar to:
ao/tutorials/calibration_flow/static_quant.py
Lines 49 to 63 in afde175
As discussed with @vayuda in CUDA_MODE, I think we could implement a new
LayoutType
andAQTLayout
that will scale the weight withequalization_scale
before quantization, and can apply theequalization_scale
tensor to input activation tensor in linear operator. (Note: I think we should call thisequalization_scale
because it's not AWQ only, smoothquant can resue this)In terms of API, we can implement some helper function like
ao/torchao/quantization/quant_api.py
Line 363 in afde175
Note: We may be able to fuse
equalization_scale
to the kernel as well, but our current A16W4 kernel is implemented in tinygemm, so we'd need to modify tinygemm kernels, if we are relying on torch.compile, it would be easy to do.Additional Optimizations
Turn Input-Weight Equalization to Cross Layer Equalization
As we can see from original implementation when applying the scale to linear weights, we applied the scale to the current linear weight and the weight of the previous module, this is only applicable if the previous operation satisfies:
see Section 4.1 of https://arxiv.org/pdf/1906.04721 for more details.
But this could be true for many use cases. To safely apply this optimization, we could do a the following:
see https://pytorch.org/docs/stable/fx.html for docs related to
torch.fx
Logistics (Code Location, Test and Benchmarks)
Please create an
awq
folder underhttps://github.com/pytorch/ao/tree/main/torchao/prototype
The flow and layout implementation can be in separate files, e.g. flow.py, layout.py (there might be some missing extension points of AffineQuantizedTensor, but we'll work on these at the same time)
For Testing, please create a
test_awq.py
in https://github.com/pytorch/ao/tree/main/test/prototypewe can test basic insert_awq_observers_ flow and also the layout creation etc.
For e2e flow demo, please add a
awq.py
in https://github.com/pytorch/ao/tree/main/tutorials/calibration_flowfollowing the static quant example, please show the benchmarking result as well (since we are using optimized kernel) following https://github.com/pytorch/ao/tree/main/torchao/quantization#quantization-flow-example
Last step is to test this with llama2/llama3 following instructions in https://github.com/pytorch/ao/tree/main/torchao/_models/llama and measure the metrics in https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks if you have GPU machines.
References
The text was updated successfully, but these errors were encountered: