-
-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[ROCm] Enabling encoder and encoder-decoder on ROCm and AITER unified backends #35334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,16 @@ def use_cascade_attention(*args, **kwargs) -> bool: | |
| def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: | ||
| return RocmAttentionMetadataBuilder | ||
|
|
||
| @classmethod | ||
| def supports_attn_type(cls, attn_type: str) -> bool: | ||
| """RocmAiterUnifiedAttention supports all attention types.""" | ||
| return attn_type in ( | ||
| AttentionType.DECODER, | ||
| AttentionType.ENCODER, | ||
| AttentionType.ENCODER_ONLY, | ||
| AttentionType.ENCODER_DECODER, | ||
| ) | ||
|
|
||
|
|
||
| class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): | ||
| def fused_output_quant_supported(self, quant_key: QuantKey): | ||
|
|
@@ -143,6 +153,19 @@ def forward( | |
|
|
||
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
|
||
| # Handle encoder attention differently - no KV cache needed | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| # For encoder attention, | ||
| # we use direct Q, K, V tensors without caching | ||
| return self._forward_encoder_attention( | ||
| query[:num_actual_tokens], | ||
| key[:num_actual_tokens], | ||
| value[:num_actual_tokens], | ||
| output[:num_actual_tokens], | ||
| attn_metadata, | ||
| layer, | ||
| ) | ||
|
Comment on lines
+157
to
+167
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
) |
||
|
|
||
| key_cache, value_cache = kv_cache.unbind(0) | ||
|
|
||
| if self.kv_cache_dtype.startswith("fp8"): | ||
|
|
@@ -195,6 +218,10 @@ def do_kv_cache_update( | |
| kv_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| ): | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| # For encoder attention, | ||
| # we use direct Q, K, V tensors without caching | ||
| return | ||
| key_cache, value_cache = kv_cache.unbind(0) | ||
|
|
||
| # Reshape the input keys and values and store them in the cache. | ||
|
|
@@ -224,6 +251,10 @@ def do_rope_and_kv_cache_update( | |
| kv_cache: torch.Tensor, | ||
| layer_slot_mapping: torch.Tensor, | ||
| ): | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| # For encoder attention, | ||
| # we use direct Q, K, V tensors without caching | ||
| return | ||
| key_cache, value_cache = kv_cache.unbind(0) | ||
| flash_layout = True | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -205,6 +205,16 @@ def get_name() -> str: | |
| def get_impl_cls() -> type["RocmAttentionImpl"]: | ||
| return RocmAttentionImpl | ||
|
|
||
| @classmethod | ||
| def supports_attn_type(cls, attn_type: str) -> bool: | ||
| """RocmAttention supports all attention types.""" | ||
| return attn_type in ( | ||
| AttentionType.DECODER, | ||
| AttentionType.ENCODER, | ||
| AttentionType.ENCODER_ONLY, | ||
| AttentionType.ENCODER_DECODER, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def get_kv_cache_shape( | ||
| num_blocks: int, | ||
|
|
@@ -244,6 +254,7 @@ def __init__( | |
| kv_sharing_target_layer_name: int | None = None, | ||
| sinks: torch.Tensor | None = None, | ||
| ) -> None: | ||
| self.attn_type = attn_type | ||
| self.num_heads = num_heads | ||
| self.head_size = head_size | ||
| self.scale = float(scale) | ||
|
|
@@ -266,11 +277,6 @@ def __init__( | |
|
|
||
| RocmAttentionBackend.validate_head_size(head_size) | ||
|
|
||
| if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]: | ||
| raise NotImplementedError( | ||
| "Encoder self-attention is not implemented for RocmAttentionImpl" | ||
| ) | ||
|
|
||
| self.fp8_dtype = current_platform.fp8_dtype() | ||
|
|
||
| self.sinks = sinks | ||
|
|
@@ -281,6 +287,54 @@ def __init__( | |
| f"num_heads: {num_heads}." | ||
| ) | ||
|
|
||
| def _forward_encoder_attention( | ||
| self, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| output: torch.Tensor, | ||
| attn_metadata: FlashAttentionMetadata, | ||
| layer: torch.nn.Module, | ||
| ) -> torch.Tensor: | ||
|
Comment on lines
+290
to
+298
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
) -> torch.Tensor: |
||
| """Forward pass for encoder attention without KV cache. | ||
|
|
||
| Args: | ||
| query: shape = [num_encoder_tokens, num_heads, head_size] | ||
| key: shape = [num_encoder_tokens, num_kv_heads, head_size] | ||
| value: shape = [num_encoder_tokens, num_kv_heads, head_size] | ||
| output: shape = [num_encoder_tokens, num_heads, head_size] | ||
| attn_metadata: Encoder attention metadata | ||
| layer: The attention layer | ||
| """ | ||
| # For encoder attention, process FP8 quantization if needed | ||
| if self.kv_cache_dtype.startswith("fp8"): | ||
| raise NotImplementedError( | ||
| "quantization is not supported for encoder attention" | ||
| ) | ||
|
|
||
| # Use encoder-specific metadata for sequence information | ||
| query_start_loc = attn_metadata.query_start_loc | ||
| seq_lens = attn_metadata.seq_lens | ||
| max_query_len = attn_metadata.max_query_len | ||
|
|
||
| # Call flash attention directly on Q, K, V tensors | ||
| from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd | ||
|
|
||
| context_attention_fwd( | ||
| q=query, | ||
| k=key, | ||
| v=value, | ||
| o=output, | ||
| b_start_loc=query_start_loc, | ||
| b_seq_len=seq_lens, | ||
| max_input_len=max_query_len, | ||
| is_causal=False, | ||
| softmax_scale=self.scale, | ||
| sliding_window_q=self.sliding_window[0], | ||
| sliding_window_k=self.sliding_window[1], | ||
| ) | ||
| return output | ||
|
|
||
| def forward( | ||
| self, | ||
| layer: torch.nn.Module, | ||
|
|
@@ -330,6 +384,16 @@ def forward( | |
|
|
||
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
|
||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| return self._forward_encoder_attention( | ||
| query[:num_actual_tokens], | ||
| key[:num_actual_tokens], | ||
| value[:num_actual_tokens], | ||
| output[:num_actual_tokens], | ||
| attn_metadata, | ||
| layer, | ||
| ) | ||
|
Comment on lines
+387
to
+395
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
) |
||
|
|
||
| key_cache, value_cache = PagedAttention.split_kv_cache( | ||
| kv_cache, self.num_kv_heads, self.head_size | ||
| ) | ||
|
|
@@ -380,6 +444,8 @@ def do_kv_cache_update( | |
| kv_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| ): | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| return | ||
| key_cache, value_cache = PagedAttention.split_kv_cache( | ||
| kv_cache, self.num_kv_heads, self.head_size | ||
| ) | ||
|
|
@@ -432,6 +498,8 @@ def do_rope_and_kv_cache_update( | |
| kv_cache: torch.Tensor, | ||
| layer_slot_mapping: torch.Tensor, | ||
| ): | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): | ||
| return | ||
| key_cache, value_cache = PagedAttention.split_kv_cache( | ||
| kv_cache, | ||
| layer.num_kv_heads, # type: ignore[attr-defined] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
RocmAiterUnifiedAttentionBackendclass inherits fromRocmAttentionBackend, which already defines an identicalsupports_attn_typemethod. This reimplementation is redundant and can be removed to rely on the parent's implementation. This improves maintainability by avoiding code duplication and ensuring consistency.