Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpinQuant #983

Merged
merged 18 commits into from
Oct 10, 2024
Merged

SpinQuant #983

merged 18 commits into from
Oct 10, 2024

Conversation

tobiasvanderwerff
Copy link
Contributor

@tobiasvanderwerff tobiasvanderwerff commented Oct 1, 2024

Corresponding issue: #579

This PR adds SpinQuant integration to pytorch/ao. See the paper for details: https://arxiv.org/abs/2405.16406.

Results on LLaMA are shown below, measured by Wikitext word perplexity.

Model Quantization Baseline R4 R2+R4 R1+R2+R4 R1+R2+R4 (pt) R2+R4 (pt)
Llama-2-7B None (bfloat16) 12.23 12.24 12.24 12.24
int8dq 12.35 12.35 12.35
int4wo-32 12.68 12.58 12.60 13.65 13.49 12.64
int4wo-64 12.87 12.82 12.80
int4wo-64-marlin 12.87 12.82 13.65
uintx-4-32 12.81 12.53 13.63
uintx-4-64 12.89 12.80
uintx-2-8 211
Llama-3-8B None (bfloat16) 7.44 7.44 7.44
int4wo-32 8.11 8.06 8.54
uintx-4-64 8.11 8.31 9.00

For R1 and R2, random Hadamard matrices are used, unless (pt) is present, in which case I use the pretrained weights provided by the SpinQuant authors.

TODO

  • implement R2
  • implement R4
  • implement layernorm weight fusion into linear layers (footnote 3 in the paper)
  • implement R1
  • implement R3
  • Cayley optimization for R1 and R2 (not sure how feasible this is for inference -- it takes them 1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset)

Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/983

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fb3882f with merge base 107e378 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2024
@tobiasvanderwerff tobiasvanderwerff marked this pull request as draft October 2, 2024 09:14
@HDCharles
Copy link
Contributor

Hey this is looking nice so far, long term we probably want to make these tensor subclasses so that we can make serialization easier. that way rather than having to load model -> convert model -> load checkpoint, you can just do load model -> load checkpoint

not absolutely critical but long term it looks like there may be multiple use cases/apis for spin quant, one explicitly for the Cayley QAT and one not, and unifying them based on serialization will make composability much nicer.

@tobiasvanderwerff
Copy link
Contributor Author

Good to know @HDCharles, I'll keep the tensor subclasses in mind. I was wondering, will the choice to integrate this into torchao depend on the performance delta it produces? Currently, there is some Wikitext perf improvement but it's perhaps not that significant.

@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Oct 3, 2024

Update: I'm currently somewhat stuck on this PR. The R2 and R4 matrices are both implemented, and show small perplexity improvements for in4wo-64 quantization (not much though, see table above). I've tried to implement it as much as possible in accordance with the SpinQuant implementation, but these are the best performance results I can achieve thus far (and not quite as good as the results in the paper). What still remains is the R3 rotation and R1 using Cayley optimization.

The R3 rotation is a bit tricky to implement because it requires a modification of Attention.forward() in the middle of the function, after the apply_rotary_emb calls:

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)

In the SpinQuant repo they use a monkeypatch solution, but the code becomes a bit ugly in that case. At the same time, they show in the paper that R3 has a minimal effect on performance (table 3), so I'm also not sure how much it's worth to implement.

Lastly, I have not added the R1 matrices, which would require adding a Cayley optimization procedure. Currently, the SpinQuant changes are immediately applicable at inference time, but running Cayley optimization would require some time to complete (they report ~1hr to run Cayley optimization on 8x A100 GPUs for R1 and R2 using 800 samples of WikiText2 calibration dataset). I guess it could also be possible to train these matrices once for a model like Llama-7B and include them as add-on weights.

I would very much appreciate some feedback on how to proceed with this.

@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Oct 3, 2024

I have unblocked myself somewhat regarding the R1 rotation matrices: the authors provide downloads for the optimized R1/R2 weights. I could try these out to see what kind of performance difference can be expected before implementing the Cayley optimization here. My only concern is that their Llama implementation might not be 100% identical as in torchao, which could mean that the R1 weights might not work as well, but it seems worth trying out, anyway.

@HDCharles
Copy link
Contributor

HDCharles commented Oct 3, 2024

i think we can merge it and continue working on it regardless, accuracy improvements are definitely a good metric to see how useful it is though. Even if you look in their paper, for 4-16-16, the improvement of SpinQuant is pretty small even with cayley optimization. Its mostly 4-4-16 where it starts to outperform other methods by a significant margin. We're working on getting some kernels for that in the next 1-2 weeks so it may be more useful to that use case. For now i'd do accuracy benchmarks on groupsize=32 rather than 64/128 since thats the minimum batchsize.

