Skip to content

[ROCm] Fix cu_seqlens_q off-by-one in AITER FA speculative decode path#39120

Merged
tjtanaa merged 8 commits intovllm-project:mainfrom
Bortlesboat:fix-aiter-fa-spec-decode-off-by-one
Apr 19, 2026
Merged

[ROCm] Fix cu_seqlens_q off-by-one in AITER FA speculative decode path#39120
tjtanaa merged 8 commits intovllm-project:mainfrom
Bortlesboat:fix-aiter-fa-spec-decode-off-by-one

Conversation

@Bortlesboat
Copy link
Copy Markdown
Contributor

@Bortlesboat Bortlesboat commented Apr 6, 2026

In the speculative decode path of AiterFlashAttentionImpl (when decode_max_query_len > 1), cu_seqlens_q is sliced as query_start_loc[:num_decodes], giving num_decodes elements. But cu_seqlens_q for unified_attention is a cumulative length array that needs num_seqs + 1 entries (including the leading 0).

The correct pattern is already used in the fallback path 33 lines later (line 1228-1229), which correctly does [:num_decodes + 1].

Also fixed descale_shape which computed shape[0] - 1 on the wrong-length slice, producing num_decodes - 1 instead of num_decodes. Simplified to use num_decodes directly, matching line 1232.

No duplicate PRs found.

Testing:

  • uvx pre-commit run ruff-format --files vllm/v1/attention/backends/rocm_aiter_fa.py (passes locally)
  • Previous CI failure only requested the matching ruff format whitespace change on this line; this follow-up commit applies that exact diff.

AI assistance was used for code search and CI log review; all changes reviewed by the human submitter.

Signed-off-by: Bortlesboat bortstheboat@gmail.com

Reference: The correct cu_seqlens_q slicing is confirmed by the upstream AITER unified_attention implementation:
https://github.com/ROCm/aiter/blob/bc5ea32c19c604cda2f4781a97cb04d5fc494543/aiter/ops/triton/_triton_kernels/attention/unified_attention.py#L426

@Bortlesboat Bortlesboat requested a review from tjtanaa as a code owner April 6, 2026 22:16
@mergify mergify Bot added rocm Related to AMD ROCm v1 labels Apr 6, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 6, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the ROCm AITER Flash Attention backend to update the calculation of descale_shape and the slicing of cu_seqlens_q passed to the unified_attention function. These changes ensure the correct number of decodes and sequence lengths are used. I have no feedback to provide.

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

While this sounds theoretically correct, I think that there is a difference in the behavior of flash attention on ROCm VS upstream. I will enable the spec decode tests just to be sure, but they are currently passing.

@AndreasKaratzas AndreasKaratzas added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 7, 2026
Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa tjtanaa enabled auto-merge (squash) April 9, 2026 13:20
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 9, 2026

Hi @Bortlesboat, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Added the AITER reference link to the description — that unified_attention.py line makes it clear cu_seqlens_q needs num_seqs + 1 entries. Thanks for the pointer.

auto-merge was automatically disabled April 9, 2026 15:27

Head branch was pushed to by a user without write access

@Bortlesboat Bortlesboat force-pushed the fix-aiter-fa-spec-decode-off-by-one branch from 74e0e09 to 01f4f2b Compare April 9, 2026 15:45
@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Remaining red on this PR is buildkite/ci/pr/multi-modal-models-standard-4-other-plus-whisper on commit 01f4f2b.

I re-checked the pipeline config and test layout: that lane is the broad multimodal + Whisper suite gated on any vllm/ change, while this PR only touches vllm/v1/attention/backends/rocm_aiter_fa.py. The targeted spec-decode lane (buildkite/ci/pr/v1-spec-decode) is green on the same commit.

I don't have Buildkite access from here to manually rebuild that step, but this looks unrelated to the speculative-decode fix. Could an authorized maintainer retry the failing Buildkite step?

@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Bortlesboat commented Apr 14, 2026

Quick recheck on this one: the branch is still the same approved change, buildkite/amd-ci is green, and buildkite/ci/pr/v1-spec-decode is green. The only remaining red is still buildkite/ci/pr/multi-modal-models-standard-4-other-plus-whisper from build 60608. If that lane still looks unrelated on your side, could someone with Buildkite access rerun it or mark this ready?

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@Bortlesboat probably rebase the branch. if the issue persists, we will post the PR on slack

In the AITER FA spec decode path (decode_max_query_len > 1), cu_seqlens_q
was sliced as query_start_loc[:num_decodes] but should be
[:num_decodes + 1] since cu_seqlens is a cumulative sum that needs
num_seqs + 1 entries.

The correct pattern is used 33 lines later at line 1228-1229 for the
fallback path, which does [:num_decodes + 1].

Similarly, descale_shape computed shape[0] - 1 on the wrong slice,
producing num_decodes - 1 instead of num_decodes. Simplified to just
use num_decodes directly, matching the fallback path at line 1232.

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
@Bortlesboat Bortlesboat force-pushed the fix-aiter-fa-spec-decode-off-by-one branch from 01f4f2b to 7f9f7aa Compare April 14, 2026 16:30
@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Bortlesboat commented Apr 14, 2026

Rebased onto current main per the note above. Patch content is still the same narrow rocm_aiter_fa.py fix, and CI is rerunning on 7f9f7aa.

@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Bortlesboat commented Apr 17, 2026

Quick update after the rebase rerun on 7f9f7aa: buildkite/ci/pr/v1-spec-decode is green, the earlier multimodal lane is green now too, and the only red in the main PR build is buildkite/ci/pr/cpu-language-generation-and-pooling-model-tests. buildkite/amd-ci is still running.

This patch is still just the small vllm/v1/attention/backends/rocm_aiter_fa.py fix, so that CPU language-generation/pooling failure looks unrelated from my side. If it looks the same with Buildkite access, could someone retry that lane when convenient?

@tjtanaa tjtanaa enabled auto-merge (squash) April 17, 2026 10:07
@Bortlesboat
Copy link
Copy Markdown
Contributor Author

Bortlesboat commented Apr 19, 2026

Merged current main to refresh the branch on the latest base. The functional diff is still the same 4-line vllm/v1/attention/backends/rocm_aiter_fa.py off-by-one fix, and CI is rerunning on 4d55a91.

@tjtanaa tjtanaa merged commit f150107 into vllm-project:main Apr 19, 2026
59 of 60 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 19, 2026
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 20, 2026
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
vllm-project#39120)

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
vllm-project#39120)

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
Signed-off-by: Adrian <info@zzit.ch>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants