diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 3244ce7cc501..3d0fcd6c7716 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -171,8 +171,8 @@ Priority is **1 = highest** (tried first). | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | All | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index db6fd97c9dd9..130ccaa2d6fd 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -55,6 +55,16 @@ def use_cascade_attention(*args, **kwargs) -> bool: def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: return RocmAttentionMetadataBuilder + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """RocmAiterUnifiedAttention supports all attention types.""" + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): @@ -143,6 +153,19 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + # Handle encoder attention differently - no KV cache needed + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + key_cache, value_cache = kv_cache.unbind(0) if self.kv_cache_dtype.startswith("fp8"): @@ -195,6 +218,10 @@ def do_kv_cache_update( kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return key_cache, value_cache = kv_cache.unbind(0) # Reshape the input keys and values and store them in the cache. @@ -224,6 +251,10 @@ def do_rope_and_kv_cache_update( kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return key_cache, value_cache = kv_cache.unbind(0) flash_layout = True diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index d72293dec250..d4bfa764febe 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -205,6 +205,16 @@ def get_name() -> str: def get_impl_cls() -> type["RocmAttentionImpl"]: return RocmAttentionImpl + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """RocmAttention supports all attention types.""" + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -244,6 +254,7 @@ def __init__( kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, ) -> None: + self.attn_type = attn_type self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -266,11 +277,6 @@ def __init__( RocmAttentionBackend.validate_head_size(head_size) - if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: - raise NotImplementedError( - "Encoder self-attention is not implemented for RocmAttentionImpl" - ) - self.fp8_dtype = current_platform.fp8_dtype() self.sinks = sinks @@ -281,6 +287,54 @@ def __init__( f"num_heads: {num_heads}." ) + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + query_start_loc = attn_metadata.query_start_loc + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + + # Call flash attention directly on Q, K, V tensors + from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd + + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_input_len=max_query_len, + is_causal=False, + softmax_scale=self.scale, + sliding_window_q=self.sliding_window[0], + sliding_window_k=self.sliding_window[1], + ) + return output + def forward( self, layer: torch.nn.Module, @@ -330,6 +384,16 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size ) @@ -380,6 +444,8 @@ def do_kv_cache_update( kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + return key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size ) @@ -432,6 +498,8 @@ def do_rope_and_kv_cache_update( kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, ): + if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + return key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, layer.num_kv_heads, # type: ignore[attr-defined]