[ROCm] Cap Triton paged attention block size to fix ROCm shared memory OOM#38502
Conversation
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
cc @micah-wil |
There was a problem hiding this comment.
Code Review
This pull request introduces a cap of 128 on the TRITON_BLOCK_SIZE within the chunked_prefill_paged_decode operation to prevent shared memory OOM errors, particularly for models with large block sizes like hybrid Mamba. Feedback suggests that hardcoding this value is brittle and recommends a more robust approach by dynamically calculating the maximum block size based on the specific device's shared memory capacity to ensure better portability across different hardware.
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…id models Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
| is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel() | ||
| if block_size in (16, 32) and is_contiguous_blocks: | ||
| # Normal 16, 32 with contiguous blocks, use vLLM native HIP C++ logic |
There was a problem hiding this comment.
There was a problem hiding this comment.
is_contiguous() is for tensor level contiguity but reshape_and_cache kernel's assumption is block-level contiguity, so the former would introduce more performance overhead that I think would be unnecessary here.
There was a problem hiding this comment.
I see so we just want to check that the rows are contiguous?
There was a problem hiding this comment.
Yep :) Ensures that rows in zero dim (aka blocks) are packed with no gaps between them
| MAX_TRITON_BLOCK_SIZE = 128 | ||
| TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32 |
There was a problem hiding this comment.
Yep :)
At least for now.
There was a problem hiding this comment.
Is this op only used for ROCm? (sorry if that's a dumb question, I'm not familiar with this area of the code)
There was a problem hiding this comment.
I didn't know the answer to that question myself before I attempted to resolve the failure here, so I think it's not a dumb question 😅
Answer: Yep :) It's found only in vllm/v1/attention/backends/rocm_attn.py.
There was a problem hiding this comment.
Yes. This op is only used for ROCm.
| # (CanonicalizePointers, ConvertToBufferOps) crash when an scf.if | ||
| # yields pointers with different base addresses. Instead, we compute | ||
| # both sets of load pointers and use mutually exclusive masks. | ||
| if HAS_INITSTATES: |
There was a problem hiding this comment.
@hmellor do you know who is more familiar with this mamba code?
There was a problem hiding this comment.
I can review changes to this kernel, but I don't really understand why these changes are related to the rest of the PR?
There was a problem hiding this comment.
It is a second ROCm/Mamba blocker by the same hybrid-model validation path. The PR is aimed at getting hybrid Mamba models, e.g. Jamba, working on ROCm with chunked prefill. Once the attention path gets past the inflated/padded block-size issue, the same hybrid_model tests run the Mamba2 varlen SSD path with initial_states. In the previous Triton code, prev_states_ptr could come from either initstates_ptr or states_ptr through an if. On AMD Triton this lowers to an scf.if yielding pointers with different base addresses, and the ROCm compiler crashes in CanonicalizePointers / ConvertToBufferOps. This change keeps the same semantics by computing both candidate load pointers and using mutually exclusive masks, so only the selected source contributes. Without this fix we get:
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens0-8]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens0-256]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens1-8]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens1-256]
There was a problem hiding this comment.
@AndreasKaratzas Can you share the error? I think this was fixed on the Triton side already with triton-lang/triton#9541.
There was a problem hiding this comment.
After updating to the new base image, I realize that this patch is unnecessary. I restored this file and waiting for the CI eval to confirm that indeed this issue has been solved already elsewhere.
|
|
||
| if block_size in (16, 32): | ||
| # Normal 16, 32, use vLLM native HIP C++ logic | ||
| is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel() |
There was a problem hiding this comment.
Would this logic also apply to the use_custom for the actual kernel selection?
There was a problem hiding this comment.
I think that the custom kernel is stride aware:
Line 3244 in 617d1c2
So that logic is not needed there.
There was a problem hiding this comment.
We only can go into that kernel is reshape_and_cache was used, not reshape_and_cache_flash
There is a condition for whether to select the kernel, or go with the triton fallback. It may need to be changed accordingly
There was a problem hiding this comment.
@gshtras I updated the kernel selection to use the same native-layout as the cache update path. If the KV cache blocks are "strided" and the update path uses reshape_and_cache_flash, use_custom is now forced false so decode falls back to the Triton path.
| # via the l_block_idx/internal_offsets addressing logic. | ||
| # TODO: Remove after upgrading from Triton 3.6 on ROCm | ||
| # See: https://github.com/triton-lang/triton/pull/9541 | ||
| MAX_TRITON_BLOCK_SIZE = 128 |
There was a problem hiding this comment.
is this a hard limit for all ROCm GPUs?
There was a problem hiding this comment.
It's not architectural. The constraint is LDS pressure from the kernel's tile, and 128 is just where this kernel fits without per-arch tuning.
There was a problem hiding this comment.
Different platforms do have different LDS size (e.g. it's different between MI300 and MI355), so we could actually query the current platform's LDS size to calculate the max block size here if we wanted to be more precise. 128 does seem like it works universally though.
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
Hi @AndreasKaratzas, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
@tjtanaa the code changes from ssd has been reverted. Is this PR good to go? |
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
Changed |
… path Signed-off-by: Andreas Karatzas <akaratza@amd.com>
|
Hi @AndreasKaratzas, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
@AndreasKaratzas please fix precommit |
Yep, it's broken currently on main, will probably be fixed by: #42197 |
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
…y OOM (vllm-project#38502) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Hybrid Mamba models (e.g. Jamba) inflate block_size to 2048 to align attention and Mamba page sizes. When the ROCm custom paged attention kernel rejects this (it only supports 16/32), the Triton fallback kernel_paged_attention_2d used 2048 as its tile size, requesting 262144 bytes of shared memory and thus exceeding the MI325X hardware limit of 65536 bytes. Cap TRITON_BLOCK_SIZE at 128. The kernel already decouples tile size from physical block size via l_block_idx/internal_offsets addressing, so this is safe.
Test plan
pytest tests/models/language/generation/test_hybrid.pycc @kenroche