Skip to content

[ROCm] Cap Triton paged attention block size to fix ROCm shared memory OOM#38502

Merged
tjtanaa merged 29 commits into
vllm-project:mainfrom
ROCm:akaratza_chunked_prefill
May 10, 2026
Merged

[ROCm] Cap Triton paged attention block size to fix ROCm shared memory OOM#38502
tjtanaa merged 29 commits into
vllm-project:mainfrom
ROCm:akaratza_chunked_prefill

Conversation

@AndreasKaratzas

Copy link
Copy Markdown
Member

Hybrid Mamba models (e.g. Jamba) inflate block_size to 2048 to align attention and Mamba page sizes. When the ROCm custom paged attention kernel rejects this (it only supports 16/32), the Triton fallback kernel_paged_attention_2d used 2048 as its tile size, requesting 262144 bytes of shared memory and thus exceeding the MI325X hardware limit of 65536 bytes. Cap TRITON_BLOCK_SIZE at 128. The kernel already decouples tile size from physical block size via l_block_idx/internal_offsets addressing, so this is safe.

Test plan

  • pytest tests/models/language/generation/test_hybrid.py

cc @kenroche

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@AndreasKaratzas

Copy link
Copy Markdown
Member Author

cc @micah-wil

@AndreasKaratzas AndreasKaratzas added ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm labels Mar 30, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 30, 2026
@mergify mergify Bot added the v1 label Mar 30, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a cap of 128 on the TRITON_BLOCK_SIZE within the chunked_prefill_paged_decode operation to prevent shared memory OOM errors, particularly for models with large block sizes like hybrid Mamba. Feedback suggests that hardcoding this value is brittle and recommends a more robust approach by dynamically calculating the maximum block size based on the specific device's shared memory capacity to ensure better portability across different hardware.

Comment thread vllm/v1/attention/ops/chunked_prefill_paged_decode.py
@mergify mergify Bot added the ci/build label Apr 1, 2026
@AndreasKaratzas AndreasKaratzas marked this pull request as ready for review April 1, 2026 22:10
@AndreasKaratzas

Copy link
Copy Markdown
Member Author

Comment thread vllm/v1/attention/backends/rocm_attn.py Outdated
Comment on lines +467 to +469
is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel()
if block_size in (16, 32) and is_contiguous_blocks:
# Normal 16, 32 with contiguous blocks, use vLLM native HIP C++ logic

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_contiguous() is for tensor level contiguity but reshape_and_cache kernel's assumption is block-level contiguity, so the former would introduce more performance overhead that I think would be unnecessary here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see so we just want to check that the rows are contiguous?

@AndreasKaratzas AndreasKaratzas Apr 14, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep :) Ensures that rows in zero dim (aka blocks) are packed with no gaps between them

Comment on lines +411 to +412
MAX_TRITON_BLOCK_SIZE = 128
TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a ROCm specific limit?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep :)

At least for now.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this op only used for ROCm? (sorry if that's a dumb question, I'm not familiar with this area of the code)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know the answer to that question myself before I attempted to resolve the failure here, so I think it's not a dumb question 😅

Answer: Yep :) It's found only in vllm/v1/attention/backends/rocm_attn.py.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This op is only used for ROCm.

# (CanonicalizePointers, ConvertToBufferOps) crash when an scf.if
# yields pointers with different base addresses. Instead, we compute
# both sets of load pointers and use mutually exclusive masks.
if HAS_INITSTATES:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor do you know who is more familiar with this mamba code?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can review changes to this kernel, but I don't really understand why these changes are related to the rest of the PR?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndreasKaratzas can you explain? Thanks

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a second ROCm/Mamba blocker by the same hybrid-model validation path. The PR is aimed at getting hybrid Mamba models, e.g. Jamba, working on ROCm with chunked prefill. Once the attention path gets past the inflated/padded block-size issue, the same hybrid_model tests run the Mamba2 varlen SSD path with initial_states. In the previous Triton code, prev_states_ptr could come from either initstates_ptr or states_ptr through an if. On AMD Triton this lowers to an scf.if yielding pointers with different base addresses, and the ROCm compiler crashes in CanonicalizePointers / ConvertToBufferOps. This change keeps the same semantics by computing both candidate load pointers and using mutually exclusive masks, so only the selected source contributes. Without this fix we get:

FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens0-8]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens0-256]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens1-8]
FAILED tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking[seqlens1-256]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndreasKaratzas Can you share the error? I think this was fixed on the Triton side already with triton-lang/triton#9541.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After updating to the new base image, I realize that this patch is unnecessary. I restored this file and waiting for the CI eval to confirm that indeed this issue has been solved already elsewhere.

Comment thread vllm/v1/attention/backends/rocm_attn.py Outdated

if block_size in (16, 32):
# Normal 16, 32, use vLLM native HIP C++ logic
is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this logic also apply to the use_custom for the actual kernel selection?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the custom kernel is stride aware:

int kv_block_stride = key_cache.stride(0);

So that logic is not needed there.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only can go into that kernel is reshape_and_cache was used, not reshape_and_cache_flash
There is a condition for whether to select the kernel, or go with the triton fallback. It may need to be changed accordingly

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras I updated the kernel selection to use the same native-layout as the cache update path. If the KV cache blocks are "strided" and the update path uses reshape_and_cache_flash, use_custom is now forced false so decode falls back to the Triton path.

# via the l_block_idx/internal_offsets addressing logic.
# TODO: Remove after upgrading from Triton 3.6 on ROCm
# See: https://github.com/triton-lang/triton/pull/9541
MAX_TRITON_BLOCK_SIZE = 128

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a hard limit for all ROCm GPUs?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not architectural. The constraint is LDS pressure from the kernel's tile, and 128 is just where this kernel fits without per-arch tuning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different platforms do have different LDS size (e.g. it's different between MI300 and MI355), so we could actually query the current platform's LDS size to calculate the max block size here if we wanted to be more precise. 128 does seem like it works universally though.

@mergify

mergify Bot commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

Hi @AndreasKaratzas, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@AndreasKaratzas

Copy link
Copy Markdown
Member Author

@tjtanaa the code changes from ssd has been reverted. Is this PR good to go?

@AndreasKaratzas

Copy link
Copy Markdown
Member Author

Changed is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel() because key_cache.shape[1:].numel() is pure metadata and avoids creating a key_cache[0] tensor view in a hot path

@mergify

mergify Bot commented May 10, 2026

Copy link
Copy Markdown
Contributor

Hi @AndreasKaratzas, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@tjtanaa tjtanaa left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tjtanaa

tjtanaa commented May 10, 2026

Copy link
Copy Markdown
Member

@AndreasKaratzas please fix precommit

@AndreasKaratzas

Copy link
Copy Markdown
Member Author

@AndreasKaratzas please fix precommit

Yep, it's broken currently on main, will probably be fixed by: #42197

@tjtanaa tjtanaa enabled auto-merge (squash) May 10, 2026 08:32
@tjtanaa tjtanaa merged commit 0a309b5 into vllm-project:main May 10, 2026
65 of 66 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 10, 2026
@AndreasKaratzas AndreasKaratzas deleted the akaratza_chunked_prefill branch May 10, 2026 20:48
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 11, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
…y OOM (vllm-project#38502)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants