Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1fa54b7
Cap Triton paged attention block size to fix ROCm shared memory OOM
AndreasKaratzas Mar 30, 2026
9d5b0a0
Cap Triton paged attention block size to fix ROCm shared memory OOM
AndreasKaratzas Mar 30, 2026
ca6e2df
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Mar 31, 2026
3b44ad4
[ROCm] Fix ROCM_ATTN KV cache write for non-contiguous blocks in hybr…
AndreasKaratzas Mar 31, 2026
70a327c
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 1, 2026
483debc
[ROCm][CI] Fix AMD Triton compiler crash in Mamba SSD chunk scan kernel
AndreasKaratzas Apr 1, 2026
f3e5e4e
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 1, 2026
3262441
Syncing with upstream states mamba version
AndreasKaratzas Apr 1, 2026
513ada7
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 1, 2026
311039c
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 3, 2026
bf8a6f5
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 15, 2026
5784e80
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 20, 2026
032f175
Reverted triton block size max amidst merged triton lib fix
AndreasKaratzas Apr 20, 2026
f2cfbd4
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 20, 2026
50ac00f
Set triton block size max cause triton bug is still there but not evi…
AndreasKaratzas Apr 20, 2026
1d5f15d
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 23, 2026
1510cc4
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 24, 2026
2ee08e4
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 24, 2026
b038408
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 28, 2026
979ad99
Restored ssd
AndreasKaratzas Apr 29, 2026
e7eb924
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 29, 2026
e97ad4a
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas Apr 29, 2026
2cf7d11
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas May 3, 2026
e0a7d20
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas May 5, 2026
49945d7
Optimize contiguous block detection
AndreasKaratzas May 5, 2026
0fcf335
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas May 7, 2026
3befaed
[ROCm] Updated kernel selection to same native-layout as cache update…
AndreasKaratzas May 7, 2026
4743215
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas May 10, 2026
6f9f1ea
Merge remote-tracking branch 'origin/main' into akaratza_chunked_prefill
AndreasKaratzas May 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1803,9 +1803,10 @@ steps:
- tests/models/multimodal/generation
- tests/models/multimodal/test_mapping.py
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/generation -m 'not core_model' --ignore models/multimodal/generation/test_common.py
- pytest -v -s models/multimodal/test_mapping.py
- uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0'

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be upstreamed?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know tbh. I'm keeping it for now synced with upstream. In this branch it's basically upstream + my original fix.

- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.6.0'
- pytest -v -s models/language/generation -m hybrid_model --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB


- label: Multi-Modal Models (Extended Generation 2) # TBD
timeout_in_minutes: 180
Expand All @@ -1817,8 +1818,10 @@ steps:
- vllm/
- tests/models/multimodal/generation
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
- uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0'
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.6.0'
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'


- label: Multi-Modal Models (Extended Generation 3) # TBD
timeout_in_minutes: 180
Expand Down Expand Up @@ -3043,7 +3046,7 @@ steps:
- vllm/
- tests/models/language/generation
commands:
- uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@fix-rocm-7.0-warp-size-constexpr'
- uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0'

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this be upstreamed?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know tbh. I'm keeping it for now synced with upstream. In this branch it's basically upstream + my original fix.

- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.6.0'
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'

Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
has_native_kv_cache_layout,
)
from vllm.v1.attention.ops.paged_attn import PagedAttention
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
Expand Down Expand Up @@ -468,9 +469,10 @@ def do_kv_cache_update(
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
has_native_layout = has_native_kv_cache_layout(key_cache, value_cache)

if block_size in (16, 32):
# Normal 16, 32, use vLLM native HIP C++ logic
if block_size in (16, 32) and has_native_layout:
# Normal 16, 32 with contiguous blocks: use vLLM native HIP C++ logic.
PagedAttention.write_to_paged_cache(
key,
value,
Expand All @@ -482,8 +484,10 @@ def do_kv_cache_update(
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
# force using our modified Triton logic
# Non-standard blocks and hybrid attention/Mamba layouts need the
# stride-aware Triton writer. The native reshape_and_cache kernel
# assumes contiguous block storage and writes to the wrong hybrid
# cache blocks.
triton_reshape_and_cache_flash(
key,
value,
Expand Down
35 changes: 27 additions & 8 deletions vllm/v1/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@
float8_info = torch.finfo(current_platform.fp8_dtype())


def has_native_kv_cache_layout(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> bool:
"""Return whether KV cache blocks can use the native ROCm pairing.

The native reshape_and_cache writer assumes packed blocks. If cache update
needs reshape_and_cache_flash for a stride-padded hybrid layout, decode
should use the matching Triton path too.
"""
return (
key_cache.stride(0) == key_cache.shape[1:].numel()
and value_cache.stride(0) == value_cache.shape[1:].numel()
)


@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
Expand Down Expand Up @@ -346,14 +362,12 @@ def chunked_prefill_paged_decode(
alibi_slopes,
sinks,
)
# Triton is only forced when encountering a non-standard block
# like Qwen3 with a size of 544.
# 1. Check if block_size is a power of 2 (16, 32, 64...)
# 2. If it's a power of 2, we trust the vLLM's native use_custom decision.
# 3. If it's not a power of 2 (such as Qwen3's 544),
# then our Triton path is forced.
has_native_layout = has_native_kv_cache_layout(key_cache, value_cache)
# Force Triton for non-standard blocks like Qwen3's 544 and for
# stride-padded hybrid layouts. The latter use reshape_and_cache_flash
# during cache update, so keep decode on the matching stride-aware path.
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if not is_pow2:
if not is_pow2 or not has_native_layout:
use_custom = False

if use_custom:
Expand Down Expand Up @@ -404,7 +418,12 @@ def chunked_prefill_paged_decode(
real_block_size = value_cache.shape[3]
# The standard model directly uses the original block_size.
# Non-standard 544 uses 32 to accommodate integer division logic.
TRITON_BLOCK_SIZE = block_size if is_pow2 else 32
# Cap at 128 to avoid exceeding GPU shared memory limits
# (e.g. hybrid Mamba models inflate block_size to 2048).
# The kernel handles TRITON_BLOCK_SIZE != PHYSICAL_BLOCK_SIZE
# via the l_block_idx/internal_offsets addressing logic.
MAX_TRITON_BLOCK_SIZE = 128
TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32
if is_block_table_ptr:
# Using the physical base address of tensors
kv_element_size = key_cache.element_size()
Expand Down
Loading