Skip to content

[ROCm][AITER] Fix aiter paged_attention_v1 decode for sliding window and head_size < 64#34570

Merged
vllm-bot merged 8 commits intovllm-project:mainfrom
ROCm:akaratza_fix_aiter_fa
Feb 21, 2026
Merged

[ROCm][AITER] Fix aiter paged_attention_v1 decode for sliding window and head_size < 64#34570
vllm-bot merged 8 commits intovllm-project:mainfrom
ROCm:akaratza_fix_aiter_fa

Conversation

@AndreasKaratzas
Copy link
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 15, 2026

Fixes a regression introduced by #34378 where sliding window models (e.g. Mixtral) and models with head_size < 64 produce completely wrong decode outputs on the ROCm AITER backend.

Root Cause

PR #34378 removed the unified_attention triton kernel path for sliding window decode and routed all decode through paged_attention_v1's ll4mi kernel. This kernel computes:

VHELOOP = HEAD_SIZE / 16 / NWARPS

On ROCm, NWARPS = 4, so for head_size=32 (e.g. TitanML/tiny-mixtral with head_dim = hidden_size / num_attention_heads = 1024 / 32 = 32), VHELOOP = 0. This causes the entire softmax × V multiplication loop to be skipped, producing uninitialized output. The result is completely wrong logprobs from the second token onward.

Fix

Restore the unified_attention triton kernel fallback for decode when:

  • Sliding window attention is active (self.sliding_window[0] != -1), or
  • head_size < 64 (below the ll4mi kernel's minimum supported size)

Also removes the sliding window args from paged_attention_v1 since it no longer handles that path.

Testing

Found via git bisect. Verified fix with:

pytest -s -v tests/models/language/generation/test_common.py::test_models[True-True-5-32-TitanML/tiny-mixtral]

Related

…ing window and small head sizes

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Feb 15, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 15, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a regression in the ROCm AITER backend for models using sliding window attention or having a head size smaller than 64. The fix involves adding a fallback to the unified_attention Triton kernel for these specific cases, which is a sound approach. However, I've identified a critical off-by-one error in the new code path related to sequence length indexing. This appears to be a pre-existing issue that has been propagated, and I've provided a suggestion to correct it to prevent potential crashes or incorrect outputs.

Comment on lines +1104 to +1113
descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be an off-by-one error in how cu_seqlens_q is sliced and how descale_shape is calculated for the decode path. cu_seqlens_q should have a length of num_sequences + 1. For num_decodes sequences, it should be sliced as [:num_decodes + 1]. Consequently, descale_shape should be (num_decodes, num_kv_heads).

This seems to be a pre-existing issue in the paged_attention_v1 path which is being propagated here. Fixing it will prevent potential indexing errors and crashes.

Suggested change
descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
decode_cu_seqlens_q = attn_metadata.query_start_loc[:num_decodes + 1]
descale_shape = (
num_decodes,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_cu_seqlens_q,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_seqlens_q indeed needs num_decodes + 1 entries and descale_shape should be (num_decodes, num_kv_heads). This was a pre-existing bug from the original code before #34378 that I carried over. Fixed in the latest push.

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas
Copy link
Collaborator Author

AndreasKaratzas commented Feb 15, 2026

I have launched a full CI eval for this PR to prevent any potential new bugs after this PR: https://buildkite.com/vllm/amd-ci/builds/4787/steps/canvas

@AndreasKaratzas AndreasKaratzas added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 15, 2026
@AndreasKaratzas
Copy link
Collaborator Author

Failure in V1 attention (H100) looks like it was because of interrupt signal to the agent pool. Also it is uncorrelated to the changes that this PR proposes.

_MIN_HEAD_SIZE_FOR_LL4MI = 64
use_unified_attention = (
self.sliding_window[0] != -1
or self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI
Copy link
Collaborator

@tjtanaa tjtanaa Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are trying to revert the changes, why there is a need for the second conditions?

head_size < 64 (below the ll4mi kernel's minimum supported size) ?

Was the test failing before they remove unified_attention (before #34378)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your intention for a better guard to prevent bugs on future small models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does 'and' works instead? e.g.

use_unified_attention = (
                    self.sliding_window[0] != -1
+                    and self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI

Because I think the original PR #34378 said full size Mistral model seems to work find with using torch.ops.aiter.paged_attention_v1 sliding window feature.

Correct me if I am wrong.

Copy link
Collaborator Author

@AndreasKaratzas AndreasKaratzas Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa

If you are trying to revert the changes, why there is a need for the second conditions?

head_size < 64 (below the ll4mi kernel's minimum supported size) ?

Was the test failing before they remove unified_attention (before #34378)?

Yes, this is a new guard. The old code never hit paged_attention_v1 for small head sizes, but now it does, and it silently produces garbage.

Is your intention for a better guard to prevent bugs on future small models?

Yes, exactly :)

Does 'and' works instead? e.g.

use_unified_attention = (
                    self.sliding_window[0] != -1
+                    and self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI

Because I think the original PR #34378 said full size Mistral model seems to work find with using torch.ops.aiter.paged_attention_v1 sliding window feature.

Correct me if I am wrong.

No, and would miss the exact case we're fixing. The bug manifests with head_size=32 and sliding_window=-1.

layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With your new condition (head_size < 64), can't we enable this changes?

Copy link
Collaborator Author

@AndreasKaratzas AndreasKaratzas Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch. I'll restore those args :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa Done :)

@AndreasKaratzas
Copy link
Collaborator Author

Failure in Entrypoints Integration (Responses API) is being addressed in #33949

@AndreasKaratzas
Copy link
Collaborator Author

Failures are not associated with this PR, since I just only merged upstream main.

@AndreasKaratzas
Copy link
Collaborator Author

Model Initialization and V1 Others tests were broken by this PR: #33600
So they are known failures.

@vllm-bot vllm-bot merged commit cf93c1a into vllm-project:main Feb 21, 2026
52 of 55 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 21, 2026
@dosubot
Copy link

dosubot bot commented Feb 21, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@AndreasKaratzas AndreasKaratzas deleted the akaratza_fix_aiter_fa branch February 21, 2026 04:26
DarkLight1337 pushed a commit to DarkLight1337/vllm that referenced this pull request Feb 21, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
joeqzzuo pushed a commit to joeqzzuo/vllm that referenced this pull request Feb 21, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: joezuo <qianzhou.zuo@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Feb 22, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
jmamou pushed a commit to jmamou/vllm that referenced this pull request Feb 23, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…and head_size < 64 (vllm-project#34570)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants