Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ Priority is **1 = highest** (tried first).
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

Expand Down
31 changes: 31 additions & 0 deletions vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def use_cascade_attention(*args, **kwargs) -> bool:
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAiterUnifiedAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
Comment on lines +58 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The RocmAiterUnifiedAttentionBackend class inherits from RocmAttentionBackend, which already defines an identical supports_attn_type method. This reimplementation is redundant and can be removed to rely on the parent's implementation. This improves maintainability by avoiding code duplication and ensuring consistency.



class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
Expand Down Expand Up @@ -143,6 +153,19 @@ def forward(

num_actual_tokens = attn_metadata.num_actual_tokens

# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
Comment on lines +157 to +167
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The layer parameter is unused in _forward_encoder_attention. This call should be updated to not pass the argument.

        if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                output[:num_actual_tokens],
                attn_metadata,
            )


key_cache, value_cache = kv_cache.unbind(0)

if self.kv_cache_dtype.startswith("fp8"):
Expand Down Expand Up @@ -195,6 +218,10 @@ def do_kv_cache_update(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)

# Reshape the input keys and values and store them in the cache.
Expand Down Expand Up @@ -224,6 +251,10 @@ def do_rope_and_kv_cache_update(
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True

Expand Down
78 changes: 73 additions & 5 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ def get_name() -> str:
def get_impl_cls() -> type["RocmAttentionImpl"]:
return RocmAttentionImpl

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -244,6 +254,7 @@ def __init__(
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
) -> None:
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -266,11 +277,6 @@ def __init__(

RocmAttentionBackend.validate_head_size(head_size)

if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for RocmAttentionImpl"
)

self.fp8_dtype = current_platform.fp8_dtype()

self.sinks = sinks
Expand All @@ -281,6 +287,54 @@ def __init__(
f"num_heads: {num_heads}."
)

def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
Comment on lines +290 to +298
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The layer parameter is unused within _forward_encoder_attention. Removing it will improve code clarity and maintainability.

    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
    ) -> torch.Tensor:

"""Forward pass for encoder attention without KV cache.

Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)

# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len

# Call flash attention directly on Q, K, V tensors
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd

context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False,
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output

def forward(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -330,6 +384,16 @@ def forward(

num_actual_tokens = attn_metadata.num_actual_tokens

if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
Comment on lines +387 to +395
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The layer parameter is unused in _forward_encoder_attention. This call should be updated to not pass the argument.

        if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
            return self._forward_encoder_attention(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                output[:num_actual_tokens],
                attn_metadata,
            )


key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
Expand Down Expand Up @@ -380,6 +444,8 @@ def do_kv_cache_update(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
Expand Down Expand Up @@ -432,6 +498,8 @@ def do_rope_and_kv_cache_update(
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache,
layer.num_kv_heads, # type: ignore[attr-defined]
Expand Down