Reapply [Attention][FA3] Update FA3 to include new swizzle optimization#34043
Reapply [Attention][FA3] Update FA3 to include new swizzle optimization#34043
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the flash-attention dependency and adjusts the scheduler_metadata buffer size to accommodate changes in flash-attention and handle edge cases related to CUDA graph capturing. The changes appear correct and necessary. My review focuses on improving the clarity of the comments associated with these critical buffer size calculations. Enhancing these comments will improve maintainability and help prevent future confusion or bugs, especially since this area of the code has been a source of issues in the past.
| # Times 4 due to: | ||
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 | ||
| # For some tests max_cudagraph_size > max_num_seqs, | ||
| # so we need to use the larger one. |
There was a problem hiding this comment.
The comment explaining the buffer size calculation is a bit confusing as it compares max_cudagraph_size (number of tokens) with max_num_seqs (number of sequences) without full context. A more detailed explanation would improve clarity and maintainability, especially given the history of issues with this code block.
During CUDA graph capture for uniform decode batches, the number of requests (num_reqs) is set to be equal to the number of tokens (num_tokens). This num_tokens can be up to max_cudagraph_capture_size. In some configurations (especially in tests), max_cudagraph_capture_size can be larger than max_num_seqs. Since the scheduler_metadata buffer size depends on the number of requests, we must allocate a buffer large enough for this capture-time scenario.
| # Times 4 due to: | |
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 | |
| # For some tests max_cudagraph_size > max_num_seqs, | |
| # so we need to use the larger one. | |
| # During CUDA graph capture, `num_reqs` can be up to `max_cudagraph_size`. | |
| # This can be larger than `max_num_seqs`, so we use the max of the two. | |
| # The size is multiplied by 4 due to flash-attention's requirement: | |
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 |
| # Times 4 due to: | ||
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 | ||
| # For some tests max_cudagraph_size > max_num_seqs, | ||
| # so we need to use the larger one. |
There was a problem hiding this comment.
The comment explaining the buffer size calculation is a bit confusing as it compares max_cudagraph_size (number of tokens) with max_num_seqs (number of sequences) without full context. A more detailed explanation would improve clarity and maintainability, especially given the history of issues with this code block.
During CUDA graph capture for uniform decode batches, the number of requests (num_reqs) is set to be equal to the number of tokens (num_tokens). This num_tokens can be up to max_cudagraph_capture_size. In some configurations (especially in tests), max_cudagraph_capture_size can be larger than max_num_seqs. Since the scheduler_metadata buffer size depends on the number of requests, we must allocate a buffer large enough for this capture-time scenario.
| # Times 4 due to: | |
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 | |
| # For some tests max_cudagraph_size > max_num_seqs, | |
| # so we need to use the larger one. | |
| # During CUDA graph capture, `num_reqs` can be up to `max_cudagraph_size`. | |
| # This can be larger than `max_num_seqs`, so we use the max of the two. | |
| # The size is multiplied by 4 due to flash-attention's requirement: | |
| # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 |
7249070 to
592c224
Compare
592c224 to
b7d884a
Compare
39e4204 to
5797dc2
Compare
…raphs - Increase scheduler_metadata buffer to 1 + round_up(batch_size, 4) * 4 to account for all num_prepare_batch_vectors slots + semaphore - Only relax BatchDescriptor for PIECEWISE mode, not FULL mode (FA3's scheduler_metadata computation depends on exact num_reqs) - Rename relax_for_mixed_batch_cudagraphs -> relax_for_piecewise_cudagraphs - Update tests to reflect FULL mode using exact keys Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Remove relax_for_piecewise_cudagraphs() method from BatchDescriptor and use NamedTuple._replace() directly for cleaner O(1) set lookups. For pure FULL mode, always relax uniform=False since keys are registered with uniform=False. Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Reapply #23465 after revert in #33841 but with correct metadata sizes