From 1fa54b73c7fc4913497973c157d1c358faa99527 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Sun, 29 Mar 2026 23:28:39 -0500 Subject: [PATCH 01/10] Cap Triton paged attention block size to fix ROCm shared memory OOM Signed-off-by: Andreas Karatzas --- vllm/v1/attention/ops/chunked_prefill_paged_decode.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 000fd4d43b93..023f0ee6b0af 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -402,7 +402,12 @@ def chunked_prefill_paged_decode( real_block_size = value_cache.shape[3] # The standard model directly uses the original block_size. # Non-standard 544 uses 32 to accommodate integer division logic. - TRITON_BLOCK_SIZE = block_size if is_pow2 else 32 + # Cap at 128 to avoid exceeding GPU shared memory limits + # (e.g. hybrid Mamba models inflate block_size to 2048). + # The kernel handles TRITON_BLOCK_SIZE != PHYSICAL_BLOCK_SIZE + # via the l_block_idx/internal_offsets addressing logic. + MAX_TRITON_BLOCK_SIZE = 128 + TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32 if is_block_table_ptr: # Using the physical base address of tensors kv_element_size = key_cache.element_size() From 9d5b0a0140107014ba581d7c0b824f31583b580b Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Sun, 29 Mar 2026 23:36:25 -0500 Subject: [PATCH 02/10] Cap Triton paged attention block size to fix ROCm shared memory OOM Signed-off-by: Andreas Karatzas --- vllm/v1/attention/ops/chunked_prefill_paged_decode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 023f0ee6b0af..cda2d668287e 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -406,6 +406,8 @@ def chunked_prefill_paged_decode( # (e.g. hybrid Mamba models inflate block_size to 2048). # The kernel handles TRITON_BLOCK_SIZE != PHYSICAL_BLOCK_SIZE # via the l_block_idx/internal_offsets addressing logic. + # TODO: Remove after upgrading from Triton 3.6 on ROCm + # See: https://github.com/triton-lang/triton/pull/9541 MAX_TRITON_BLOCK_SIZE = 128 TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32 if is_block_table_ptr: From 3b44ad4c326128c9c6381b087ae09734411f25ab Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 31 Mar 2026 02:10:48 -0500 Subject: [PATCH 03/10] [ROCm] Fix ROCM_ATTN KV cache write for non-contiguous blocks in hybrid models Signed-off-by: Andreas Karatzas --- vllm/v1/attention/backends/rocm_attn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 6afb617f28ed..6d0967116207 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -463,8 +463,9 @@ def do_kv_cache_update( # value_cache shape: [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] - if block_size in (16, 32): - # Normal 16, 32, use vLLM native HIP C++ logic + is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel() + if block_size in (16, 32) and is_contiguous_blocks: + # Normal 16, 32 with contiguous blocks, use vLLM native HIP C++ logic PagedAttention.write_to_paged_cache( key, value, @@ -476,8 +477,11 @@ def do_kv_cache_update( layer._v_scale, ) else: - # Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ), - # force using our modified Triton logic + # Non-standard blocks (e.g., 544 in Qwen3Next) or non-contiguous + # blocks (e.g., hybrid Mamba models where page size is padded to + # align attention and Mamba state pages). The C++ reshape_and_cache + # kernel assumes contiguous block layout, so fall back to Triton + # which uses explicit strides. triton_reshape_and_cache_flash( key, value, From 483debcd8be923d67b803493e598f2d3f09b5f18 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 31 Mar 2026 20:18:45 -0500 Subject: [PATCH 04/10] [ROCm][CI] Fix AMD Triton compiler crash in Mamba SSD chunk scan kernel Signed-off-by: Andreas Karatzas --- .../layers/mamba/ops/ssd_chunk_scan.py | 114 +++++++++++++----- 1 file changed, 87 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index e5e73625f861..a9c1f9a207b5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -236,14 +236,21 @@ def _chunk_scan_fwd_kernel( seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 ) - if HAS_INITSTATES and (seq_idx != seq_idx_prev): - prev_states_ptr = ( + # NOTE: When HAS_INITSTATES is True, we avoid storing the selected + # pointer in a variable because AMD Triton compiler passes + # (CanonicalizePointers, ConvertToBufferOps) crash when an scf.if + # yields pointers with different base addresses. Instead, we compute + # both sets of load pointers and use mutually exclusive masks. + if HAS_INITSTATES: + use_init_states = seq_idx != seq_idx_prev + init_states_base = ( initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head ) - prev_states_hdim = stride_init_states_hdim - prev_states_dstate = stride_init_states_dstate + chunk_states_base = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) else: prev_states_ptr = ( states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head @@ -280,8 +287,25 @@ def _chunk_scan_fwd_kernel( other=0.0, ) - if not HAS_INITSTATES and (seq_idx != seq_idx_prev): - # if no init states AND starting a new sequence, we need zeros + if HAS_INITSTATES: + # Load from both sources with mutually exclusive masks + base_mask = (offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim) + init_ptrs = ( + init_states_base + + offs_n[None, :] * stride_init_states_hdim + + offs_k_dstate[:, None] * stride_init_states_dstate + ) + chunk_ptrs = ( + chunk_states_base + + offs_n[None, :] * stride_states_hdim + + offs_k_dstate[:, None] * stride_states_dstate + ) + prev_states = tl.load( + init_ptrs, mask=use_init_states & base_mask, other=0.0 + ) + tl.load(chunk_ptrs, mask=(~use_init_states) & base_mask, other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + elif seq_idx != seq_idx_prev: + # no init states AND starting a new sequence: use zeros prev_states = tl.zeros( (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty ) @@ -302,33 +326,69 @@ def _chunk_scan_fwd_kernel( acc = tl.dot(C, prev_states) * scale_m[:, None] else: - prev_states_ptrs = ( - prev_states_ptr - + offs_n[None, :] * prev_states_hdim - + offs_k_dstate[:, None] * prev_states_dstate - ) - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k_dstate[None, :] < dstate - k), - other=0.0, + if HAS_INITSTATES: + init_state_ptrs = ( + init_states_base + + offs_n[None, :] * stride_init_states_hdim + + offs_k_dstate[:, None] * stride_init_states_dstate ) - if not HAS_INITSTATES and (seq_idx != seq_idx_prev): - prev_states = tl.zeros( - (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + chunk_state_ptrs = ( + chunk_states_base + + offs_n[None, :] * stride_states_hdim + + offs_k_dstate[:, None] * stride_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + base_mask = (offs_k_dstate[:, None] < dstate - k) & ( + offs_n[None, :] < hdim ) - else: prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) - & (offs_n[None, :] < hdim), + init_state_ptrs, + mask=use_init_states & base_mask, + other=0.0, + ) + tl.load( + chunk_state_ptrs, + mask=(~use_init_states) & base_mask, other=0.0, ) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + init_state_ptrs += BLOCK_SIZE_K + chunk_state_ptrs += BLOCK_SIZE_K + else: + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + if seq_idx != seq_idx_prev: + prev_states = tl.zeros( + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) + else: + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) From 3262441a3e890ffe4ba76f878283979c4c77b7c2 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Wed, 1 Apr 2026 11:57:10 -0500 Subject: [PATCH 05/10] Syncing with upstream states mamba version Signed-off-by: Andreas Karatzas --- .buildkite/test-amd.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index b7254efd2dc2..b65faf5cdd33 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -2196,7 +2196,7 @@ steps: - vllm/ - tests/models/language/generation commands: - - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@fix-rocm-7.0-warp-size-constexpr' + - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - pytest -v -s models/language/generation -m hybrid_model --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --shard-id=$$BUILDKITE_PARALLEL_JOB @@ -2210,7 +2210,7 @@ steps: - vllm/ - tests/models/language/generation commands: - - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@fix-rocm-7.0-warp-size-constexpr' + - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' @@ -3348,7 +3348,7 @@ steps: - vllm/ - tests/models/language/generation commands: - - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@fix-rocm-7.0-warp-size-constexpr' + - uv pip install --system --no-build-isolation 'git+https://github.com/AndreasKaratzas/mamba@rocm-7.0-v2.3.0' - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' From 032f1754c9d52d5363cf87ffa560400cdbfb5c82 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Sun, 19 Apr 2026 19:54:38 -0500 Subject: [PATCH 06/10] Reverted triton block size max amidst merged triton lib fix Signed-off-by: Andreas Karatzas --- vllm/v1/attention/ops/chunked_prefill_paged_decode.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index cda2d668287e..000fd4d43b93 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -402,14 +402,7 @@ def chunked_prefill_paged_decode( real_block_size = value_cache.shape[3] # The standard model directly uses the original block_size. # Non-standard 544 uses 32 to accommodate integer division logic. - # Cap at 128 to avoid exceeding GPU shared memory limits - # (e.g. hybrid Mamba models inflate block_size to 2048). - # The kernel handles TRITON_BLOCK_SIZE != PHYSICAL_BLOCK_SIZE - # via the l_block_idx/internal_offsets addressing logic. - # TODO: Remove after upgrading from Triton 3.6 on ROCm - # See: https://github.com/triton-lang/triton/pull/9541 - MAX_TRITON_BLOCK_SIZE = 128 - TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32 + TRITON_BLOCK_SIZE = block_size if is_pow2 else 32 if is_block_table_ptr: # Using the physical base address of tensors kv_element_size = key_cache.element_size() From 50ac00fc96c2c65685dc5ee579926e224f450ee7 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 20 Apr 2026 12:17:16 -0500 Subject: [PATCH 07/10] Set triton block size max cause triton bug is still there but not evident with rocm attn Signed-off-by: Andreas Karatzas --- vllm/v1/attention/ops/chunked_prefill_paged_decode.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 000fd4d43b93..023f0ee6b0af 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -402,7 +402,12 @@ def chunked_prefill_paged_decode( real_block_size = value_cache.shape[3] # The standard model directly uses the original block_size. # Non-standard 544 uses 32 to accommodate integer division logic. - TRITON_BLOCK_SIZE = block_size if is_pow2 else 32 + # Cap at 128 to avoid exceeding GPU shared memory limits + # (e.g. hybrid Mamba models inflate block_size to 2048). + # The kernel handles TRITON_BLOCK_SIZE != PHYSICAL_BLOCK_SIZE + # via the l_block_idx/internal_offsets addressing logic. + MAX_TRITON_BLOCK_SIZE = 128 + TRITON_BLOCK_SIZE = min(block_size, MAX_TRITON_BLOCK_SIZE) if is_pow2 else 32 if is_block_table_ptr: # Using the physical base address of tensors kv_element_size = key_cache.element_size() From 979ad99b4d0964c2596f4de800178b5d044498f6 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Wed, 29 Apr 2026 01:27:54 -0500 Subject: [PATCH 08/10] Restored ssd Signed-off-by: Andreas Karatzas --- .../layers/mamba/ops/ssd_chunk_scan.py | 114 +++++------------- 1 file changed, 27 insertions(+), 87 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index a9c1f9a207b5..e5e73625f861 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -236,21 +236,14 @@ def _chunk_scan_fwd_kernel( seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 ) - # NOTE: When HAS_INITSTATES is True, we avoid storing the selected - # pointer in a variable because AMD Triton compiler passes - # (CanonicalizePointers, ConvertToBufferOps) crash when an scf.if - # yields pointers with different base addresses. Instead, we compute - # both sets of load pointers and use mutually exclusive masks. - if HAS_INITSTATES: - use_init_states = seq_idx != seq_idx_prev - init_states_base = ( + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = ( initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head ) - chunk_states_base = ( - states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head - ) + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate else: prev_states_ptr = ( states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head @@ -287,25 +280,8 @@ def _chunk_scan_fwd_kernel( other=0.0, ) - if HAS_INITSTATES: - # Load from both sources with mutually exclusive masks - base_mask = (offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim) - init_ptrs = ( - init_states_base - + offs_n[None, :] * stride_init_states_hdim - + offs_k_dstate[:, None] * stride_init_states_dstate - ) - chunk_ptrs = ( - chunk_states_base - + offs_n[None, :] * stride_states_hdim - + offs_k_dstate[:, None] * stride_states_dstate - ) - prev_states = tl.load( - init_ptrs, mask=use_init_states & base_mask, other=0.0 - ) + tl.load(chunk_ptrs, mask=(~use_init_states) & base_mask, other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - elif seq_idx != seq_idx_prev: - # no init states AND starting a new sequence: use zeros + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros prev_states = tl.zeros( (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty ) @@ -326,69 +302,33 @@ def _chunk_scan_fwd_kernel( acc = tl.dot(C, prev_states) * scale_m[:, None] else: - if HAS_INITSTATES: - init_state_ptrs = ( - init_states_base - + offs_n[None, :] * stride_init_states_hdim - + offs_k_dstate[:, None] * stride_init_states_dstate - ) - chunk_state_ptrs = ( - chunk_states_base - + offs_n[None, :] * stride_states_hdim - + offs_k_dstate[:, None] * stride_states_dstate + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, ) - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k_dstate[None, :] < dstate - k), - other=0.0, - ) - base_mask = (offs_k_dstate[:, None] < dstate - k) & ( - offs_n[None, :] < hdim + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros( + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty ) + else: prev_states = tl.load( - init_state_ptrs, - mask=use_init_states & base_mask, - other=0.0, - ) + tl.load( - chunk_state_ptrs, - mask=(~use_init_states) & base_mask, + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), other=0.0, ) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - init_state_ptrs += BLOCK_SIZE_K - chunk_state_ptrs += BLOCK_SIZE_K - else: - prev_states_ptrs = ( - prev_states_ptr - + offs_n[None, :] * prev_states_hdim - + offs_k_dstate[:, None] * prev_states_dstate - ) - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k_dstate[None, :] < dstate - k), - other=0.0, - ) - if seq_idx != seq_idx_prev: - prev_states = tl.zeros( - (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty - ) - else: - prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) - & (offs_n[None, :] < hdim), - other=0.0, - ) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) From 49945d7493d0f13b5b19ac191620a138cfe9f8d4 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 5 May 2026 01:25:52 -0500 Subject: [PATCH 09/10] Optimize contiguous block detection Signed-off-by: Andreas Karatzas --- vllm/v1/attention/backends/rocm_attn.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 2a5789fa6e20..5da82ad637e8 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -468,10 +468,15 @@ def do_kv_cache_update( # Get the actual block_size from value_cache # value_cache shape: [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] + key_block_stride = key_cache.shape[1:].numel() + value_block_stride = value_cache.shape[1:].numel() + has_contiguous_blocks = ( + key_cache.stride(0) == key_block_stride + and value_cache.stride(0) == value_block_stride + ) - is_contiguous_blocks = key_cache.stride(0) == key_cache[0].numel() - if block_size in (16, 32) and is_contiguous_blocks: - # Normal 16, 32 with contiguous blocks, use vLLM native HIP C++ logic + if block_size in (16, 32) and has_contiguous_blocks: + # Normal 16, 32 with contiguous blocks: use vLLM native HIP C++ logic. PagedAttention.write_to_paged_cache( key, value, @@ -483,11 +488,10 @@ def do_kv_cache_update( layer._v_scale, ) else: - # Non-standard blocks (e.g., 544 in Qwen3Next) or non-contiguous - # blocks (e.g., hybrid Mamba models where page size is padded to - # align attention and Mamba state pages). The C++ reshape_and_cache - # kernel assumes contiguous block layout, so fall back to Triton - # which uses explicit strides. + # Non-standard blocks and hybrid attention/Mamba layouts need the + # stride-aware Triton writer. The native reshape_and_cache kernel + # assumes contiguous block storage and writes to the wrong hybrid + # cache blocks. triton_reshape_and_cache_flash( key, value, From 3befaeda88dcb33c08faf8297484c4e28abeaabe Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Thu, 7 May 2026 12:19:59 -0500 Subject: [PATCH 10/10] [ROCm] Updated kernel selection to same native-layout as cache update path Signed-off-by: Andreas Karatzas --- vllm/v1/attention/backends/rocm_attn.py | 10 ++----- .../ops/chunked_prefill_paged_decode.py | 28 ++++++++++++++----- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5da82ad637e8..d533268e2176 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -29,6 +29,7 @@ ) from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode, + has_native_kv_cache_layout, ) from vllm.v1.attention.ops.paged_attn import PagedAttention from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( @@ -468,14 +469,9 @@ def do_kv_cache_update( # Get the actual block_size from value_cache # value_cache shape: [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] - key_block_stride = key_cache.shape[1:].numel() - value_block_stride = value_cache.shape[1:].numel() - has_contiguous_blocks = ( - key_cache.stride(0) == key_block_stride - and value_cache.stride(0) == value_block_stride - ) + has_native_layout = has_native_kv_cache_layout(key_cache, value_cache) - if block_size in (16, 32) and has_contiguous_blocks: + if block_size in (16, 32) and has_native_layout: # Normal 16, 32 with contiguous blocks: use vLLM native HIP C++ logic. PagedAttention.write_to_paged_cache( key, diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 1895de0940bf..77eb3ac60b1f 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -21,6 +21,22 @@ float8_info = torch.finfo(current_platform.fp8_dtype()) +def has_native_kv_cache_layout( + key_cache: torch.Tensor, + value_cache: torch.Tensor, +) -> bool: + """Return whether KV cache blocks can use the native ROCm pairing. + + The native reshape_and_cache writer assumes packed blocks. If cache update + needs reshape_and_cache_flash for a stride-padded hybrid layout, decode + should use the matching Triton path too. + """ + return ( + key_cache.stride(0) == key_cache.shape[1:].numel() + and value_cache.stride(0) == value_cache.shape[1:].numel() + ) + + @triton.jit def cdiv_fn(x, y): return (x + y - 1) // y @@ -346,14 +362,12 @@ def chunked_prefill_paged_decode( alibi_slopes, sinks, ) - # Triton is only forced when encountering a non-standard block - # like Qwen3 with a size of 544. - # 1. Check if block_size is a power of 2 (16, 32, 64...) - # 2. If it's a power of 2, we trust the vLLM's native use_custom decision. - # 3. If it's not a power of 2 (such as Qwen3's 544), - # then our Triton path is forced. + has_native_layout = has_native_kv_cache_layout(key_cache, value_cache) + # Force Triton for non-standard blocks like Qwen3's 544 and for + # stride-padded hybrid layouts. The latter use reshape_and_cache_flash + # during cache update, so keep decode on the matching stride-aware path. is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) - if not is_pow2: + if not is_pow2 or not has_native_layout: use_custom = False if use_custom: