[ROCm] Restore 16-wide fast path in Triton unified attention#30582
[ROCm] Restore 16-wide fast path in Triton unified attention#30582hyoon1 wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization for AMD Navi GPUs by adding a fast path to Triton attention kernels. This is achieved by creating a special case for when TILE_SIZE equals BLOCK_SIZE, which avoids expensive division and modulo operations. The changes are effective, as shown by the performance benchmarks. However, the implementation introduces significant code duplication in both kernel_unified_attention_2d and kernel_unified_attention_3d kernels. This duplication makes the code harder to maintain and reason about. My review focuses on refactoring these kernels to eliminate the code duplication while preserving the performance benefits.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # V : (BLOCK_SIZE, HEAD_SIZE) | ||
| V_load = tl.load( | ||
| value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0 | ||
| ) |
There was a problem hiding this comment.
Mask padded V tiles in block-aligned fast path
When TILE_SIZE == BLOCK_SIZE, the new fast path loads the entire cache block of V with only a head-dimension mask (tl.load(..., mask=dim_mask[None, :])). For sequences whose final KV block is only partially filled, the padding entries remain whatever was previously in that page; if any of those bytes happen to be NaN, the subsequent tl.dot(P, V) will multiply zeros by NaNs and propagate NaNs into acc even though the positions are softmax-masked. The general path used tile_mask to zero out those columns, avoiding reads of uninitialized padding. The fast path needs the same masking (likewise in the new 3D fast path) to prevent corrupted outputs on partially filled blocks.
Useful? React with 👍 / 👎.
9a30335 to
f18b504
Compare
|
Hey @tdoublep , can you help review this PR? Thanks! |
Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
f18b504 to
5837be3
Compare
Purpose
Test Plan
Run meta-llama/Llama-3.1-8B-Instruct and check the benchmark result.
lm_eval for correctness test
Test Result
Performance & Correctness
AMD Radeon Pro W7900 (RDNA3)
Original Triton unified attention kernel:
Updated version:
AMD MI308X
Original Triton unified attention kernel:
Updated version:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.