-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpinQuant #983
SpinQuant #983
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/983
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fb3882f with merge base 107e378 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey this is looking nice so far, long term we probably want to make these tensor subclasses so that we can make serialization easier. that way rather than having to load model -> convert model -> load checkpoint, you can just do load model -> load checkpoint not absolutely critical but long term it looks like there may be multiple use cases/apis for spin quant, one explicitly for the Cayley QAT and one not, and unifying them based on serialization will make composability much nicer. |
Good to know @HDCharles, I'll keep the tensor subclasses in mind. I was wondering, will the choice to integrate this into torchao depend on the performance delta it produces? Currently, there is some Wikitext perf improvement but it's perhaps not that significant. |
Update: I'm currently somewhat stuck on this PR. The R2 and R4 matrices are both implemented, and show small perplexity improvements for in4wo-64 quantization (not much though, see table above). I've tried to implement it as much as possible in accordance with the SpinQuant implementation, but these are the best performance results I can achieve thus far (and not quite as good as the results in the paper). What still remains is the R3 rotation and R1 using Cayley optimization. The R3 rotation is a bit tricky to implement because it requires a modification of ao/torchao/_models/llama/model.py Lines 290 to 302 in 09b8b3c
In the SpinQuant repo they use a monkeypatch solution, but the code becomes a bit ugly in that case. At the same time, they show in the paper that R3 has a minimal effect on performance (table 3), so I'm also not sure how much it's worth to implement. Lastly, I have not added the R1 matrices, which would require adding a Cayley optimization procedure. Currently, the SpinQuant changes are immediately applicable at inference time, but running Cayley optimization would require some time to complete (they report ~1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset). I guess it could also be possible to train these matrices once for a model like Llama-7B and include them as add-on weights. I would very much appreciate some feedback on how to proceed with this. |
I have unblocked myself somewhat regarding the R1 rotation matrices: the authors provide downloads for the optimized R1/R2 weights. I could try these out to see what kind of performance difference can be expected before implementing the Cayley optimization here. My only concern is that their Llama implementation might not be 100% identical as in torchao, which could mean that the R1 weights might not work as well, but it seems worth trying out, anyway. |
i think we can merge it and continue working on it regardless, accuracy improvements are definitely a good metric to see how useful it is though. Even if you look in their paper, for 4-16-16, the improvement of SpinQuant is pretty small even with cayley optimization. Its mostly 4-4-16 where it starts to outperform other methods by a significant margin. We're working on getting some kernels for that in the next 1-2 weeks so it may be more useful to that use case. For now i'd do accuracy benchmarks on groupsize=32 rather than 64/128 since thats the minimum batchsize. Yeah the monkeypatch is pretty messy, feels like we can do this in a better way with either tensor subclasses or something else. |
6f6445d
to
2c8acdd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this looks good, seems like a lot of value will be added once activation quantization is used.
would be good to add groupsize 32 numbers and uintx-2bit numbers, and llama3 numbers to the PR description if you have them.
I'll do a final reformat and add some more results in the next few days @HDCharles |
Hi @tobiasvanderwerff, do you mind reformatting |
0bf6a76
to
b6ae688
Compare
Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them
This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used.
Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now).
a8807f1
to
fb3882f
Compare
Hi @tobiasvanderwerff @andrewor14 could you use this implementation https://fburl.com/code/d3nuagm4? It's faster and much easier to read. |
its faster? do you have a link to benchmarks? |
The benchmark is in the summary of D61891002. |
Hi @tobiasvanderwerff, great work! I am wondering if you tested the end-to-end generation performance(tokens/s)? |
I have not tested tokens/s generation @yiliu30, but I can test this if you want. |
Thank you, @tobiasvanderwerff ! I'm primarily interested in studying the computational overhead introduced by r4, and I was wondering if the |
Thanks for bringing this up @yiliu30 -- I tested this and it looks like the custom Hadamard transform kernel indeed breaks torch.compile. I'll investigate this and get back to you. |
* SpinQuant using R2 matrices * Move Hadamard functions and matrices to separate file * Add R4 rotation * Reformat * Do not wrap Linear layers but use nn.Sequential Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them * Add test * Fix test and do small reformat of Hadamard code * Fuse Layernorm params into linear layers This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used. * Add R1 rotation * Add option to load pretrained R1/R2 matrices * Move Spinquant from `torchao/quantization` to `torchao/prototype/spinquant` * Move Hadamard matrices to a separate file * Move test * Minor changes * Reformat * Only enable R4 as default setting Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now). * Add __init__.py to spinquant folder * Do not fail if fast_hadamard_transform is not present
* Move export_aoti into export + minor tidyness * Lint * Remove mismatched arg
Corresponding issue: #579
This PR adds SpinQuant integration to
pytorch/ao
. See the paper for details: https://arxiv.org/abs/2405.16406.Results on LLaMA are shown below, measured by Wikitext word perplexity.
For R1 and R2, random Hadamard matrices are used, unless (pt) is present, in which case I use the pretrained weights provided by the SpinQuant authors.
TODO
implement R3Cayley optimization for R1 and R2 (not sure how feasible this is for inference -- it takes them 1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset)