[ROCm][AITER] Fix aiter paged_attention_v1 decode for sliding window and head_size < 64#34570
Conversation
…ing window and small head sizes Signed-off-by: Andreas Karatzas <akaratza@amd.com>
There was a problem hiding this comment.
Code Review
This pull request addresses a regression in the ROCm AITER backend for models using sliding window attention or having a head size smaller than 64. The fix involves adding a fallback to the unified_attention Triton kernel for these specific cases, which is a sound approach. However, I've identified a critical off-by-one error in the new code path related to sequence length indexing. This appears to be a pre-existing issue that has been propagated, and I've provided a suggestion to correct it to prevent potential crashes or incorrect outputs.
| descale_shape = ( | ||
| attn_metadata.query_start_loc[:num_decodes].shape[0] - 1, | ||
| key_cache.shape[2], | ||
| ) | ||
| unified_attention( | ||
| q=query[:num_decode_tokens], | ||
| k=key_cache, | ||
| v=value_cache, | ||
| out=output[:num_decode_tokens], | ||
| cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes], |
There was a problem hiding this comment.
There appears to be an off-by-one error in how cu_seqlens_q is sliced and how descale_shape is calculated for the decode path. cu_seqlens_q should have a length of num_sequences + 1. For num_decodes sequences, it should be sliced as [:num_decodes + 1]. Consequently, descale_shape should be (num_decodes, num_kv_heads).
This seems to be a pre-existing issue in the paged_attention_v1 path which is being propagated here. Fixing it will prevent potential indexing errors and crashes.
| descale_shape = ( | |
| attn_metadata.query_start_loc[:num_decodes].shape[0] - 1, | |
| key_cache.shape[2], | |
| ) | |
| unified_attention( | |
| q=query[:num_decode_tokens], | |
| k=key_cache, | |
| v=value_cache, | |
| out=output[:num_decode_tokens], | |
| cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes], | |
| decode_cu_seqlens_q = attn_metadata.query_start_loc[:num_decodes + 1] | |
| descale_shape = ( | |
| num_decodes, | |
| key_cache.shape[2], | |
| ) | |
| unified_attention( | |
| q=query[:num_decode_tokens], | |
| k=key_cache, | |
| v=value_cache, | |
| out=output[:num_decode_tokens], | |
| cu_seqlens_q=decode_cu_seqlens_q, |
There was a problem hiding this comment.
cu_seqlens_q indeed needs num_decodes + 1 entries and descale_shape should be (num_decodes, num_kv_heads). This was a pre-existing bug from the original code before #34378 that I carried over. Fixed in the latest push.
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
I have launched a full CI eval for this PR to prevent any potential new bugs after this PR: https://buildkite.com/vllm/amd-ci/builds/4787/steps/canvas |
|
Failure in |
| _MIN_HEAD_SIZE_FOR_LL4MI = 64 | ||
| use_unified_attention = ( | ||
| self.sliding_window[0] != -1 | ||
| or self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI |
There was a problem hiding this comment.
If you are trying to revert the changes, why there is a need for the second conditions?
head_size < 64 (below the ll4mi kernel's minimum supported size) ?
Was the test failing before they remove unified_attention (before #34378)?
There was a problem hiding this comment.
Is your intention for a better guard to prevent bugs on future small models?
There was a problem hiding this comment.
Does 'and' works instead? e.g.
use_unified_attention = (
self.sliding_window[0] != -1
+ and self.head_size < _MIN_HEAD_SIZE_FOR_LL4MIBecause I think the original PR #34378 said full size Mistral model seems to work find with using torch.ops.aiter.paged_attention_v1 sliding window feature.
Correct me if I am wrong.
There was a problem hiding this comment.
If you are trying to revert the changes, why there is a need for the second conditions?
head_size < 64 (below the ll4mi kernel's minimum supported size)?Was the test failing before they remove
unified_attention(before #34378)?
Yes, this is a new guard. The old code never hit paged_attention_v1 for small head sizes, but now it does, and it silently produces garbage.
Is your intention for a better guard to prevent bugs on future small models?
Yes, exactly :)
Does 'and' works instead? e.g.
use_unified_attention = ( self.sliding_window[0] != -1 + and self.head_size < _MIN_HEAD_SIZE_FOR_LL4MIBecause I think the original PR #34378 said full size Mistral model seems to work find with using
torch.ops.aiter.paged_attention_v1sliding window feature.Correct me if I am wrong.
No, and would miss the exact case we're fixing. The bug manifests with head_size=32 and sliding_window=-1.
| layer._v_scale, | ||
| None, | ||
| _PARTITION_SIZE_ROCM, | ||
| 1, |
There was a problem hiding this comment.
With your new condition (head_size < 64), can't we enable this changes?
There was a problem hiding this comment.
Yes, good catch. I'll restore those args :)
…s for paged_attention_v1 Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
Failure in |
|
Failures are not associated with this PR, since I just only merged upstream main. |
|
Model Initialization and V1 Others tests were broken by this PR: #33600 |
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: joezuo <qianzhou.zuo@gmail.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…and head_size < 64 (vllm-project#34570) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Fixes a regression introduced by #34378 where sliding window models (e.g. Mixtral) and models with
head_size < 64produce completely wrong decode outputs on the ROCm AITER backend.Root Cause
PR #34378 removed the
unified_attentiontriton kernel path for sliding window decode and routed all decode throughpaged_attention_v1's ll4mi kernel. This kernel computes:On ROCm,
NWARPS = 4, so forhead_size=32(e.g.TitanML/tiny-mixtralwithhead_dim = hidden_size / num_attention_heads = 1024 / 32 = 32),VHELOOP = 0. This causes the entire softmax × V multiplication loop to be skipped, producing uninitialized output. The result is completely wrong logprobs from the second token onward.Fix
Restore the
unified_attentiontriton kernel fallback for decode when:self.sliding_window[0] != -1), orhead_size < 64(below the ll4mi kernel's minimum supported size)Also removes the sliding window args from
paged_attention_v1since it no longer handles that path.Testing
Found via
git bisect. Verified fix with:Related
head_size < 64or support smaller head dimensions natively.