Skip to content

Add quickreduce as alternative to custom allreduce#16804

Closed
ilmarkov wants to merge 28 commits intovllm-project:mainfrom
neuralmagic:experimental/quick_reduce
Closed

Add quickreduce as alternative to custom allreduce#16804
ilmarkov wants to merge 28 commits intovllm-project:mainfrom
neuralmagic:experimental/quick_reduce

Conversation

@ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Apr 17, 2025

Add quickreduce alternative to custom allreduce.

The collective is only enabled on AMD, MI300, for fp16/bf16 inputs and when custom allreduce is enabled. The kernels support full precision and quantized (int4 symmetric with group size 32) all reduce collective quantization algorithm.

The quickreduce can be enabled by setting VLLM_ROCM_QR_QUANT_REGIME=[NONE|FP|INT8|INT6|INT4] env variable. quickreduce supports int8, int6, int4 quantization.

PR supports fp16 and bf16 kernels but given the lack of intrinsics of bf16 math operations, bf16 kernels performance is worse (see kernel benchmark results below), so by default we convert bf16 all reduce input to fp16. To disable this behavior one can set VLLM_ROCM_QR_CAST_BF16_TO_FP16=0 env variable.

As long as quickreduce only get the performance benefits at middle/higher input sizes (see kernel benchmarks), vllm keeps using custom allreduce for small inputs. The lower bounds on enabling quickreduce are chosen empirically.

Maximal input size for quickreduce is 2GB.

Benchmark results

(float16):
Server: VLLM_USE_V1=1 VLLM_USE_TRITON_FLASH_ATTN=0 vllm serve meta-llama/Llama-3.1-70B-Instruct --block_size=32 --disable-log-requests --no-enable-prefix-caching -tp $tp --dtype float16
Client: python benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-70B-Instruct --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --num-prompts 500--request-rate 10 --ignore-eos

TP=8

TTFT, ms Speedup TPOT, ms Speedup
baseline 145 1 51 1x
QR, fp 106 1.37x 38 1.34x
QR, int4 90 1.61x 35 1.46x

TP=4

TTFT, ms Speedup TPOT, ms Speedup
baseline 316 1 89 1x
QR, fp 280 1.09x 89 1x
QR, int8 270 1.17x 85 1.05x
QR, int6 222 1.42x 70 1.27x
QR, int4 138 2.2x 45 2x

bfloat16 kernels (--dtype bfloat16; fp16 kernels results in the table are done with VLLM_ROCM_QR_CAST_BF16_TO_FP16=1):
TP=4

TTFT, ms Speedup TPOT, ms Speedup
baseline 316 1 89 1x
QR, fp 324 0.9x 87 1.01x
QR, int8 261 1 84 1x
QR, int6 266 1.18x 85 1.05x
QR, int4 232 1.36x 76 1.17x
QR, fp, fp16 kernels 290 1.08x 95 1x
QR, int8, fp16 kernels 268 1.18 85 1.05x
QR, int6, fp16 kernels 236 1.34x 75 1.18x
QR, int4, fp16 kernels 154 2.05x 50 1.8x

Kernels benchmarking
TP=2

msg size baseline QR FP QR int8 QR int6 QR int4 QR FP bf16 QR int8 bf16 QR int6 bf16 QR int4 bf16
2.0KB 7.26 10.75 22.05 23.79 19.04 13.11 92.69 94.29 92.05
32.0KB 7.40 10.88 22.09 24.18 19.16 13.23 93.39 95.16 92.63
256.0KB 11.87 14.91 23.89 25.54 20.35 17.56 95.58 96.54 94.54
512.0KB 18.08 19.93 25.62 25.88 21.17 22.09 95.30 96.18 94.66
1.0MB 30.07 30.57 31.19 29.81 23.57 32.55 96.10 97.14 95.30
2.0MB 53.76 51.96 43.50 39.25 30.03 54.08 99.33 99.15 96.80
4.0MB 102.19 97.90 66.63 57.63 42.74 98.62 125.23 116.04 105.32
8.0MB 199.01 190.19 115.36 95.66 68.28 191.92 178.99 158.96 138.31
16.0MB 391.43 378.94 219.60 174.70 125.08 378.34 287.39 244.03 201.06
32.0MB 892.99 739.28 425.61 339.69 243.53 741.36 517.28 437.13 351.26
64.0MB 1465.11 1466.43 828.58 650.82 464.56 1470.50 953.49 801.14 635.16
128.0MB 2912.45 2917.14 1634.59 1277.78 898.42 2935.37 1777.66 1450.88 1153.58
256.0MB 5927.01 5822.34 3252.23 2534.88 1772.40 5866.49 3433.40 2804.97 2192.04
512.0MB 11575.91 11639.26 6491.12 5058.28 3522.78 11727.95 6809.77 5505.87 4262.04
1GB 23223.61 23255.00 12971.35 10106.85 7023.81 23435.05 13586.55 10945.95 8392.96
2GB 45968.99 46101.59 26021.00 20228.00 14164.00 47084.30 27227.85 21836.00 16624.25

TP=4

