-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[Hardware][AMD] Add fused QK RoPE and reshape & cache flash support for ROCm #28850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -45,6 +45,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if current_platform.is_rocm(): | ||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.platforms.rocm import on_gfx9 | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if envs.VLLM_ROCM_USE_AITER: | ||||||||||||||||||||||||||||||||||||||||||||||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||||||||||||||||||||||||||||||||||||||||||||||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| on_gfx9 = lambda *args, **kwargs: False | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
46
to
56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -235,6 +242,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| attn_type: str = AttentionType.DECODER, | ||||||||||||||||||||||||||||||||||||||||||||||
| kv_sharing_target_layer_name: str | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| attn_backend: type[AttentionBackend] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| rotary_emb: nn.Module | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| **extra_impl_args, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -310,6 +318,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| kv_sharing_target_layer_name, | ||||||||||||||||||||||||||||||||||||||||||||||
| **extra_impl_args, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.impl.rotary_emb = rotary_emb | ||||||||||||||||||||||||||||||||||||||||||||||
|
Check failure on line 321 in vllm/attention/layer.py
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.backend = AttentionBackendEnum[self.attn_backend.get_name()] | ||||||||||||||||||||||||||||||||||||||||||||||
| self.dtype = dtype | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -365,6 +374,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| # shape does not match the query shape, so we optionally let the model | ||||||||||||||||||||||||||||||||||||||||||||||
| # definition specify the output tensor shape. | ||||||||||||||||||||||||||||||||||||||||||||||
| output_shape: torch.Size | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| positions: torch.Tensor = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| The KV cache is stored inside this class and is accessed via | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -377,7 +387,6 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.calculate_kv_scales: | ||||||||||||||||||||||||||||||||||||||||||||||
| torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) | ||||||||||||||||||||||||||||||||||||||||||||||
| output_dtype = query.dtype | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.query_quant is not None: | ||||||||||||||||||||||||||||||||||||||||||||||
| # quantizing with a simple torch operation enables | ||||||||||||||||||||||||||||||||||||||||||||||
| # torch.compile to fuse this into previous ops | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -392,7 +401,15 @@ | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if self.use_output: | ||||||||||||||||||||||||||||||||||||||||||||||
| output_shape = output_shape if output_shape is not None else query.shape | ||||||||||||||||||||||||||||||||||||||||||||||
| output = torch.empty(output_shape, dtype=output_dtype, device=query.device) | ||||||||||||||||||||||||||||||||||||||||||||||
| if positions is not None: | ||||||||||||||||||||||||||||||||||||||||||||||
| output = torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||
| output_shape, dtype=query.dtype, device=query.device | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| output = torch.zeros( | ||||||||||||||||||||||||||||||||||||||||||||||
| output_shape, dtype=query.dtype, device=query.device | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
402
to
+410
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the attention forward path, the output tensor is now created with Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| hidden_size = output_shape[-1] | ||||||||||||||||||||||||||||||||||||||||||||||
| # Reshape the query, key, and value tensors. | ||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE(woosuk): We do this outside the custom op to minimize the | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -414,7 +431,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| torch.ops.vllm.unified_attention_with_output( | ||||||||||||||||||||||||||||||||||||||||||||||
| query, key, value, output, self.layer_name | ||||||||||||||||||||||||||||||||||||||||||||||
| query, | ||||||||||||||||||||||||||||||||||||||||||||||
| key, | ||||||||||||||||||||||||||||||||||||||||||||||
| value, | ||||||||||||||||||||||||||||||||||||||||||||||
| output, | ||||||||||||||||||||||||||||||||||||||||||||||
| self.layer_name, | ||||||||||||||||||||||||||||||||||||||||||||||
| None, | ||||||||||||||||||||||||||||||||||||||||||||||
| positions=positions, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return output.view(-1, hidden_size) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -941,19 +964,44 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| layer_name: str, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_scale: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_block_scale: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| positions: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| attn_metadata, self, kv_cache = get_attention_context(layer_name) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.impl.forward( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| query, | ||||||||||||||||||||||||||||||||||||||||||||||
| key, | ||||||||||||||||||||||||||||||||||||||||||||||
| value, | ||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||||||||||||||||||||||||||
| attn_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||
| output=output, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_scale=output_scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_block_scale=output_block_scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and isinstance( | ||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AITER flags management are done in the |
||||||||||||||||||||||||||||||||||||||||||||||
| self.impl, AiterFlashAttentionImpl | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| # fusing RoPE with flushing kv_cache operation | ||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||
| hasattr(self.impl, "rotary_emb") | ||||||||||||||||||||||||||||||||||||||||||||||
| and self.impl.rotary_emb is not None | ||||||||||||||||||||||||||||||||||||||||||||||
| and positions is not None | ||||||||||||||||||||||||||||||||||||||||||||||
| ), f"rotary_emb not found in {self.impl=} and positions cannot be None" | ||||||||||||||||||||||||||||||||||||||||||||||
| self.impl.forward( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| query, | ||||||||||||||||||||||||||||||||||||||||||||||
| key, | ||||||||||||||||||||||||||||||||||||||||||||||
| value, | ||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||||||||||||||||||||||||||
| attn_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||
| output=output, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_scale=output_scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| positions=positions, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| assert positions is None, f"positions must be None {positions=}" | ||||||||||||||||||||||||||||||||||||||||||||||
| self.impl.forward( | ||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||
| query, | ||||||||||||||||||||||||||||||||||||||||||||||
| key, | ||||||||||||||||||||||||||||||||||||||||||||||
| value, | ||||||||||||||||||||||||||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||||||||||||||||||||||||||
| attn_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||
| output=output, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_scale=output_scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_block_scale=output_block_scale, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def unified_attention_with_output_fake( | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -964,6 +1012,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| layer_name: str, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_scale: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| output_block_scale: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| positions: torch.Tensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,6 +200,7 @@ | |
| VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False | ||
| VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False | ||
| VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw that this is enabled default. |
||
| VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False | ||
| VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True | ||
| VLLM_TUNED_CONFIG_FOLDER: str | None = None | ||
|
|
@@ -1393,6 +1394,10 @@ def get_vllm_port() -> int | None: | |
| "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": lambda: bool( | ||
| int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0")) | ||
| ), | ||
| # Use AITER Triton fused RoPE, zeros, and reshape_and_cache kernel | ||
| "VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE": lambda: bool( | ||
| int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1")) | ||
| ), | ||
| # Whether to use pytorch symmetric memory for allreduce | ||
| "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( | ||
| int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) | ||
|
|
@@ -1615,6 +1620,7 @@ def compute_hash() -> str: | |
| "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", | ||
| "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", | ||
| "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", | ||
| "VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", | ||
| "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", | ||
| "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", | ||
| "VLLM_NVFP4_GEMM_BACKEND", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| from torch import nn | ||
| from transformers import Qwen3Config | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.attention import Attention, AttentionType | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.config import CacheConfig, VllmConfig | ||
|
|
@@ -41,6 +42,7 @@ | |
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.rotary_embedding import get_rope | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead | ||
| from vllm.platforms import current_platform | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
| from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP | ||
|
|
@@ -49,6 +51,12 @@ | |
| from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix | ||
|
|
||
| logger = init_logger(__name__) | ||
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AITER flags management are done in the |
||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
| ) | ||
| else: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False | ||
|
Comment on lines
+54
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for setting from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE |
||
|
|
||
|
|
||
| class Qwen3Attention(nn.Module): | ||
|
|
@@ -132,6 +140,11 @@ def __init__( | |
| } | ||
| if dual_chunk_attention_config | ||
| else {}, | ||
| rotary_emb=( | ||
| self.rotary_emb | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||
| else None | ||
| ), | ||
| ) | ||
| self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) | ||
| self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) | ||
|
|
@@ -150,8 +163,12 @@ def forward( | |
| k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) | ||
| k_by_head = self.k_norm(k_by_head) | ||
| k = k_by_head.view(k.shape) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||
| attn_output = self.attn(q, k, v, positions=positions) | ||
| else: | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
|
|
||
| output, _ = self.o_proj(attn_output) | ||
| return output | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| import torch | ||
| from torch import nn | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.attention import Attention | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config | ||
|
|
@@ -63,6 +64,7 @@ | |
| maybe_remap_kv_scale_name, | ||
| ) | ||
| from vllm.model_executor.models.utils import sequence_parallel_chunk | ||
| from vllm.platforms import current_platform | ||
| from vllm.sequence import IntermediateTensors | ||
|
|
||
| from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP | ||
|
|
@@ -77,6 +79,12 @@ | |
| ) | ||
|
|
||
| logger = init_logger(__name__) | ||
| if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||
| envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
| ) | ||
| else: | ||
| VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False | ||
|
Comment on lines
+82
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for setting from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE |
||
|
|
||
|
|
||
| class Qwen3MoeMLP(nn.Module): | ||
|
|
@@ -291,6 +299,11 @@ def __init__( | |
| } | ||
| if dual_chunk_attention_config | ||
| else {}, | ||
| rotary_emb=( | ||
| self.rotary_emb | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||
| else None | ||
| ), | ||
| ) | ||
|
|
||
| self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) | ||
|
|
@@ -311,8 +324,12 @@ def forward( | |
| k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) | ||
| k_by_head = self.k_norm(k_by_head) | ||
| k = k_by_head.view(k.shape) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
| if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||
| attn_output = self.attn(q, k, v, positions=positions) | ||
| else: | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| attn_output = self.attn(q, k, v) | ||
|
|
||
| output, _ = self.o_proj(attn_output) | ||
| return output | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
|
|
||||||
| import torch | ||||||
|
|
||||||
| import vllm.envs as envs | ||||||
| from vllm.attention.backends.abstract import ( | ||||||
| AttentionBackend, | ||||||
| AttentionImpl, | ||||||
|
|
@@ -35,6 +36,9 @@ | |||||
|
|
||||||
| from vllm.triton_utils import tl, triton | ||||||
|
|
||||||
| if envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise |
||||||
| from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache | ||||||
|
|
||||||
| def block_size(x, head_dim): | ||||||
| return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) | ||||||
|
|
||||||
|
|
@@ -637,6 +641,7 @@ def forward( | |||||
| output: torch.Tensor | None = None, | ||||||
| output_scale: torch.Tensor | None = None, | ||||||
| output_block_scale: torch.Tensor | None = None, | ||||||
| positions: torch.Tensor | None = None, | ||||||
| ) -> torch.Tensor: | ||||||
| """Forward pass with AiterFlashAttention. | ||||||
|
|
||||||
|
|
@@ -675,25 +680,62 @@ def forward( | |||||
| # performance to make sure it does not introduce any overhead. | ||||||
| num_actual_tokens = attn_metadata.num_actual_tokens | ||||||
| key_cache, value_cache = kv_cache.unbind(0) | ||||||
| if self.kv_sharing_target_layer_name is None: | ||||||
| # Reshape the input keys and values and store them in the cache. | ||||||
| # Skip this if sharing KV cache with an earlier attention layer. | ||||||
| # NOTE(woosuk): Here, key and value are padded while slot_mapping | ||||||
| # is not padded. However, we don't need to do | ||||||
| # key[:num_actual_tokens] and value[:num_actual_tokens] because | ||||||
| # the reshape_and_cache_flash op uses the slot_mapping's shape | ||||||
| # to determine the number of actual tokens. | ||||||
|
|
||||||
| torch.ops._C_cache_ops.reshape_and_cache_flash( | ||||||
| key, | ||||||
| value, | ||||||
| key_cache, | ||||||
| value_cache, | ||||||
| attn_metadata.slot_mapping, | ||||||
| self.kv_cache_dtype, | ||||||
| layer._k_scale, | ||||||
| layer._v_scale, | ||||||
| if positions is not None and query.shape[0] <= 256: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value
Suggested change
|
||||||
| assert self.kv_sharing_target_layer_name is None, ( | ||||||
| "self.kv_sharing_target_layer_name cannot be None" | ||||||
| ) | ||||||
| assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" | ||||||
| cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1) | ||||||
| is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") | ||||||
| if is_fp8_kv_cache: | ||||||
| key_cache = key_cache.view(current_platform.fp8_dtype()) | ||||||
| value_cache = value_cache.view(current_platform.fp8_dtype()) | ||||||
|
|
||||||
| query, key, key_cache, value_cache, output = ( | ||||||
| fused_qk_rope_reshape_and_cache( | ||||||
| query, | ||||||
| key, | ||||||
| value, | ||||||
| key_cache, | ||||||
| value_cache, | ||||||
| attn_metadata.slot_mapping, | ||||||
| positions, | ||||||
| cos, | ||||||
| sin, | ||||||
| layer._k_scale, | ||||||
| layer._v_scale, | ||||||
| self.rotary_emb.is_neox_style, | ||||||
| flash_layout=True, | ||||||
| apply_scale=is_fp8_kv_cache, | ||||||
| offs=None, | ||||||
| q_out=query, | ||||||
| k_out=key, | ||||||
| output_zeros=True, | ||||||
| zeros_out=output, | ||||||
| ) | ||||||
| ) | ||||||
| else: | ||||||
| if positions is not None: | ||||||
| query, key = self.rotary_emb(positions, query, key) | ||||||
|
|
||||||
| if self.kv_sharing_target_layer_name is None: | ||||||
| # Reshape the input keys and values and store them in the cache. | ||||||
| # Skip this if sharing KV cache with an earlier attention layer. | ||||||
| # NOTE(woosuk): Here, key and value are padded while slot_mapping is | ||||||
| # not padded. However, we don't need to do key[:num_actual_tokens] | ||||||
| # and value[:num_actual_tokens] because the reshape_and_cache_flash | ||||||
| # op uses the slot_mapping's shape to determine the number of | ||||||
| # actual tokens. | ||||||
| torch.ops._C_cache_ops.reshape_and_cache_flash( | ||||||
| key, | ||||||
| value, | ||||||
| key_cache, | ||||||
| value_cache, | ||||||
| attn_metadata.slot_mapping, | ||||||
| self.kv_cache_dtype, | ||||||
| layer._k_scale, | ||||||
| layer._v_scale, | ||||||
| ) | ||||||
|
|
||||||
| if self.kv_cache_dtype.startswith("fp8"): | ||||||
| key_cache = key_cache.view(current_platform.fp8_dtype()) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AITER flags management are done in the
_aiter_ops.py. Please move all the flags there and userocm_aiter_ops.is_enabled()and some new flags there.