Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first).
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

Expand Down
32 changes: 10 additions & 22 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
# ROCM paged attention kernel only supports block sizes 16 and 32
# ROCM paged attention native C++ kernel only supports block sizes 16 and 32
# due to shared memory (LDS) constraints on AMD GPUs.
# See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro.

# However, The limitations in [16, 32] are reasonable for a native C++ kernel,
# but vLLM should allow support for non-standard sizes via the Triton path,
# as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380,
# where the Triton kernel under rocm_atten does not support inference
# for a non-standard qwen3-next model with a block_size of 544.
# We have fixed the Triton kernel so that the standard model uses the original
# bit-addressing logic, while the non-standard model
# uses our optimized kernel logic.
return [16, 32, 544]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)
# However, vLLM allows support for any multiple of 16 via the Triton path.
# As addressed in PR: https://github.com/vllm-project/vllm/pull/31380,
# non-standard models (like qwen3-next with block_size 544, or qwen3_5
# with 784 and 1056) are dynamically routed to our optimized Triton kernel
# in `do_kv_cache_update`.
return [MultipleOf(16)]
Comment on lines 176 to +185
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change to allow any block size that is a multiple of 16 is correct for supporting models like Qwen3.5, it introduces a potential failure for other models.

The dispatch logic in do_kv_cache_update (lines 450-480) uses is_pow2 to decide whether to use the native C++ kernel or the Triton fallback. The native C++ kernel, as noted in the comments and confirmed in csrc/rocm/attention.cu, only supports block sizes of 16 and 32.

With this PR, a model using a block size that is a power of two but not 16 or 32 (e.g., 64) will be incorrectly routed to the native C++ kernel, which will then raise an error.

To fix this, the condition in do_kv_cache_update should be changed from if is_pow2: to if block_size in (16, 32):. This will ensure that only the explicitly supported block sizes are routed to the native kernel, and all others (including other powers of two) use the Triton fallback.


@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down Expand Up @@ -463,11 +453,9 @@ def do_kv_cache_update(
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)

if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
if block_size in (16, 32):
# Normal 16, 32, use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
Expand All @@ -479,7 +467,7 @@ def do_kv_cache_update(
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
Expand Down