[ROCm] Fix cu_seqlens_q off-by-one in AITER FA speculative decode path#39120
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the ROCm AITER Flash Attention backend to update the calculation of descale_shape and the slicing of cu_seqlens_q passed to the unified_attention function. These changes ensure the correct number of decodes and sequence lengths are used. I have no feedback to provide.
|
While this sounds theoretically correct, I think that there is a difference in the behavior of flash attention on ROCm VS upstream. I will enable the spec decode tests just to be sure, but they are currently passing. |
tjtanaa
left a comment
There was a problem hiding this comment.
LGTM. Thanks for catching this. Add this link as a proof for review.
|
Hi @Bortlesboat, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Added the AITER reference link to the description — that unified_attention.py line makes it clear cu_seqlens_q needs num_seqs + 1 entries. Thanks for the pointer. |
Head branch was pushed to by a user without write access
74e0e09 to
01f4f2b
Compare
|
Remaining red on this PR is I re-checked the pipeline config and test layout: that lane is the broad multimodal + Whisper suite gated on any I don't have Buildkite access from here to manually rebuild that step, but this looks unrelated to the speculative-decode fix. Could an authorized maintainer retry the failing Buildkite step? |
|
Quick recheck on this one: the branch is still the same approved change, |
|
@Bortlesboat probably rebase the branch. if the issue persists, we will post the PR on slack |
In the AITER FA spec decode path (decode_max_query_len > 1), cu_seqlens_q was sliced as query_start_loc[:num_decodes] but should be [:num_decodes + 1] since cu_seqlens is a cumulative sum that needs num_seqs + 1 entries. The correct pattern is used 33 lines later at line 1228-1229 for the fallback path, which does [:num_decodes + 1]. Similarly, descale_shape computed shape[0] - 1 on the wrong slice, producing num_decodes - 1 instead of num_decodes. Simplified to just use num_decodes directly, matching the fallback path at line 1232. Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
01f4f2b to
7f9f7aa
Compare
|
Rebased onto current main per the note above. Patch content is still the same narrow rocm_aiter_fa.py fix, and CI is rerunning on 7f9f7aa. |
|
Quick update after the rebase rerun on This patch is still just the small |
|
Merged current |
vllm-project#39120) Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
vllm-project#39120) Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
vllm-project#39120) Signed-off-by: Bortlesboat <bortstheboat@gmail.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
vllm-project#39120) Signed-off-by: Bortlesboat <bortstheboat@gmail.com> Signed-off-by: Adrian <info@zzit.ch>
vllm-project#39120) Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
In the speculative decode path of
AiterFlashAttentionImpl(whendecode_max_query_len > 1),cu_seqlens_qis sliced asquery_start_loc[:num_decodes], givingnum_decodeselements. Butcu_seqlens_qforunified_attentionis a cumulative length array that needsnum_seqs + 1entries (including the leading 0).The correct pattern is already used in the fallback path 33 lines later (line 1228-1229), which correctly does
[:num_decodes + 1].Also fixed
descale_shapewhich computedshape[0] - 1on the wrong-length slice, producingnum_decodes - 1instead ofnum_decodes. Simplified to usenum_decodesdirectly, matching line 1232.No duplicate PRs found.
Testing:
uvx pre-commit run ruff-format --files vllm/v1/attention/backends/rocm_aiter_fa.py(passes locally)ruff formatwhitespace change on this line; this follow-up commit applies that exact diff.AI assistance was used for code search and CI log review; all changes reviewed by the human submitter.
Signed-off-by: Bortlesboat bortstheboat@gmail.com
Reference: The correct
cu_seqlens_qslicing is confirmed by the upstream AITERunified_attentionimplementation:https://github.com/ROCm/aiter/blob/bc5ea32c19c604cda2f4781a97cb04d5fc494543/aiter/ops/triton/_triton_kernels/attention/unified_attention.py#L426