Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
get_kv_cache_layout)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -78,6 +79,26 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order

@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
Expand Down
9 changes: 9 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4) if include_num_layers_dimension else (0, 1, 2,
3)

@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
return AscendMLAImpl
Expand Down
9 changes: 9 additions & 0 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4) if include_num_layers_dimension else (0, 1, 2,
3)

@staticmethod
def get_impl_cls() -> Type["AscendSFAImpl"]:
return AscendSFAImpl
Expand Down
16 changes: 5 additions & 11 deletions vllm_ascend/kv_offload/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional

import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
Expand Down Expand Up @@ -45,18 +45,12 @@ def get_manager(self) -> OffloadingManager:
return self._manager

def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
if not self._handler:
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}

self._handler = CpuNpuOffloadingHandler(
attn_backends=attn_backends,
Expand Down
235 changes: 224 additions & 11 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.config.cache import CacheDType
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
Expand Down Expand Up @@ -220,6 +221,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
else:
self.prefetch_stream = None
self.sampler = AscendSampler()

self.cross_layers_kv_cache: torch.Tensor | None = None
self.cross_layers_attn_backend: type[AttentionBackend] | None = None
self.attn_mask = None
self.attn_state = None

Expand Down Expand Up @@ -2365,11 +2369,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
for attn_group in self.attn_groups
])

self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
kernel_block_sizes = self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(
kv_cache_config,
kernel_block_sizes=[i[0] for i in kernel_block_sizes])

if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
kv_transfer_group = get_kv_transfer_group()
if self.cross_layers_kv_cache is not None:
assert self.cross_layers_attn_backend is not None
kv_transfer_group.register_cross_layers_kv_cache(
self.cross_layers_kv_cache, self.cross_layers_attn_backend)
else:
kv_transfer_group.register_kv_caches(kv_caches)

def _align_memory(self, tensor: torch.Tensor,
alignment: int) -> torch.Tensor:
Expand All @@ -2379,7 +2391,8 @@ def _align_memory(self, tensor: torch.Tensor,
return tensor[int(offset):]

def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
self, kv_cache_config: KVCacheConfig,
kernel_block_sizes) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.

Expand All @@ -2389,11 +2402,27 @@ def initialize_kv_cache_tensors(
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
# Try creating KV caches optimized for kv-connector transfers
cache_dtype = self.cache_config.cache_dtype
if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
kv_caches, cross_layers_k_cache, cross_layers_v_cache, attn_backend = (
self.allocate_uniform_kv_caches(
kv_cache_config,
self.attn_groups,
cache_dtype,
self.device,
kernel_block_sizes,
))
self.cross_layers_kv_cache = (cross_layers_k_cache,
cross_layers_v_cache)
self.cross_layers_attn_backend = attn_backend
else:
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(
kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)

from vllm.v1.worker.utils import bind_kv_cache
bind_kv_cache(kv_caches,
Expand Down Expand Up @@ -2681,7 +2710,7 @@ def _reshape_kv_cache_tensors(
return kv_caches

def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
kv_cache_config: KVCacheConfig) -> list:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
Expand Down Expand Up @@ -2760,6 +2789,8 @@ def may_reinitialize_input_batch(self,
kernel_block_sizes=kernel_block_sizes,
)

return kernel_block_sizes

def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
Expand Down Expand Up @@ -3454,6 +3485,188 @@ def _generate_pcp_mtp_input(
non_blocking=True,
)

def use_uniform_kv_cache(
self,
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
) -> bool:
"""
Determines whether a uniform KV layout should be used.
A uniform layout means all layers KV caches will share the same
underlying tensor, where for a given block number, the respective
KV data for all layers will be contiguous.
This will allow efficient KV transfer of per-block KV data for all
layers at once.
Note this layout will only be applied given 3 conditions:
1. The KV Cache config contains just a single group where all layers
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)

Note that the actual placement of the num_layers dimensions
in the unified layers tensors will be determined by the attention
backend.
Thus, the layers KV data may still not be contiguous per block
if the attention backend does not support it.

Args:
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
Returns:
True if we should use a uniform KV cache layout.
"""

if not has_kv_transfer_group():
return False
if not get_kv_transfer_group().prefer_cross_layer_blocks:
return False

if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
return False

attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False

attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)

# `kv_cache_stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# If an exception occurs, the subsequent allocate_uniform_kv_caches
# cannot be completed smoothly, so return False directly.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True)
except (AttributeError, NotImplementedError):
return False
Comment thread
HF-001 marked this conversation as resolved.

# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1

def allocate_uniform_kv_caches(
self,
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
device: torch.device,
kernel_block_sizes: list[int],
) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor,
type[AttentionBackend]]:
"""
Initializes and reshapes KV caches for the simple case where all
layers have the same layout.

This function assumes use_uniform_kv_cache() returned True.

Args:
kv_cache_config: The KV cache config
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
device: The torch device to allocate on.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
kv_caches is a dict mapping between layer names to their
corresponding memory buffer for KV cache.
cross_layers_kv_cache is the cross layers kv cache tensor
attn_backend is the attention backend matching this tensor
"""
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)

tensor_sizes = set(
kv_cache_tensor.size
for kv_cache_tensor in kv_cache_config.kv_cache_tensors)
assert len(tensor_sizes) == 1
tensor_size = tensor_sizes.pop()

page_size = kv_cache_spec.page_size_bytes
assert tensor_size % page_size == 0
num_blocks = tensor_size // page_size
num_layers = len(kv_cache_config.kv_cache_tensors)
total_size = tensor_size * num_layers

assert len(kernel_block_sizes) == 1
kernel_block_size = kernel_block_sizes[0]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)

# prepend a num_layers dimension into the shape
kv_cache_shape = (num_layers, ) + kv_cache_shape

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
if not self.model_config.is_deepseek_mla:
new_kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order
if kv_cache_shape[i] != 2)
else:
new_kv_cache_shape_list = [
kv_cache_shape[i] for i in kv_cache_stride_order
]
new_kv_cache_shape_list[-1] = new_kv_cache_shape_list[-1] // 2
new_kv_cache_shape = tuple(new_kv_cache_shape_list)
logger.info("Allocating a cross layer KV cache of shape %s",
new_kv_cache_shape)

# allocate two contiguous buffer for all layers
cross_layers_k_cache = (torch.zeros(
total_size // 2, dtype=torch.int8,
device=device).view(kv_cache_spec.dtype).view(new_kv_cache_shape))
cross_layers_v_cache = (torch.zeros(
total_size // 2, dtype=torch.int8,
device=device).view(kv_cache_spec.dtype).view(new_kv_cache_shape))

if not self.model_config.is_deepseek_mla:
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
if kv_cache_shape[i] != 2
]
if len(new_kv_cache_shape) != len(kv_cache_shape):
inv_order = [i - 1 if i > 1 else i for i in inv_order]
else:
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]

permuted_k_cache = cross_layers_k_cache.permute(*inv_order)
permuted_v_cache = cross_layers_v_cache.permute(*inv_order)
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = (permuted_k_cache[i], permuted_v_cache[i])
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor

return kv_caches, cross_layers_k_cache, cross_layers_v_cache, attn_backend


@contextmanager
def _torch_cuda_wrapper():
Expand Down Expand Up @@ -3489,4 +3702,4 @@ def __init__(self, *args, **kwargs) -> None:
torch.cuda.Stream = torch.cuda.Stream
torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.stream = torch.npu.stream
torch.cuda.stream = torch.npu.stream
Loading