Skip to content

Reapply [Attention][FA3] Update FA3 to include new swizzle optimization#34043

Merged
vllm-bot merged 3 commits intomainfrom
lwilkinson/fix-fa3-swizzle
Feb 11, 2026
Merged

Reapply [Attention][FA3] Update FA3 to include new swizzle optimization#34043
vllm-bot merged 3 commits intomainfrom
lwilkinson/fix-fa3-swizzle

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Feb 7, 2026

Reapply #23465 after revert in #33841 but with correct metadata sizes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the flash-attention dependency and adjusts the scheduler_metadata buffer size to accommodate changes in flash-attention and handle edge cases related to CUDA graph capturing. The changes appear correct and necessary. My review focuses on improving the clarity of the comments associated with these critical buffer size calculations. Enhancing these comments will improve maintainability and help prevent future confusion or bugs, especially since this area of the code has been a source of issues in the past.

Comment on lines +314 to +317
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment explaining the buffer size calculation is a bit confusing as it compares max_cudagraph_size (number of tokens) with max_num_seqs (number of sequences) without full context. A more detailed explanation would improve clarity and maintainability, especially given the history of issues with this code block.

During CUDA graph capture for uniform decode batches, the number of requests (num_reqs) is set to be equal to the number of tokens (num_tokens). This num_tokens can be up to max_cudagraph_capture_size. In some configurations (especially in tests), max_cudagraph_capture_size can be larger than max_num_seqs. Since the scheduler_metadata buffer size depends on the number of requests, we must allocate a buffer large enough for this capture-time scenario.

Suggested change
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
# During CUDA graph capture, `num_reqs` can be up to `max_cudagraph_size`.
# This can be larger than `max_num_seqs`, so we use the max of the two.
# The size is multiplied by 4 due to flash-attention's requirement:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653

Comment on lines +133 to +136
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment explaining the buffer size calculation is a bit confusing as it compares max_cudagraph_size (number of tokens) with max_num_seqs (number of sequences) without full context. A more detailed explanation would improve clarity and maintainability, especially given the history of issues with this code block.

During CUDA graph capture for uniform decode batches, the number of requests (num_reqs) is set to be equal to the number of tokens (num_tokens). This num_tokens can be up to max_cudagraph_capture_size. In some configurations (especially in tests), max_cudagraph_capture_size can be larger than max_num_seqs. Since the scheduler_metadata buffer size depends on the number of requests, we must allocate a buffer large enough for this capture-time scenario.

Suggested change
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
# During CUDA graph capture, `num_reqs` can be up to `max_cudagraph_size`.
# This can be larger than `max_num_seqs`, so we use the max of the two.
# The size is multiplied by 4 due to flash-attention's requirement:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653

@mergify mergify bot added gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm labels Feb 10, 2026
@mergify mergify bot added the cpu Related to CPU backends label Feb 10, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 10, 2026
@mergify mergify bot added the kv-connector label Feb 10, 2026
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/fix-fa3-swizzle branch 11 times, most recently from 39e4204 to 5797dc2 Compare February 10, 2026 18:34
@LucasWilkinson LucasWilkinson removed the documentation Improvements or additions to documentation label Feb 10, 2026
…ion" (#33841)

This reverts commit e3bf79f.

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
…raphs

- Increase scheduler_metadata buffer to 1 + round_up(batch_size, 4) * 4
  to account for all num_prepare_batch_vectors slots + semaphore
- Only relax BatchDescriptor for PIECEWISE mode, not FULL mode
  (FA3's scheduler_metadata computation depends on exact num_reqs)
- Rename relax_for_mixed_batch_cudagraphs -> relax_for_piecewise_cudagraphs
- Update tests to reflect FULL mode using exact keys

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Remove relax_for_piecewise_cudagraphs() method from BatchDescriptor
and use NamedTuple._replace() directly for cleaner O(1) set lookups.
For pure FULL mode, always relax uniform=False since keys are
registered with uniform=False.

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants