-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Auto-Round support #581
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/581
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 96f745d with merge base 05224a9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @jerryzh168 @msaroufim, I’m reaching out to request a preliminary review for this PR. Although some refactoring is still in progress. I’d like to get your feedback to ensure we’re on the right track before moving forward. This draft PR includes:
Some TODOs:
Regarding 3) GPU memory consumption, in the current flow, I use hooks to capture the inputs and outputs of each block during the calibration stage. This approach differs from the original This approach is mainly to align with the static quantization flow and use |
Signed-off-by: yiliu30 <[email protected]>
Hi @jerryzh168, for 3), I noticed that GPTQ has a similar complication. #577
The main difference is that GPTQ handles a single Linear layer, whereas Inspired by HDCharles's proposal, I tried to extend it to I have prepared a full demo at here. Could you please take a look, thanks a lot! |
@yiliu30 sorry for the late reply, I think using |
one small nit for the "general_decoder", we can use
instead of looking at also after this is done, I think we can improve our current utils for operator implementation: Lines 11 to 70 in db345bd
|
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
requested some changes
Signed-off-by: yiliu30 <[email protected]>
--------- Signed-off-by: yiliu30 <[email protected]>
I was curious about the compute dtype supported by the AO kernel. If it only supports FP16, I recommend forcing the dtype to FP16 before passing it to AutoRound. However, if BF16 is also supported, it would be preferable to set the scale_type in AutoRound to align with the original model. Additionally, the accuracy data slightly differs from the results of our recipe, which may not be solely due to changes in hyperparameters. We should investigate this further. |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
it depends on the kernel, int4 weight only that uses tinygemm kernel only supports bfloat16 I think |
quantize_(model, apply_auto_round(), is_target_module) | ||
``` | ||
|
||
## End-to-End Results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so what about performance results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code changes looks good to me, one comment is just to include performance data (token/s, memory etc.) in README as well, similar to https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks
Signed-off-by: yiliu30 <[email protected]>
The benchmark depends on #769 |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
…#16) * wrap model's buffers and params to `MultiTensor` and update the results Signed-off-by: yiliu30 <[email protected]>
) | ||
else: | ||
is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock) | ||
quantize_model_with_autoround_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should we just use the same flow everywhere to reduce confusions, the flow in https://github.com/pytorch/ao/pull/581/files#diff-af129d63635a3b5b0a95f1a3831f852fbd7bedfd66b38d41bf4975fb49aad246 would be the recommended one I think
Thanks @yiliu30 for addressing all the comments! |
@jerryzh168 Thanks for your patient guidance and detailed examples. This joint effort will allow more users to benefit from AO and auto-round! |
* initial flow for autoround Signed-off-by: yiliu30 <[email protected]> * update flow Signed-off-by: yiliu30 <[email protected]> * use int4 kernel Signed-off-by: yiliu30 <[email protected]> * remove debug code Signed-off-by: yiliu30 <[email protected]> * update the forward Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * e2e example Signed-off-by: yiliu30 <[email protected]> * refine code Signed-off-by: yiliu30 <[email protected]> * add requirements for test Signed-off-by: yiliu30 <[email protected]> * update test Signed-off-by: yiliu30 <[email protected]> * update the readme Signed-off-by: yiliu30 <[email protected]> * add readme Signed-off-by: yiliu30 <[email protected]> * update the filenames Signed-off-by: yiliu30 <[email protected]> * update the np version Signed-off-by: yiliu30 <[email protected]> * add demo Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add more docs Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add doc Signed-off-by: yiliu30 <[email protected]> * use `AffineQuantizedTensor` Signed-off-by: yiliu30 <[email protected]> * impl ar using multensors Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * use hook + multensors Signed-off-by: yiliu30 <[email protected]> * separate mul_tensors into a new file Signed-off-by: yiliu30 <[email protected]> * fix typos Signed-off-by: yiliu30 <[email protected]> * rename mul_tensor to multi_tensor Signed-off-by: yiliu30 <[email protected]> * enable amp Signed-off-by: yiliu30 <[email protected]> * eval model Signed-off-by: yiliu30 <[email protected]> * add gen examples Signed-off-by: yiliu30 <[email protected]> * add warmup to benchmark Signed-off-by: yiliu30 <[email protected]> * add benchmark Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * format code Signed-off-by: yiliu30 <[email protected]> * use tiny kernel Signed-off-by: yiliu30 <[email protected]> * add more note Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * correct typos Signed-off-by: yiliu30 <[email protected]> * remove hard code Signed-off-by: yiliu30 <[email protected]> * use intx Signed-off-by: yiliu30 <[email protected]> * enable offload for multitensor Signed-off-by: yiliu30 <[email protected]> * update the default config Signed-off-by: yiliu30 <[email protected]> * refine note Signed-off-by: yiliu30 <[email protected]> * update the version check Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * update Signed-off-by: yiliu30 <[email protected]> * add ut Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add scripts Signed-off-by: yiliu30 <[email protected]> * format code Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * update Signed-off-by: yiliu30 <[email protected]> * fix typo Signed-off-by: yiliu30 <[email protected]> * refine bench code Signed-off-by: yiliu30 <[email protected]> * Enable `use_optimized_layer_output` and AO' llama (pytorch#12) Signed-off-by: yiliu30 <[email protected]> * Refine the Doc (pytorch#14) --------- Signed-off-by: yiliu30 <[email protected]> * add more docstring Signed-off-by: yiliu30 <[email protected]> * add paper link Signed-off-by: yiliu30 <[email protected]> * correct some note Signed-off-by: yiliu30 <[email protected]> * add cmd Signed-off-by: yiliu30 <[email protected]> * udpdate the scripts Signed-off-by: yiliu30 <[email protected]> * revert some change Signed-off-by: yiliu30 <[email protected]> * Add a lightweight configuration for quick benchmarking (pytorch#15) Signed-off-by: yiliu30 <[email protected]> * update quant method name Signed-off-by: yiliu30 <[email protected]> * Wrap model's buffers and params to `MultiTensor` & update the results (pytorch#16) * wrap model's buffers and params to `MultiTensor` and update the results Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
Resolve #533
Description
Auto-Round
withquantize
_ API using hooks +MultiTensor
.qweight
toAffineQuantizedTensor
to leverage thetinygemm
andUintx
kernels.Llama2/3/3.1
on 5 popularlm-eval
tasks (more tests are on the way).Auto-Round
to the generation benchmarking forLlama2/3
, (Llama 3.1
not yet tested as it was landed a few days ago).Usage
For E2E examples, please refer README.md
cc @thuang6 @ftian1 @wenhuach21