Yeah the monkeypatch is pretty messy, feels like we can do this in a better way with either tensor subclasses or something else.

@tobiasvanderwerff tobiasvanderwerff marked this pull request as ready for review October 4, 2024 18:15
@HDCharles HDCharles self-requested a review October 4, 2024 20:20
Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this looks good, seems like a lot of value will be added once activation quantization is used.

would be good to add groupsize 32 numbers and uintx-2bit numbers, and llama3 numbers to the PR description if you have them.

@tobiasvanderwerff
Copy link
Contributor Author

I'll do a final reformat and add some more results in the next few days @HDCharles

@andrewor14
Copy link
Contributor

Hi @tobiasvanderwerff, do you mind reformatting hadamard_utils.py so we don't end up with a 10k line file? I feel you can even separate it into a separate file like _hadamard_matrices.py, so it's easier to review the other parts of hadamard_utils.py

@tobiasvanderwerff tobiasvanderwerff force-pushed the spinquant branch 2 times, most recently from 0bf6a76 to b6ae688 Compare October 9, 2024 08:24
@tobiasvanderwerff tobiasvanderwerff changed the title [wip] SpinQuant SpinQuant Oct 9, 2024
Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them
This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used.
Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now).
@HDCharles HDCharles merged commit 590f8fb into pytorch:main Oct 10, 2024
17 checks passed
@tobiasvanderwerff tobiasvanderwerff deleted the spinquant branch October 10, 2024 19:14
@wat3rBro
Copy link

Hi @tobiasvanderwerff, do you mind reformatting hadamard_utils.py so we don't end up with a 10k line file? I feel you can even separate it into a separate file like _hadamard_matrices.py, so it's easier to review the other parts of hadamard_utils.py

Hi @tobiasvanderwerff @andrewor14 could you use this implementation https://fburl.com/code/d3nuagm4? It's faster and much easier to read.

@HDCharles
Copy link
Contributor

its faster? do you have a link to benchmarks?

@wat3rBro
Copy link

its faster? do you have a link to benchmarks?

The benchmark is in the summary of D61891002.

@yiliu30
Copy link
Contributor

yiliu30 commented Oct 11, 2024

Hi @tobiasvanderwerff, great work! I am wondering if you tested the end-to-end generation performance(tokens/s)?

@tobiasvanderwerff
Copy link
Contributor Author

I have not tested tokens/s generation @yiliu30, but I can test this if you want.

@yiliu30
Copy link
Contributor

yiliu30 commented Oct 14, 2024

I have not tested tokens/s generation @yiliu30, but I can test this if you want.

Thank you, @tobiasvanderwerff ! I'm primarily interested in studying the computational overhead introduced by r4, and I was wondering if the hardmard_transform might break the torch.compile.

@tobiasvanderwerff
Copy link
Contributor Author

Thanks for bringing this up @yiliu30 -- I tested this and it looks like the custom Hadamard transform kernel indeed breaks torch.compile. I'll investigate this and get back to you.

jainapurva pushed a commit that referenced this pull request Oct 15, 2024
* SpinQuant using R2 matrices

* Move Hadamard functions and matrices to separate file

* Add R4 rotation

* Reformat

* Do not wrap Linear layers but use nn.Sequential

Wrapping the Linear layers might mess with the quantization of the linear layers, so it's probably better to keep the linear layers the same and insert new layers alongside them

* Add test

* Fix test and do small reformat of Hadamard code

* Fuse Layernorm params into linear layers

This is done for pre-norm LLMs like LLaMa to make them scale-invariant (see footnote 3 in the paper). However, in the current implementation it seems to hurt performance when quantization is used.

* Add R1 rotation

* Add option to load pretrained R1/R2 matrices

* Move Spinquant from `torchao/quantization` to `torchao/prototype/spinquant`

* Move Hadamard matrices to a separate file

* Move test

* Minor changes

* Reformat

* Only enable R4 as default setting

Random R1 and R2 matrices are showing worse results than just using R4, so the latter seems to be a better default option (at least for now).

* Add __init__.py to spinquant folder

* Do not fail if fast_hadamard_transform is not present
@tobiasvanderwerff
Copy link
Contributor Author

@yiliu30 FYI I fixed the issue with torch.compile -- you can see the benchmark results here.

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Move export_aoti into export + minor tidyness

* Lint

* Remove mismatched arg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants