Skip to content

[ROCm][perf] Shuffle KV cache to use paged_attention_common#32914

Open
samutamm wants to merge 18 commits intovllm-project:mainfrom
samutamm:pa_common_shuffle_kv_cache
Open

[ROCm][perf] Shuffle KV cache to use paged_attention_common#32914
samutamm wants to merge 18 commits intovllm-project:mainfrom
samutamm:pa_common_shuffle_kv_cache

Conversation

@samutamm
Copy link
Contributor

@samutamm samutamm commented Jan 23, 2026

Purpose

For Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 model, currently VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 performs worse on small concurrencies, compared to VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=0. This PR fixes the issue using paged_attention_common from aiter (see ROCm/aiter#1821).

Test Plan

For input and output lengths of 1k and 8k and concurrencies from 8, 18, 32, 64, 128, compare current main branch with and without VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT (_vllm_main_shuffle1 and _vllm_main_shuffle0, respectively) to changes of this PR (_pr_shuffle1).

Also verified on MI355.
Also verified for Qwen/Qwen3-235B-A22B-Instruct-2507.

Test Result

Qwen_Qwen3-235B-A22B-Instruct-2507-FP8_combined

For input length 8k and output length 1k (green lines), the changes of this PR (_pr_shuffle1, the solid line) outperform main branch, with or without shuffle kv cache.
For input length 1k and output length 8k (orange lines), the changes of this PR (_pr_shuffle1, the solid line) outperform main branch, with or without shuffle kv cache.
For input length 1k and output length 1k (blue lines), the changes of this PR (_pr_shuffle1, the solid line) are very close to main branch. This might require further adjustment in aiter paged_attention_common.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added ci/build rocm Related to AMD ROCm v1 labels Jan 23, 2026
@samutamm samutamm changed the title Pa common shuffle kv cache [ROCm][perf] Shuffle KV cache to use paged_attention_common Jan 23, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request updates the AITER branch in the Dockerfile and integrates aiter.paged_attention_common for shuffle KV cache handling in rocm_aiter_fa.py. This change aims to fix performance issues with small concurrencies for specific Qwen models. The introduction of temporary tensors (tmp_out, exp_sums, max_logits) and new scaling parameters (K_QScale_hip, V_QScale_hip, K_QScale_asm, V_QScale_asm) to the paged_attention_common function is a significant update to the attention mechanism. I've identified a couple of issues related to variable redefinition and unreachable code that should be addressed.

@samutamm samutamm force-pushed the pa_common_shuffle_kv_cache branch from ea196ed to 6cf3af5 Compare January 23, 2026 06:42
@tjtanaa
Copy link
Collaborator

tjtanaa commented Jan 23, 2026

@gshtras what do you think about this aiter commit?

@samutamm before upgrading aiter commits, we also need to determine if all other popular models and settings does not have issue. So we will need time for reviewing this PR

E.g.
Llama4 tp, ep
Deepseek tp ep, MTP
Qwen3 moe models tp ep,
Gptoss

@tjtanaa
Copy link
Collaborator

tjtanaa commented Jan 23, 2026

So we would usually split this PR into

Upgrade Aiter version first, then only introduce new Kernel.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Jan 23, 2026

We will keep this PR in check, once we have AITER commit version upgraded and if it contains the kernel, then we will continue with this PR.

Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
@samutamm samutamm force-pushed the pa_common_shuffle_kv_cache branch from 778460c to 3d36878 Compare February 11, 2026 08:32
@mergify
Copy link

mergify bot commented Feb 11, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Samu Tamminen <stammine@amd.com>
@mergify
Copy link

mergify bot commented Feb 11, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Samu Tamminen <stammine@amd.com>
@mergify
Copy link

mergify bot commented Feb 12, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Feb 12, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

_, num_heads, head_size = query.shape
num_seqs = attn_metadata.seq_lens.shape[0]

if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samutamm are you confident that we can remove this is_shuffle_kv_cache_enabled and envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT condition and always use this rocm_aiter_ops.paged_attention_common function?

paged_attention_common is invoking the torch.ops.aiter.paged_attention_v1

Since rocm_aiter_ops.paged_attention_common does not expose sliding_windows parameter, I think we still need to keep the invocation of torch.ops.aiter.paged_attention_v1 , we call torch.ops.aiter.paged_attention_v1 if if self.sliding_window[0] != -1:

Copy link
Contributor Author

@samutamm samutamm Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT seems useful in cases we've seen. Right, paged_attention_common does not have sliding_windows, conserving two paths makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT seems useful in cases we've seen.

When is it useful?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found one that we might need to keep the flag for now, when trying to run Qwen/Qwen3.5-397B-A17B-FP8 with kv-cache-dtype=fp8

MODEL=Qwen/Qwen3.5-397B-A17B-FP8

VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--max-num-batched-tokens 32768 \
--disable-log-requests \
--kv-cache-dtype fp8 \
--compilation-config '{"cudagraph_mode": "FULL"}' \
--trust-remote-code \
--enable_expert_parallel \
--port 6789
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     self.impl.forward(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/app/reviewpr5/pa_common_shuffle_kv_cache/vllm/v1/attention/backends/rocm_aiter_fa.py", line 1248, in forward
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     rocm_aiter_ops.paged_attention_common(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/app/reviewpr5/pa_common_shuffle_kv_cache/vllm/_aiter_ops.py", line 1836, in paged_attention_common
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     return paged_attention_common(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]            ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/attention.py", line 189, in paged_attention_common
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     output = pa_fwd_asm(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]              ^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 278, in wrapper_custom
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in __call__
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     return self._op(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 301, in outer_wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     wrapper(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 196, in wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     return func(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 970, in custom_wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     return wrapper(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]            ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 966, in wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]     return op(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863]            ^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] RuntimeError: get_heuristic_kernel: cannot get heuristic kernel! q_type:bf16 kv_type:fp8 gqa:4 mtp:0 msk:0 hp:1 block_size:32 ps:0 qTile:0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT seems useful in cases we've seen.

For Qwen/Qwen3-235B-A22B-Instruct-2507 (bf16) VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 brought 8% uplift for larger concurrencies (> 128). For Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 the uplift is smaller, see the table in the PR description. For Qwen/Qwen3-VL-235B-A22B-Instruct-FP8 uplift was 1.7%.

Copy link

@tuukkjs tuukkjs Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed above, it seems to me that sliding window is neither supported in paged_attention_common so in that case we should call paged_attention_v1 I guess.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rohan138 your PR fuses rope+kvcache but it doesn’t support shuffled layout. Any chance of us also supporting the shuffled layout in that fusion (in another PR possibly)? AFAIU, for some models (e.g., llama 3s – I think) one would benefit from both (the shuffled layout and rope+kvcache fusion) but absent support one needs to pick one which in terms of perf is not optimal. Any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tuukkjs yea. paged_attention_common is still not compatible with sliding windows. We have to make sure in regular PA, it is working and doesn't have accuracy issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rohan138 ok. Then in this case, we still need to preserve the environment variable flag VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT. But we will still keep VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=0 as default because we will have more optimization compatible with non-shuffled kvcache layout.

Copy link

@tuukkjs tuukkjs Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short update. Fix to restrict the pa_fwd_asm kernel to head_size 128 in aiter has been merged. I’ve been running lm_eval for shuffle=0 (main), shuffle=1 (main), shuffle=1 (PR) with three models with head size 128: Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, amd/Llama-3.3-70B-Instruct-FP8-KV and amd/Llama-3.1-405B-Instruct-FP8-KV. It seems there may be an incorrectness issue for the llamas for the paged_attention_rocm path of paged_attention_common. We are investigating it. Perf seems better for paged_attention_common. If we can resolve the issue, I will run perf benchmarks for the three models in the three cases. If they look good I guess we can proceed with the merge?

@mergify
Copy link

mergify bot commented Feb 16, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Samu Tamminen <stammine@amd.com>
@samutamm samutamm requested a review from tjtanaa February 16, 2026 11:38
@mergify
Copy link

mergify bot commented Feb 21, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @samutamm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Samu Tamminen <stammine@amd.com>
@mergify mergify bot removed the needs-rebase label Feb 24, 2026
@mergify
Copy link

mergify bot commented Feb 24, 2026

Hi @samutamm, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
@samutamm
Copy link
Contributor Author

@tjtanaa I updated the branch after #34570 and verified that I still get same results. Can we get this merged?

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 27, 2026
@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 27, 2026

The gsm8k accuracy is 0 for Qwen/Qwen3.5-397B-A17B-FP8

#!/bin/bash

rm -rf /root/.cache/vllm

MODEL=Qwen/Qwen3.5-397B-A17B-FP8

VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--max-num-batched-tokens 32768 \
--disable-log-requests \
--trust-remote-code \
--enable_expert_parallel \
--reasoning-parser qwen3 \
--port 6789

Benchmark command:

#!/bin/bash

lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=Qwen/Qwen3.5-397B-A17B-FP8,base_url=http://127.0.0.1:6789/v1/completions \
--batch_size 100 \
> lmeval_server-Qwen_Qwen3.5-397B-A17B-FP8-aiter-v1-fp8-cudagraph_FULL-shuffle_kv_cache_layout_1-bf16kvcache.log 2>&1

Without preshuffle enabled, the accuracy is as follows

2026-02-27:10:17:05 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'Qwen/Qwen3.5-397B-A17B-FP8', 'base_url': 'http://127.0.0.1:6789/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 100
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8484|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.8340|±  |0.0102|

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

4 participants