-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torchao quant for mixtral and qwen_moe #1418
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
merrymercy
reviewed
Sep 14, 2024
jerryzh168
changed the title
Add torchao quant for mixtral
Add torchao quant for mixtral and qwen_moe
Sep 14, 2024
merrymercy
requested changes
Sep 14, 2024
Summary: Similar to sgl-project#1341 we add torchao quantization to mixtral model Test Plan: Note: compile is not working yet, and I can't install torchnightly locally and make it work either. I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 Warmup ... Prefill. latency: 0.05532 s, throughput: 2313.73 token/s Decode. latency: 0.00896 s, throughput: 111.65 token/s Decode. latency: 0.00833 s, throughput: 120.04 token/s Decode. latency: 0.00869 s, throughput: 115.06 token/s Decode. latency: 0.00842 s, throughput: 118.79 token/s Decode. median latency: 0.00855 s, median throughput: 116.89 token/s Total. latency: 0.090 s, throughput: 1471.26 token/s Benchmark ... Prefill. latency: 0.04294 s, throughput: 2980.61 token/s Decode. latency: 0.00839 s, throughput: 119.12 token/s Decode. latency: 0.00828 s, throughput: 120.78 token/s Decode. latency: 0.00857 s, throughput: 116.64 token/s Decode. latency: 0.00853 s, throughput: 117.19 token/s Decode. latency: 0.00859 s, throughput: 116.39 token/s Decode. median latency: 0.00853 s, median throughput: 117.17 token/s Total. latency: 0.111 s, throughput: 1226.84 token/s python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128 Warmup ... Prefill. latency: 0.06413 s, throughput: 1996.05 token/s Decode. latency: 0.00764 s, throughput: 130.84 token/s Decode. latency: 0.00748 s, throughput: 133.73 token/s Decode. latency: 0.00725 s, throughput: 137.84 token/s Decode. latency: 0.00721 s, throughput: 138.74 token/s Decode. median latency: 0.00737 s, median throughput: 135.76 token/s Total. latency: 0.094 s, throughput: 1408.61 token/s Benchmark ... Prefill. latency: 0.05239 s, throughput: 2443.43 token/s Decode. latency: 0.00739 s, throughput: 135.25 token/s Decode. latency: 0.00720 s, throughput: 138.90 token/s Decode. latency: 0.00718 s, throughput: 139.21 token/s Decode. latency: 0.00722 s, throughput: 138.42 token/s Decode. latency: 0.00745 s, throughput: 134.30 token/s Decode. median latency: 0.00731 s, median throughput: 136.82 token/s Total. latency: 0.111 s, throughput: 1223.51 token/s A100, no compile python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo max_total_num_tokens=199454 Warmup ... Prefill. latency: 0.06958 s, throughput: 1839.60 token/s Decode. latency: 0.02343 s, throughput: 42.68 token/s Decode. latency: 0.02342 s, throughput: 42.70 token/s Decode. latency: 0.02368 s, throughput: 42.23 token/s Decode. latency: 0.02337 s, throughput: 42.80 token/s Decode. median latency: 0.02342 s, median throughput: 42.69 token/s Total. latency: 0.163 s, throughput: 807.48 token/s Benchmark ... Prefill. latency: 0.05767 s, throughput: 2219.36 token/s Decode. latency: 0.02293 s, throughput: 43.61 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. latency: 0.02029 s, throughput: 49.29 token/s Decode. latency: 0.02024 s, throughput: 49.41 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. median latency: 0.02025 s, median throughput: 49.39 token/s Total. latency: 0.222 s, throughput: 611.87 token/s Reviewers: Subscribers: Tasks: Tags:
jerryzh168
force-pushed
the
add-mixtral-quant
branch
from
September 14, 2024 05:46
5436560
to
c3b08dd
Compare
merrymercy
approved these changes
Sep 14, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Similar to #1341 we add torchao quantization to mixtral model
Test Plan:
Note: compile is not working yet, and I can't install torchnightly locally and make it work either. I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 Warmup ...
Prefill. latency: 0.05532 s, throughput: 2313.73 token/s
Decode. latency: 0.00896 s, throughput: 111.65 token/s
Decode. latency: 0.00833 s, throughput: 120.04 token/s
Decode. latency: 0.00869 s, throughput: 115.06 token/s
Decode. latency: 0.00842 s, throughput: 118.79 token/s
Decode. median latency: 0.00855 s, median throughput: 116.89 token/s
Total. latency: 0.090 s, throughput: 1471.26 token/s
Benchmark ...
Prefill. latency: 0.04294 s, throughput: 2980.61 token/s
Decode. latency: 0.00839 s, throughput: 119.12 token/s
Decode. latency: 0.00828 s, throughput: 120.78 token/s
Decode. latency: 0.00857 s, throughput: 116.64 token/s
Decode. latency: 0.00853 s, throughput: 117.19 token/s
Decode. latency: 0.00859 s, throughput: 116.39 token/s
Decode. median latency: 0.00853 s, median throughput: 117.17 token/s
Total. latency: 0.111 s, throughput: 1226.84 token/s
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128
Warmup ...
Prefill. latency: 0.06413 s, throughput: 1996.05 token/s
Decode. latency: 0.00764 s, throughput: 130.84 token/s
Decode. latency: 0.00748 s, throughput: 133.73 token/s
Decode. latency: 0.00725 s, throughput: 137.84 token/s
Decode. latency: 0.00721 s, throughput: 138.74 token/s
Decode. median latency: 0.00737 s, median throughput: 135.76 token/s
Total. latency: 0.094 s, throughput: 1408.61 token/s
Benchmark ...
Prefill. latency: 0.05239 s, throughput: 2443.43 token/s
Decode. latency: 0.00739 s, throughput: 135.25 token/s
Decode. latency: 0.00720 s, throughput: 138.90 token/s
Decode. latency: 0.00718 s, throughput: 139.21 token/s
Decode. latency: 0.00722 s, throughput: 138.42 token/s
Decode. latency: 0.00745 s, throughput: 134.30 token/s
Decode. median latency: 0.00731 s, median throughput: 136.82 token/s
Total. latency: 0.111 s, throughput: 1223.51 token/s
A100, no compile
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo max_total_num_tokens=199454
Warmup ...
Prefill. latency: 0.06958 s, throughput: 1839.60 token/s
Decode. latency: 0.02343 s, throughput: 42.68 token/s
Decode. latency: 0.02342 s, throughput: 42.70 token/s
Decode. latency: 0.02368 s, throughput: 42.23 token/s
Decode. latency: 0.02337 s, throughput: 42.80 token/s
Decode. median latency: 0.02342 s, median throughput: 42.69 token/s
Total. latency: 0.163 s, throughput: 807.48 token/s
Benchmark ...
Prefill. latency: 0.05767 s, throughput: 2219.36 token/s
Decode. latency: 0.02293 s, throughput: 43.61 token/s
Decode. latency: 0.02026 s, throughput: 49.36 token/s
Decode. latency: 0.02029 s, throughput: 49.29 token/s
Decode. latency: 0.02024 s, throughput: 49.41 token/s
Decode. latency: 0.02026 s, throughput: 49.36 token/s
Decode. median latency: 0.02025 s, median throughput: 49.39 token/s
Total. latency: 0.222 s, throughput: 611.87 token/s
Reviewers:
Subscribers:
Tasks:
Tags: