[ROCm][perf] Shuffle KV cache to use paged_attention_common#32914
[ROCm][perf] Shuffle KV cache to use paged_attention_common#32914samutamm wants to merge 18 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
The pull request updates the AITER branch in the Dockerfile and integrates aiter.paged_attention_common for shuffle KV cache handling in rocm_aiter_fa.py. This change aims to fix performance issues with small concurrencies for specific Qwen models. The introduction of temporary tensors (tmp_out, exp_sums, max_logits) and new scaling parameters (K_QScale_hip, V_QScale_hip, K_QScale_asm, V_QScale_asm) to the paged_attention_common function is a significant update to the attention mechanism. I've identified a couple of issues related to variable redefinition and unreachable code that should be addressed.
ea196ed to
6cf3af5
Compare
|
So we would usually split this PR into Upgrade Aiter version first, then only introduce new Kernel. |
|
We will keep this PR in check, once we have AITER commit version upgraded and if it contains the kernel, then we will continue with this PR. |
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
Signed-off-by: Samu Tamminen <stammine@amd.com>
778460c to
3d36878
Compare
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Samu Tamminen <stammine@amd.com>
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Samu Tamminen <stammine@amd.com>
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| _, num_heads, head_size = query.shape | ||
| num_seqs = attn_metadata.seq_lens.shape[0] | ||
|
|
||
| if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): |
There was a problem hiding this comment.
@samutamm are you confident that we can remove this is_shuffle_kv_cache_enabled and envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT condition and always use this rocm_aiter_ops.paged_attention_common function?
paged_attention_common is invoking the torch.ops.aiter.paged_attention_v1
Since rocm_aiter_ops.paged_attention_common does not expose sliding_windows parameter, I think we still need to keep the invocation of torch.ops.aiter.paged_attention_v1 , we call torch.ops.aiter.paged_attention_v1 if if self.sliding_window[0] != -1:
There was a problem hiding this comment.
So far VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT seems useful in cases we've seen. Right, paged_attention_common does not have sliding_windows, conserving two paths makes sense.
There was a problem hiding this comment.
So far
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUTseems useful in cases we've seen.
When is it useful?
There was a problem hiding this comment.
I just found one that we might need to keep the flag for now, when trying to run Qwen/Qwen3.5-397B-A17B-FP8 with kv-cache-dtype=fp8
MODEL=Qwen/Qwen3.5-397B-A17B-FP8
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--max-num-batched-tokens 32768 \
--disable-log-requests \
--kv-cache-dtype fp8 \
--compilation-config '{"cudagraph_mode": "FULL"}' \
--trust-remote-code \
--enable_expert_parallel \
--port 6789
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] self.impl.forward(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/app/reviewpr5/pa_common_shuffle_kv_cache/vllm/v1/attention/backends/rocm_aiter_fa.py", line 1248, in forward
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] rocm_aiter_ops.paged_attention_common(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/app/reviewpr5/pa_common_shuffle_kv_cache/vllm/_aiter_ops.py", line 1836, in paged_attention_common
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] return paged_attention_common(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/ops/attention.py", line 189, in paged_attention_common
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] output = pa_fwd_asm(
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 278, in wrapper_custom
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in __call__
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] return self._op(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 301, in outer_wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] wrapper(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/jit/utils/torch_guard.py", line 196, in wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] return func(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 970, in custom_wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] return wrapper(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 966, in wrapper
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] return op(*args, **kwargs)
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] ^^^^^^^^^^^^^^^^^^^
(Worker_TP4_EP4 pid=3455) ERROR 02-27 09:59:53 [multiproc_executor.py:863] RuntimeError: get_heuristic_kernel: cannot get heuristic kernel! q_type:bf16 kv_type:fp8 gqa:4 mtp:0 msk:0 hp:1 block_size:32 ps:0 qTile:0
There was a problem hiding this comment.
So far
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUTseems useful in cases we've seen.
For Qwen/Qwen3-235B-A22B-Instruct-2507 (bf16) VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1 brought 8% uplift for larger concurrencies (> 128). For Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 the uplift is smaller, see the table in the PR description. For Qwen/Qwen3-VL-235B-A22B-Instruct-FP8 uplift was 1.7%.
There was a problem hiding this comment.
As discussed above, it seems to me that sliding window is neither supported in paged_attention_common so in that case we should call paged_attention_v1 I guess.
There was a problem hiding this comment.
@Rohan138 your PR fuses rope+kvcache but it doesn’t support shuffled layout. Any chance of us also supporting the shuffled layout in that fusion (in another PR possibly)? AFAIU, for some models (e.g., llama 3s – I think) one would benefit from both (the shuffled layout and rope+kvcache fusion) but absent support one needs to pick one which in terms of perf is not optimal. Any thoughts?
There was a problem hiding this comment.
@tuukkjs yea. paged_attention_common is still not compatible with sliding windows. We have to make sure in regular PA, it is working and doesn't have accuracy issues.
There was a problem hiding this comment.
@Rohan138 ok. Then in this case, we still need to preserve the environment variable flag VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT. But we will still keep VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=0 as default because we will have more optimization compatible with non-shuffled kvcache layout.
There was a problem hiding this comment.
Short update. Fix to restrict the pa_fwd_asm kernel to head_size 128 in aiter has been merged. I’ve been running lm_eval for shuffle=0 (main), shuffle=1 (main), shuffle=1 (PR) with three models with head size 128: Qwen/Qwen3-235B-A22B-Instruct-2507-FP8, amd/Llama-3.3-70B-Instruct-FP8-KV and amd/Llama-3.1-405B-Instruct-FP8-KV. It seems there may be an incorrectness issue for the llamas for the paged_attention_rocm path of paged_attention_common. We are investigating it. Perf seems better for paged_attention_common. If we can resolve the issue, I will run perf benchmarks for the three models in the three cases. If they look good I guess we can proceed with the merge?
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Samu Tamminen <stammine@amd.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Samu Tamminen <stammine@amd.com>
|
Hi @samutamm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Samu Tamminen <stammine@amd.com>
|
The gsm8k accuracy is Benchmark command: Without preshuffle enabled, the accuracy is as follows |
Purpose
For Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 model, currently
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=1performs worse on small concurrencies, compared toVLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=0. This PR fixes the issue usingpaged_attention_commonfrom aiter (see ROCm/aiter#1821).Test Plan
For input and output lengths of 1k and 8k and concurrencies from 8, 18, 32, 64, 128, compare current main branch with and without VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT (_vllm_main_shuffle1 and _vllm_main_shuffle0, respectively) to changes of this PR (_pr_shuffle1).
Also verified on MI355.
Also verified for Qwen/Qwen3-235B-A22B-Instruct-2507.
Test Result
For input length 8k and output length 1k (green lines), the changes of this PR (_pr_shuffle1, the solid line) outperform main branch, with or without shuffle kv cache.
For input length 1k and output length 8k (orange lines), the changes of this PR (_pr_shuffle1, the solid line) outperform main branch, with or without shuffle kv cache.
For input length 1k and output length 1k (blue lines), the changes of this PR (_pr_shuffle1, the solid line) are very close to main branch. This might require further adjustment in aiter
paged_attention_common.Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.