-
Notifications
You must be signed in to change notification settings - Fork 644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] support torchao for qwen2 models #2219
Comments
I believe this issue is related to the implementation of TorchAO. For example, with To achieve speedup, you might consider using |
#1341 |
There is no speedup with |
It seems torchao has not been applied to qwen. Can you copy this line from llama.py sglang/python/sglang/srt/models/llama.py Line 426 in 906d795
Please send a pull request after you find it works! |
yes, seems the torchao-config works. but the int8wo's throughput decreases, and the int4wo-128 increases. python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --torchao-config int4wo-128 By the way, with adding --eneble-torch-compile, seems error [rank0]: from user code: [rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [rank0]: You can suppress this exception and fall back to eager by setting: |
ok, and when i use --enable-torch-compile and int8wo or int4wo-128, errors [rank0]: from user code: [rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [rank0]: You can suppress this exception and fall back to eager by setting: |
@HandH1998 |
@HandH1998 @merrymercy [rank0]: File "/home/service/var/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 273, in capture_one_batch_size |
please use torch 2.5 by |
@merrymercy @HandH1998 |
Hi @tricky61 @merrymercy , do your guys have any ideas of the above error? I have tried to upgrade torch to 2.5 by My Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team My Name: vllm
Version: 0.6.4.post1
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
Home-page: https://github.com/vllm-project/vllm
Author: vLLM Team My script to reproduce the error is python3 -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--torchao-config int8dq \
--port 30000 --host 0.0.0.0 When disable capture cuda graph, the server can run, but it seems extremely slow. python3 -m sglang.launch_server \
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--torchao-config int8dq \
--disable-cuda-graph \
--port 30000 --host 0.0.0.0 |
I used one A30 card, and used Qwen2-7B-Instruct, the speed with quantization seems no different
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100
Benchmark ...
Prefill. latency: 0.03508 s, throughput: 5700.84 token/s
Decode. latency: 0.01952 s, throughput: 51.23 token/s
Decode. latency: 0.01947 s, throughput: 51.37 token/s
Decode. latency: 0.01939 s, throughput: 51.58 token/s
Decode. latency: 0.01933 s, throughput: 51.74 token/s
Decode. latency: 0.01928 s, throughput: 51.87 token/s
Decode. median latency: 0.01924 s, median throughput: 51.98 token/s
Total. latency: 1.942 s, throughput: 154.52 token/s
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile
Benchmark ...
Prefill. latency: 0.03655 s, throughput: 5471.84 token/s
Decode. latency: 0.01852 s, throughput: 54.00 token/s
Decode. latency: 0.01847 s, throughput: 54.14 token/s
Decode. latency: 0.01845 s, throughput: 54.21 token/s
Decode. latency: 0.01843 s, throughput: 54.26 token/s
Decode. latency: 0.01838 s, throughput: 54.39 token/s
Decode. median latency: 0.01836 s, median throughput: 54.46 token/s
Total. latency: 1.855 s, throughput: 161.71 token/s
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile --torchao-config int8wo
Benchmark ...
Prefill. latency: 0.04469 s, throughput: 4475.31 token/s
Decode. latency: 0.01860 s, throughput: 53.77 token/s
Decode. latency: 0.01849 s, throughput: 54.09 token/s
Decode. latency: 0.01844 s, throughput: 54.24 token/s
Decode. latency: 0.01841 s, throughput: 54.32 token/s
Decode. latency: 0.01837 s, throughput: 54.45 token/s
Decode. median latency: 0.01836 s, median throughput: 54.46 token/s
Total. latency: 1.863 s, throughput: 160.99 token/s
python3 -m sglang.bench_latency --model ../Qwen2-7B-Instruct --batch-size 1 --input-len 200 --output-len 100 --enable-torch-compile --torchao-config int4wo
Benchmark ...
Prefill. latency: 0.03558 s, throughput: 5621.52 token/s
Decode. latency: 0.01855 s, throughput: 53.91 token/s
Decode. latency: 0.01852 s, throughput: 54.01 token/s
Decode. latency: 0.01845 s, throughput: 54.20 token/s
Decode. latency: 0.01842 s, throughput: 54.28 token/s
Decode. latency: 0.01841 s, throughput: 54.33 token/s
Decode. median latency: 0.01837 s, median throughput: 54.44 token/s
Total. latency: 1.855 s, throughput: 161.72 token/s
The text was updated successfully, but these errors were encountered: