Skip to content

Conversation

@maleksan85
Copy link
Contributor

@maleksan85 maleksan85 commented Mar 28, 2025

Adopting ROCM Paged Attention to be use in V1 FA as alternative to Triton kernel. Perf I see:

Baseline (VLLM_ROCM_CUSTOM_PAGED_ATTN=0 and --no-enable-prefix-caching):

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  27.51
Total input tokens:                      215196
Total generated tokens:                  197090
Request throughput (req/s):              36.35
Output token throughput (tok/s):         7164.57
Total Token throughput (tok/s):          14987.32
---------------Time to First Token----------------
Mean TTFT (ms):                          5068.86
Median TTFT (ms):                        4666.36
P99 TTFT (ms):                           10794.20
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          59.98
Median TPOT (ms):                        58.69
P99 TPOT (ms):                           87.44
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.91
Median ITL (ms):                         34.99
P99 ITL (ms):                            95.14
==================================================

With change (--no-enable-prefix-caching):

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  23.97
Total input tokens:                      215196
Total generated tokens:                  197119
Request throughput (req/s):              41.73
Output token throughput (tok/s):         8225.24
Total Token throughput (tok/s):          17204.79
---------------Time to First Token----------------
Mean TTFT (ms):                          4853.87
Median TTFT (ms):                        4488.42
P99 TTFT (ms):                           10301.76
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          56.14
Median TPOT (ms):                        54.89
P99 TPOT (ms):                           82.79
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.92
Median ITL (ms):                         32.00
P99 ITL (ms):                            94.70
==================================================

latency:

python benchmarks/benchmark_latency.py --model /data/models/Llama-3.1-8B-Instruct --input-len 2048 --output-len 2048 --batch-size 64 --num-iters-warmup 1 --num-iters 3

V0
Avg latency: 40.9744499316439 seconds

V1
upstream (VLLM_ROCM_CUSTOM_PAGED_ATTN=0):
Avg latency: 46.44388557784259 seconds

with change:
Avg latency: 34.401718624557056 seconds

correctness:
2025-03-31:21:23:17,108 INFO [lm_eval.loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=/data/models/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: 500.0, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.800 ± 0.0179
strict-match 5 exact_match 0.782 ± 0.0185

cc @SageMoore

Signed-off-by: Aleksandr Malyshev <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 28, 2025
@maleksan85 maleksan85 marked this pull request as ready for review March 31, 2025 21:25
@maleksan85 maleksan85 changed the title rocm paged attention for V1 [ROCM][KERNEL] Paged attention for V1 Mar 31, 2025
@maleksan85
Copy link
Contributor Author

maleksan85 commented Mar 31, 2025

llama 70B TP2 same test:
upstream (--no-enable-prefix-caching)

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  96.75
Total input tokens:                      215196
Total generated tokens:                  125080
Request throughput (req/s):              10.34
Output token throughput (tok/s):         1292.76
Total Token throughput (tok/s):          3516.92
---------------Time to First Token----------------
Mean TTFT (ms):                          24240.22
Median TTFT (ms):                        22927.67
P99 TTFT (ms):                           51727.34
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          267.65
Median TPOT (ms):                        233.12
P99 TPOT (ms):                           446.70
---------------Inter-token Latency----------------
Mean ITL (ms):                           172.88
Median ITL (ms):                         104.39
P99 ITL (ms):                            476.44
==================================================

this PR (--no-enable-prefix-caching):

============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  91.33
Total input tokens:                      215196
Total generated tokens:                  127292
Request throughput (req/s):              10.95
Output token throughput (tok/s):         1393.71
Total Token throughput (tok/s):          3749.88
---------------Time to First Token----------------
Mean TTFT (ms):                          23905.00
Median TTFT (ms):                        22544.27
P99 TTFT (ms):                           50916.09
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          258.73
Median TPOT (ms):                        220.76
P99 TPOT (ms):                           440.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           164.74
Median ITL (ms):                         97.24
P99 ITL (ms):                            470.45
==================================================

Signed-off-by: Aleksandr Malyshev <[email protected]>
@maleksan85
Copy link
Contributor Author

maleksan85 commented Apr 1, 2025

Output heavy load with llama 3.1 70B

python vllm/benchmarks/benchmark_serving.py --backend vllm --model /data/models/Llama-3.1-70B-Instruct --dataset-name random --random-input-len 150 --random-output-len 16384 --max-concurrency 32 --num-prompts 512 --seed 42

upstream (--no-enable-prefix-caching):

============ Serving Benchmark Result ============
Successful requests:                     512
Benchmark duration (s):                  2765.66
Total input tokens:                      76800
Total generated tokens:                  1782512
Request throughput (req/s):              0.19
Output token throughput (tok/s):         644.52
Total Token throughput (tok/s):          672.29
---------------Time to First Token----------------
Mean TTFT (ms):                          340.30
Median TTFT (ms):                        284.30
P99 TTFT (ms):                           1063.48
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          57.06
Median TPOT (ms):                        48.44
P99 TPOT (ms):                           196.84
---------------Inter-token Latency----------------
Mean ITL (ms):                           48.18
Median ITL (ms):                         48.23
P99 ITL (ms):                            60.74
==================================================

this PR (--no-enable-prefix-caching):

============ Serving Benchmark Result ============
Successful requests:                     512
Benchmark duration (s):                  2045.44
Total input tokens:                      76800
Total generated tokens:                  1782991
Request throughput (req/s):              0.25
Output token throughput (tok/s):         871.69
Total Token throughput (tok/s):          909.24
---------------Time to First Token----------------
Mean TTFT (ms):                          298.00
Median TTFT (ms):                        258.42
P99 TTFT (ms):                           1071.87
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.30
Median TPOT (ms):                        35.77
P99 TPOT (ms):                           189.64
---------------Inter-token Latency----------------
Mean ITL (ms):                           35.65
Median ITL (ms):                         35.31
P99 ITL (ms):                            39.86
==================================================

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, this adds support for chunked prefill to the existing ROCm custom Paged Attention, and then for prefill we still use the _fwd_kernel?

@maleksan85
Copy link
Contributor Author

maleksan85 commented Apr 1, 2025

So if I understand correctly, this adds support for chunked prefill to the existing ROCm custom Paged Attention, and then for prefill we still use the _fwd_kernel?

correct, we had two kernels, one for chunked prefills in a query and another kernel, which used to be only triton, for decodes in the query. This PR adds alternative to triton decode kernel for set of cases, please see limitations in use_rocm_custom_paged_attention function. As for the rocm paged attention kernel itself, the kernel (which in fact 3 C++ kernels), just skips prefills in the query now.

PS of course there might be more tailored way to utilize rocm paged attention for V1, however it will take more time to update the kernel.

Aleksandr Malyshev added 2 commits April 1, 2025 22:53
@robertgshaw2-redhat
Copy link
Collaborator

Nice. This is consistent with my profiling of where the issues are on V1. Glad to see the update. This generally looks fine to me. I will follow up with a refactor for attn backend.

@SageMoore @ProExpertProg - can you guys look through the cpp and let me know if its okay to merge

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPP looks good to me, very contained and straightforward change 😃

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 2, 2025 04:23
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 2, 2025
@robertgshaw2-redhat
Copy link
Collaborator

image

^ Mistral Small, TP=1, 1000 In | 100 Out

Nice step forward but still juice to squeeze
cc @SageMoore @ProExpertProg

Comment on lines 380 to 389
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
const int64_t query_start_off =
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the static cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be conversion similar to what static cast provides? I mean when assignment happens to int64_t from int?

auto-merge was automatically disabled April 2, 2025 16:20

Head branch was pushed to by a user without write access

@maleksan85 maleksan85 requested a review from tlrmchlsmth as a code owner April 2, 2025 16:20
@mergify mergify bot added the ci/build label Apr 2, 2025
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 3, 2025 02:47
@robertgshaw2-redhat robertgshaw2-redhat merged commit e73ff24 into vllm-project:main Apr 3, 2025
59 checks passed
@maleksan85 maleksan85 deleted the v1_rocm_paged_attention_integration branch April 3, 2025 02:51
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Signed-off-by: xinyuxiao <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Aleksandr Malyshev <[email protected]>
Signed-off-by: root <[email protected]>
Co-authored-by: Aleksandr Malyshev <[email protected]>
Co-authored-by: root <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants