diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b5e91096df3..9dbce529937 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -29,7 +29,8 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + get_kv_cache_layout) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec @@ -78,6 +79,26 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + @staticmethod def get_bsh_kv_cache_shape( num_blocks: int, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 1674d34697a..a2a76649e43 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -69,6 +69,15 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return (num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, num_kv_heads, head_size) + return (1, 0, 2, 3, 4) if include_num_layers_dimension else (0, 1, 2, + 3) + @staticmethod def get_impl_cls() -> Type["MLAAttentionImpl"]: return AscendMLAImpl diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index bc1de75fc4f..ab5887e58b2 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -55,6 +55,15 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return (num_blocks, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, num_kv_heads, head_size) + return (1, 0, 2, 3, 4) if include_num_layers_dimension else (0, 1, 2, + 3) + @staticmethod def get_impl_cls() -> Type["AscendSFAImpl"]: return AscendSFAImpl diff --git a/vllm_ascend/kv_offload/npu.py b/vllm_ascend/kv_offload/npu.py index 9f80237b718..6d4f0c268ba 100644 --- a/vllm_ascend/kv_offload/npu.py +++ b/vllm_ascend/kv_offload/npu.py @@ -2,8 +2,8 @@ from typing import Optional import torch -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager @@ -45,18 +45,12 @@ def get_manager(self) -> OffloadingManager: return self._manager def get_handlers( - self, kv_caches: dict[str, torch.Tensor] + self, + kv_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: - layer_names = list(kv_caches.keys()) - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) - attn_backends = { - layer_name: layers[layer_name].get_attn_backend() - for layer_name in layer_names - } self._handler = CpuNpuOffloadingHandler( attn_backends=attn_backends, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c6c881ebc30..42eada03c1c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -40,6 +40,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) +from vllm.config.cache import CacheDType from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer @@ -220,6 +221,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): else: self.prefetch_stream = None self.sampler = AscendSampler() + + self.cross_layers_kv_cache: torch.Tensor | None = None + self.cross_layers_attn_backend: type[AttentionBackend] | None = None self.attn_mask = None self.attn_state = None @@ -2365,11 +2369,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: for attn_group in self.attn_groups ]) - self.may_reinitialize_input_batch(kv_cache_config) - kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + kernel_block_sizes = self.may_reinitialize_input_batch(kv_cache_config) + kv_caches = self.initialize_kv_cache_tensors( + kv_cache_config, + kernel_block_sizes=[i[0] for i in kernel_block_sizes]) if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) + kv_transfer_group = get_kv_transfer_group() + if self.cross_layers_kv_cache is not None: + assert self.cross_layers_attn_backend is not None + kv_transfer_group.register_cross_layers_kv_cache( + self.cross_layers_kv_cache, self.cross_layers_attn_backend) + else: + kv_transfer_group.register_kv_caches(kv_caches) def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor: @@ -2379,7 +2391,8 @@ def _align_memory(self, tensor: torch.Tensor, return tensor[int(offset):] def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig, + kernel_block_sizes) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -2389,11 +2402,27 @@ def initialize_kv_cache_tensors( Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ - # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) - # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + # Try creating KV caches optimized for kv-connector transfers + cache_dtype = self.cache_config.cache_dtype + if self.use_uniform_kv_cache(self.attn_groups, cache_dtype): + kv_caches, cross_layers_k_cache, cross_layers_v_cache, attn_backend = ( + self.allocate_uniform_kv_caches( + kv_cache_config, + self.attn_groups, + cache_dtype, + self.device, + kernel_block_sizes, + )) + self.cross_layers_kv_cache = (cross_layers_k_cache, + cross_layers_v_cache) + self.cross_layers_attn_backend = attn_backend + else: + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) from vllm.v1.worker.utils import bind_kv_cache bind_kv_cache(kv_caches, @@ -2681,7 +2710,7 @@ def _reshape_kv_cache_tensors( return kv_caches def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + kv_cache_config: KVCacheConfig) -> list: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -2760,6 +2789,8 @@ def may_reinitialize_input_batch(self, kernel_block_sizes=kernel_block_sizes, ) + return kernel_block_sizes + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -3454,6 +3485,188 @@ def _generate_pcp_mtp_input( non_blocking=True, ) + def use_uniform_kv_cache( + self, + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + ) -> bool: + """ + Determines whether a uniform KV layout should be used. + A uniform layout means all layers KV caches will share the same + underlying tensor, where for a given block number, the respective + KV data for all layers will be contiguous. + This will allow efficient KV transfer of per-block KV data for all + layers at once. + Note this layout will only be applied given 3 conditions: + 1. The KV Cache config contains just a single group where all layers + have the same page size. + 2. A KV connector is configured, and the KV connector instance prefers + to use this layout (prefer_cross_layer_blocks() returns True) + 2. The flash attention backend supports this layout + (get_kv_cache_stride_order(True) includes a placement for a + num_layers dimension) + + Note that the actual placement of the num_layers dimensions + in the unified layers tensors will be determined by the attention + backend. + Thus, the layers KV data may still not be contiguous per block + if the attention backend does not support it. + + Args: + attn_groups: The list of attention groups for this model + cache_dtype: The KV cache dtype + Returns: + True if we should use a uniform KV cache layout. + """ + + if not has_kv_transfer_group(): + return False + if not get_kv_transfer_group().prefer_cross_layer_blocks: + return False + + if len(attn_groups) != 1 or len(attn_groups[0]) != 1: + return False + + attn_group = attn_groups[0][0] + kv_cache_spec = attn_group.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + return False + + attn_backend = attn_group.backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + 1234, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype, + ) + + # `kv_cache_stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # If an exception occurs, the subsequent allocate_uniform_kv_caches + # cannot be completed smoothly, so return False directly. + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True) + except (AttributeError, NotImplementedError): + return False + + # check that attention backend include a layers dimension + return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 + + def allocate_uniform_kv_caches( + self, + kv_cache_config: KVCacheConfig, + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + device: torch.device, + kernel_block_sizes: list[int], + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor, + type[AttentionBackend]]: + """ + Initializes and reshapes KV caches for the simple case where all + layers have the same layout. + + This function assumes use_uniform_kv_cache() returned True. + + Args: + kv_cache_config: The KV cache config + attn_groups: The list of attention groups for this model + cache_dtype: The KV cache dtype + device: The torch device to allocate on. + kernel_block_sizes: The kernel block sizes for each KV cache group. + Returns: + A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where: + kv_caches is a dict mapping between layer names to their + corresponding memory buffer for KV cache. + cross_layers_kv_cache is the cross layers kv cache tensor + attn_backend is the attention backend matching this tensor + """ + attn_group = attn_groups[0][0] + kv_cache_spec = attn_group.kv_cache_spec + assert isinstance(kv_cache_spec, AttentionSpec) + + tensor_sizes = set( + kv_cache_tensor.size + for kv_cache_tensor in kv_cache_config.kv_cache_tensors) + assert len(tensor_sizes) == 1 + tensor_size = tensor_sizes.pop() + + page_size = kv_cache_spec.page_size_bytes + assert tensor_size % page_size == 0 + num_blocks = tensor_size // page_size + num_layers = len(kv_cache_config.kv_cache_tensors) + total_size = tensor_size * num_layers + + assert len(kernel_block_sizes) == 1 + kernel_block_size = kernel_block_sizes[0] + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + attn_backend = attn_group.backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype, + ) + + # prepend a num_layers dimension into the shape + kv_cache_shape = (num_layers, ) + kv_cache_shape + + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True) + assert len(kv_cache_stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + if not self.model_config.is_deepseek_mla: + new_kv_cache_shape = tuple(kv_cache_shape[i] + for i in kv_cache_stride_order + if kv_cache_shape[i] != 2) + else: + new_kv_cache_shape_list = [ + kv_cache_shape[i] for i in kv_cache_stride_order + ] + new_kv_cache_shape_list[-1] = new_kv_cache_shape_list[-1] // 2 + new_kv_cache_shape = tuple(new_kv_cache_shape_list) + logger.info("Allocating a cross layer KV cache of shape %s", + new_kv_cache_shape) + + # allocate two contiguous buffer for all layers + cross_layers_k_cache = (torch.zeros( + total_size // 2, dtype=torch.int8, + device=device).view(kv_cache_spec.dtype).view(new_kv_cache_shape)) + cross_layers_v_cache = (torch.zeros( + total_size // 2, dtype=torch.int8, + device=device).view(kv_cache_spec.dtype).view(new_kv_cache_shape)) + + if not self.model_config.is_deepseek_mla: + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + if kv_cache_shape[i] != 2 + ] + if len(new_kv_cache_shape) != len(kv_cache_shape): + inv_order = [i - 1 if i > 1 else i for i in inv_order] + else: + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + + permuted_k_cache = cross_layers_k_cache.permute(*inv_order) + permuted_v_cache = cross_layers_v_cache.permute(*inv_order) + kv_caches = {} + for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors): + tensor = (permuted_k_cache[i], permuted_v_cache[i]) + for layer_name in kv_cache_tensor.shared_by: + kv_caches[layer_name] = tensor + + return kv_caches, cross_layers_k_cache, cross_layers_v_cache, attn_backend + @contextmanager def _torch_cuda_wrapper(): @@ -3489,4 +3702,4 @@ def __init__(self, *args, **kwargs) -> None: torch.cuda.Stream = torch.cuda.Stream torch.cuda.default_stream = torch.npu.default_stream torch.cuda.current_stream = torch.npu.current_stream - torch.cuda.stream = torch.npu.stream + torch.cuda.stream = torch.npu.stream \ No newline at end of file