diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index b343f9277761..81533c29de2f 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first). | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 96c4033d8adc..1d0dc81dc2c5 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -174,25 +174,15 @@ class RocmAttentionBackend(AttentionBackend): @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - # ROCM paged attention kernel only supports block sizes 16 and 32 + # ROCM paged attention native C++ kernel only supports block sizes 16 and 32 # due to shared memory (LDS) constraints on AMD GPUs. # See csrc/rocm/attention.cu CALL_CUSTOM_LAUNCHER_BLK macro. - - # However, The limitations in [16, 32] are reasonable for a native C++ kernel, - # but vLLM should allow support for non-standard sizes via the Triton path, - # as addressed in this PR: https://github.com/vllm-project/vllm/pull/31380, - # where the Triton kernel under rocm_atten does not support inference - # for a non-standard qwen3-next model with a block_size of 544. - # We have fixed the Triton kernel so that the standard model uses the original - # bit-addressing logic, while the non-standard model - # uses our optimized kernel logic. - return [16, 32, 544] - - @classmethod - def supports_block_size(cls, block_size: int | None) -> bool: - if block_size is None: - return True - return block_size in (16, 32, 544) + # However, vLLM allows support for any multiple of 16 via the Triton path. + # As addressed in PR: https://github.com/vllm-project/vllm/pull/31380, + # non-standard models (like qwen3-next with block_size 544, or qwen3_5 + # with 784 and 1056) are dynamically routed to our optimized Triton kernel + # in `do_kv_cache_update`. + return [MultipleOf(16)] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -463,11 +453,9 @@ def do_kv_cache_update( # Get the actual block_size from value_cache # value_cache shape: [num_blocks, num_heads, head_size, block_size] block_size = value_cache.shape[3] - # Determine if it is a power of 2 - is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) - if is_pow2: - # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic + if block_size in (16, 32): + # Normal 16, 32, use vLLM native HIP C++ logic PagedAttention.write_to_paged_cache( key, value, @@ -479,7 +467,7 @@ def do_kv_cache_update( layer._v_scale, ) else: - # Case B: Non-standard blocks (e.g., 544 in Qwen3), + # Case B: Non-standard blocks (e.g., 64, 128, 544 in Qwen3Next or Qwen3.5 ), # force using our modified Triton logic triton_reshape_and_cache_flash( key,