msg size baseline QR FP QR int8 QR int6 QR int4 QR FP bf16 QR int8 bf16 QR int6 bf16 QR int4 bf16
2.0KB 7.18 12.67 23.89 25.62 21.33 16.14 79.68 93.95 84.15
16.0KB 7.26 12.79 23.69 25.84 21.29 16.12 80.08 93.89 84.31
32.0KB 7.58 12.87 23.81 25.94 21.27 16.20 80.72 94.33 84.87
256.0KB 14.39 15.88 26.78 27.48 22.85 19.18 82.15 94.74 86.22
512.0KB 14.65 17.76 27.12 27.98 23.65 21.01 81.73 94.75 86.30
1.0MB 22.91 22.49 29.43 29.85 24.72 26.34 82.55 95.92 81.59
2.0MB 36.64 36.08 40.97 39.27 28.49 38.45 86.48 97.98 82.35
4.0MB 63.86 63.28 66.95 54.81 37.49 72.12 109.41 114.94 90.81
8.0MB 118.31 126.69 126.43 99.41 64.23 137.45 168.64 157.50 116.16
16.0MB 230.48 237.08 204.46 167.08 109.77 237.42 290.08 256.74 183.08
32.0MB 389.03 439.12 390.55 307.40 217.65 441.30 470.49 440.30 304.83
64.0MB 1017.56 825.53 654.79 509.82 364.77 837.36 803.16 731.90 522.51
128.0MB 1910.37 1587.00 1090.06 848.89 596.27 1606.67 1307.73 1220.87 886.31
256.0MB 3542.03 3082.80 1970.84 1535.23 1078.49 3135.44 2281.91 2180.19 1613.20
512.0MB 6560.81 6098.23 3735.02 2892.65 2015.72 6185.95 4282.83 4096.10 3154.83
1GB 12582.56 12105.15 7275.68 5618.14 3895.45 12288.60 8317.19 7991.48 6231.14
2GB 24453.95 24570.59 14636.20 11087.40 7685.00 24529.95 16488.65 15956.70 12265.90

Evaluation results on MMLU benchmark (LLaMa 3.1 70B, TP=8)

MMLU, STEM MMLU, human MMLU, social
baseline 0.76 0.81 0.88
QR, int4 0.76 0.81 0.88

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Apr 17, 2025
@lixixicommute
Copy link
Contributor

hi, @ilmarkov ,
Thank you for your great work.
Does this program have any test data and how well does it work?
It looks like it's still a draft at the moment, will it be refined afterward?

@ilmarkov ilmarkov force-pushed the experimental/quick_reduce branch from 5b81d85 to 96e1a3e Compare May 13, 2025 13:27
@mergify
Copy link

mergify bot commented May 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 14, 2025
@ilmarkov ilmarkov force-pushed the experimental/quick_reduce branch from 96e1a3e to d92ccc8 Compare May 19, 2025 09:42
@ilmarkov ilmarkov force-pushed the experimental/quick_reduce branch from d92ccc8 to 6f17424 Compare May 20, 2025 13:41
@ilmarkov ilmarkov marked this pull request as ready for review May 20, 2025 13:42
@youkaichao
Copy link
Member

youkaichao commented May 20, 2025

On AMD we first check if it is profitable to use quickreduce, otherwise we fallback to custom allreduce.

what are the cases when custom allreduce performs better than quickreduce? It would better if quickreduce can surpass custom allreduce in all cases, then we can use quickreduce as a drop-in replacement of custom allreduce without a new user-facing flag.

@ilmarkov
Copy link
Contributor Author

@youkaichao It is slower for smaller input sizes. We could do the similar approach as custom allreduce has - use one shot for small buffers and two shot for larger ones.

@youkaichao
Copy link
Member

@youkaichao It is slower for smaller input sizes. We could do the similar approach as custom allreduce has - use one shot for small buffers and two shot for larger ones.

that would be great, can you implement it? we can use either quickreduce or custom allreduce at the engine level, instead of dynamically switching based on the input size.

@ilmarkov
Copy link
Contributor Author

Yes, we can try to implement this approach.
Although, custom allreduce setup and implementation is more suitable for low latency small input sizes, whereas quick reduce performs well for bandwidth bottlenecked workloads. At the moment, we use custom allreduce or nccl based on input size.
Also, we will still need the new uder-facing flag as we need to provide switch between quantization regimes allowing user to find a trade-off between accuracy and performance.

@youkaichao
Copy link
Member

Also, we will still need the new uder-facing flag as we need to provide switch between quantization regimes allowing user to find a trade-off between accuracy and performance.

you can use an environment variable, like VLLM_ROCM_CA_BACKEND.

@mergify
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 23, 2025
@ilmarkov ilmarkov force-pushed the experimental/quick_reduce branch from 6f17424 to ad731a5 Compare June 3, 2025 15:54
ilmarkov and others added 24 commits June 17, 2025 16:37
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Add min sizes for QR
Cleanup

Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
@ilmarkov ilmarkov force-pushed the experimental/quick_reduce branch from 380c1b1 to f314fe4 Compare June 17, 2025 16:39
@mergify mergify bot removed the needs-rebase label Jun 17, 2025
@mergify
Copy link

mergify bot commented Jun 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 22, 2025
@tlrmchlsmth
Copy link
Member

closing in favor of #19744

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants