Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpinQuant to generate.py #1069

Merged
merged 7 commits into from
Oct 22, 2024
Merged

Conversation

tobiasvanderwerff
Copy link
Contributor

  • Add SpinQuant to torchao/_models/llama/generate.py
  • Only import SpinQuant when necessary in eval.py and generate.py (No need to import the large Hadamard matrices required for SpinQuant otherwise)

No need to import the large Hadamard matrices required for SpinQuant if it isn't necessary
Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1069

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1543e4f with merge base e7b33bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@jerryzh168
Copy link
Contributor

thanks, any results we can show?

@tobiasvanderwerff
Copy link
Contributor Author

@jerryzh168 I'm fixing a torch.compile issue first related to the Hadamard transform using in SpinQuant, after that I'll post some benchmark results here. If you want, we can keep this PR open and I'll push the changes here.

@tobiasvanderwerff
Copy link
Contributor Author

tobiasvanderwerff commented Oct 15, 2024

SpinQuant now also works with torch.compile. Benchmark results (llama-2-7b, tested on an A100):

Baseline + torch.compile

Average tokens/sec: 114.08
Average Bandwidth: 1507.58 GB/s
Peak Memory Usage: 13.88 GB
Model Size: 13.21 GB

Spinquant (R4) + torch.compile

Average tokens/sec: 109.59
Average Bandwidth: 1448.61 GB/s
Peak Memory Usage: 13.72 GB
Model Size: 13.22 GB

Spinquant (R1+R2+R4) + torch.compile

NB: R1 and R2 are fused into the linear weights before inference takes place, so it is expected that they do not lead to additional overhead at inference time.

Average tokens/sec: 109.64
Average Bandwidth: 1449.28 GB/s
Peak Memory Usage: 14.90 GB
Model Size: 13.22 GB

@tobiasvanderwerff tobiasvanderwerff mentioned this pull request Oct 15, 2024
6 tasks
@tobiasvanderwerff
Copy link
Contributor Author

Results without torch.compile:

Baseline

Average tokens/sec: 27.33
Average Bandwidth: 361.21 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB

Spinquant (R4)

Average tokens/sec: 23.01
Average Bandwidth: 304.10 GB/s
Peak Memory Usage: 14.24 GB
Model Size: 13.22 GB

@yiliu30
Copy link
Contributor

yiliu30 commented Oct 15, 2024

SpinQuant now also works with torch.compile. Benchmark results (tested on an A100):

Baseline + torch.compile

Average tokens/sec: 114.31
Average Bandwidth: 1510.58 GB/s
Peak Memory Usage: 13.88 GB
Model Size: 13.21 GB

Spinquant (R4) + torch.compile

Average tokens/sec: 109.00
Average Bandwidth: 1440.76 GB/s
Peak Memory Usage: 13.98 GB
Model Size: 13.22 GB

Thanks @tobiasvanderwerff, may I know which model you tested on, llama-2-7b?

@tobiasvanderwerff
Copy link
Contributor Author

Yep, llama-2-7b, I'll add that to the benchmark.

@HDCharles
Copy link
Contributor

can you add benchmark numbers for R1+R2 as well? i think R4 is only for activation quantization

@HDCharles
Copy link
Contributor

would be good to add this info into a readme file inside the spinquant dir

@jerryzh168
Copy link
Contributor

ready to merge?

@tobiasvanderwerff
Copy link
Contributor Author

Yep, this is ready @jerryzh168

@HDCharles HDCharles merged commit 3044ee5 into pytorch:main Oct 22, 2024
17 checks passed
@tobiasvanderwerff tobiasvanderwerff deleted the spinquant-mods branch October 22, 2024 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants