diff --git a/tests/v1/kv_connector/unit/test_canonical_kv_caches.py b/tests/v1/kv_connector/unit/test_canonical_kv_caches.py new file mode 100644 index 000000000000..4bf8307c72c2 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_canonical_kv_caches.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for CanonicalKVCaches abstraction.""" + +from unittest.mock import patch + +import pytest +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + CanonicalKVCaches, + SupportsHMA, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, + MambaSpec, + SlidingWindowSpec, +) +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin, +) +from vllm.v1.worker.utils import AttentionGroup + +# --------------------------------------------------------------------------- +# Mock backends and connectors +# --------------------------------------------------------------------------- + +BLOCK_SIZE = 16 +NUM_KV_HEADS = 4 +HEAD_SIZE = 8 +NUM_BLOCKS = 10 +DTYPE = torch.float16 + + +class MockFlashAttnBackend: + """Mimics FlashAttention NHD layout.""" + + @staticmethod + def get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto" + ): + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order(include_num_layers_dimension=False): + if include_num_layers_dimension: + return (2, 0, 1, 3, 4, 5) + return (0, 1, 2, 3, 4) + + +class MockNoStrideOrderBackend: + """Backend that does not support stride order.""" + + @staticmethod + def get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto" + ): + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order(include_num_layers_dimension=False): + raise NotImplementedError + + +class MockConnector(SupportsHMA): + prefer_cross_layer_blocks = True + + def request_finished_all_groups(self, request, block_ids): + return False, None + + +class MockConnectorNoHMA: + prefer_cross_layer_blocks = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_full_attn_spec(): + return FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_size=HEAD_SIZE, + dtype=DTYPE, + ) + + +def _make_sw_spec(sliding_window=128): + return SlidingWindowSpec( + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_size=HEAD_SIZE, + dtype=DTYPE, + sliding_window=sliding_window, + ) + + +def _make_hma_kv_cache_config(): + """HMA config: 3 groups, group_size=2, 2 KVCacheTensors.""" + full_spec = _make_full_attn_spec() + sw_spec = _make_sw_spec() + page_size = full_spec.page_size_bytes + + groups = [ + KVCacheGroupSpec(["full.0", "full.1"], full_spec), + KVCacheGroupSpec(["sw.0", "sw.2"], sw_spec), + KVCacheGroupSpec(["sw.1", "sw.3"], sw_spec), + ] + size = page_size * NUM_BLOCKS + tensors = [ + KVCacheTensor(size=size, shared_by=["full.0", "sw.0", "sw.1"]), + KVCacheTensor(size=size, shared_by=["full.1", "sw.2", "sw.3"]), + ] + return KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=tensors, + kv_cache_groups=groups, + ) + + +def _make_attn_groups(backend_cls, kv_cache_config): + attn_groups = [] + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + attn_groups.append( + [ + AttentionGroup( + backend=backend_cls, + layer_names=group.layer_names, + kv_cache_spec=group.kv_cache_spec, + kv_cache_group_id=gid, + ) + ] + ) + return attn_groups + + +def _patch_connector(connector): + return ( + patch( + "vllm.v1.worker.kv_connector_model_runner_mixin.has_kv_transfer_group", + return_value=True, + ), + patch( + "vllm.v1.worker.kv_connector_model_runner_mixin.get_kv_transfer_group", + return_value=connector, + ), + ) + + +def _use_canonical(config, attn_groups): + return KVConnectorModelRunnerMixin.use_canonical_kv_caches( + config, attn_groups, "auto" + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_test +def test_use_canonical_kv_caches_happy_path(): + """Should return True for a valid HMA model with compatible connector.""" + config = _make_hma_kv_cache_config() + attn_groups = _make_attn_groups(MockFlashAttnBackend, config) + p1, p2 = _patch_connector(MockConnector()) + with p1, p2: + assert _use_canonical(config, attn_groups) is True + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "description,config_fn,backend,connector_fn,patch_no_connector", + [ + ( + "single_group", + lambda: KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor( + size=_make_full_attn_spec().page_size_bytes * NUM_BLOCKS, + shared_by=["layer0"], + ) + ], + kv_cache_groups=[KVCacheGroupSpec(["layer0"], _make_full_attn_spec())], + ), + MockFlashAttnBackend, + MockConnector, + False, + ), + ( + "no_connector", + _make_hma_kv_cache_config, + MockFlashAttnBackend, + None, + True, + ), + ( + "no_hma_support", + _make_hma_kv_cache_config, + MockFlashAttnBackend, + MockConnectorNoHMA, + False, + ), + ( + "mamba_group", + lambda: KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(["attn.0"], _make_full_attn_spec()), + KVCacheGroupSpec( + ["mamba.0"], + MambaSpec( + block_size=BLOCK_SIZE, + shapes=((16,), (16,)), + dtypes=(DTYPE,), + ), + ), + ], + ), + MockFlashAttnBackend, + MockConnector, + False, + ), + ( + "no_stride_order", + _make_hma_kv_cache_config, + MockNoStrideOrderBackend, + MockConnector, + False, + ), + ], + ids=lambda x: x if isinstance(x, str) else "", +) +def test_use_canonical_kv_caches_returns_false( + description, config_fn, backend, connector_fn, patch_no_connector +): + """Should return False when any precondition is not met.""" + config = config_fn() + attn_groups = _make_attn_groups(backend, config) + + if patch_no_connector: + with patch( + "vllm.v1.worker.kv_connector_model_runner_mixin.has_kv_transfer_group", + return_value=False, + ): + assert _use_canonical(config, attn_groups) is False + else: + p1, p2 = _patch_connector(connector_fn()) + with p1, p2: + assert _use_canonical(config, attn_groups) is False + + +@pytest.mark.cpu_test +def test_allocate_canonical_kv_caches(): + """Allocation should produce correct kv_caches dict and + CanonicalKVCaches with contiguous per-block data.""" + config = _make_hma_kv_cache_config() + attn_groups = _make_attn_groups(MockFlashAttnBackend, config) + + kv_caches, canonical = KVConnectorModelRunnerMixin.allocate_canonical_kv_caches( + config, attn_groups, "auto", torch.device("cpu"), [BLOCK_SIZE] + ) + + assert isinstance(canonical, CanonicalKVCaches) + + # -- kv_caches dict: all 6 layers present with correct shapes + expected_shape = (2, NUM_BLOCKS, BLOCK_SIZE, NUM_KV_HEADS, HEAD_SIZE) + assert len(kv_caches) == 6 + for name in ["full.0", "full.1", "sw.0", "sw.1", "sw.2", "sw.3"]: + assert kv_caches[name].shape == expected_shape + + # layers sharing a position point to the same memory + assert kv_caches["full.0"].data_ptr() == kv_caches["sw.0"].data_ptr() + assert kv_caches["full.1"].data_ptr() == kv_caches["sw.2"].data_ptr() + + # -- single cross-layers tensor: (num_blocks, cross_layer_page_size) int8 + assert len(canonical.tensors) == 1 + bt = canonical.tensors[0] + per_position_page = 2 * BLOCK_SIZE * NUM_KV_HEADS * HEAD_SIZE * DTYPE.itemsize + group_size = len(config.kv_cache_tensors) + cross_layer_page = per_position_page * group_size + assert bt.tensor.shape == (NUM_BLOCKS, cross_layer_page) + assert bt.tensor.dtype == torch.int8 + assert bt.page_size_bytes == cross_layer_page + + # each row is contiguous and covers all positions for one block + assert bt.tensor.is_contiguous() + + # -- group_data_refs: a single data reference per group (3 groups) + assert len(canonical.group_data_refs) == 3 + full_page = config.kv_cache_groups[0].kv_cache_spec.page_size_bytes + for refs in canonical.group_data_refs: + assert len(refs) == 1 + assert refs[0].tensor_idx == 0 + assert refs[0].page_size_bytes == full_page * group_size diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ef143cba7fb5..a028f68da76d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -43,6 +43,7 @@ import enum from abc import ABC, abstractmethod from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal import torch @@ -167,6 +168,68 @@ def aggregate( pass +@dataclass +class KVCacheBlockTensor: + """ + A canonicalized KV cache tensor whose first dimension is num_blocks. + + For attention backends where the raw tensor has num_blocks at a + non-leading physical dimension (e.g. FlashAttention's + (2, num_blocks, ...) layout), the tensor is split so that each + resulting KVCacheBlockTensor starts with (num_blocks, ...). + """ + + # The KV cache tensor with shape (num_blocks, ...) + tensor: torch.Tensor + # The (possibly padded) page size per block in bytes + page_size_bytes: int + + +@dataclass +class KVCacheBlockDataRef: + """ + Per-layer (or group of layers) reference to a specific (by index) + KVCacheBlockTensor and records the un-padded page size used by that layer. + """ + + # Index into the list of KVCacheBlockTensor objects + tensor_idx: int + # The un-padded page size per block in bytes + page_size_bytes: int + + +@dataclass +class CanonicalKVCaches: + """ + Canonicalized block-level representation of the KV caches. + + Composed of: + - Unique list of KV cache data tensors, + each with shape (num_blocks, page_size_in_bytes) and int8 dtype. + - Per-group data references of the tensors. + i.e. how each KV cache group maps to the tensors. + """ + + # Ordered list of unique block tensors, each with shape + # (num_blocks, ...). + tensors: list[KVCacheBlockTensor] + # Per-KV-cache-group list of data references that map each layer + # in the group to the appropriate entry in the tensors list. + group_data_refs: list[list[KVCacheBlockDataRef]] + + +@dataclass +class WorkerConnectorInitializationData: + """Data passed to initialize_worker_connector(). + + Designed to be extended without breaking existing connectors: new optional + fields can be added here and connectors that don't need them simply ignore + the extra data. + """ + + canonical_kv_caches: CanonicalKVCaches | None = field(default=None) + + class KVConnectorBase_V1(ABC): """ Base class for KV connectors. @@ -288,6 +351,24 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ return + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + """ + Initialize per-worker connector state after model loading. + + Called once by the GPU model runner after the model and KV caches + are ready. The default implementation is a no-op; connectors that + need additional initialization should override this method. + + Args: + initialization_data: data bag containing optional fields such + as ``canonical_kv_caches``. New fields may be added in + future versions without breaking existing connectors. + """ + return + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): """ Handle preempted requests or evicted blocks BEFORE they are overwritten. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 4ef8f0ac9c90..e35e7bcb5b29 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -32,6 +32,9 @@ if TYPE_CHECKING: from vllm.distributed.kv_events import KVCacheEvent + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + WorkerConnectorInitializationData, + ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig @@ -219,6 +222,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) + def initialize_worker_connector( + self, + initialization_data: "WorkerConnectorInitializationData", + ) -> None: + for c in self._connectors: + c.initialize_worker_connector(initialization_data) + # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0ba47f945a7..094130489c4d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.distributed.kv_transfer.kv_connector.v1.base import CanonicalKVCaches from vllm.distributed.parallel_state import ( get_dcp_group, get_pp_group, @@ -498,6 +499,7 @@ def __init__( # Initialize in initialize_kv_cache_tensors self.cross_layers_kv_cache: torch.Tensor | None = None self.cross_layers_attn_backend: type[AttentionBackend] | None = None + self.canonical_kv_caches: CanonicalKVCaches | None = None # indexes: [kv_cache_group_id][attn_group] self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig @@ -6738,7 +6740,15 @@ def initialize_kv_cache_tensors( # Try creating KV caches optimized for kv-connector transfers cache_dtype = self.cache_config.cache_dtype - if self.use_uniform_kv_cache(self.attn_groups, cache_dtype): + if self.use_canonical_kv_caches(kv_cache_config, self.attn_groups, cache_dtype): + kv_caches, self.canonical_kv_caches = self.allocate_canonical_kv_caches( + kv_cache_config, + self.attn_groups, + cache_dtype, + self.device, + kernel_block_sizes, + ) + elif self.use_uniform_kv_cache(self.attn_groups, cache_dtype): kv_caches, cross_layers_kv_cache, attn_backend = ( self.allocate_uniform_kv_caches( kv_cache_config, @@ -6854,13 +6864,27 @@ def initialize_kv_cache( if has_kv_transfer_group() and not is_profiling: kv_transfer_group = get_kv_transfer_group() - if self.cross_layers_kv_cache is not None: + if self.canonical_kv_caches is not None: + # canonical path: kv_caches already registered via + # initialize_worker_connector below + pass + elif self.cross_layers_kv_cache is not None: assert self.cross_layers_attn_backend is not None kv_transfer_group.register_cross_layers_kv_cache( self.cross_layers_kv_cache, self.cross_layers_attn_backend ) else: kv_transfer_group.register_kv_caches(kv_caches) + + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + WorkerConnectorInitializationData, + ) + + kv_transfer_group.initialize_worker_connector( + WorkerConnectorInitializationData( + canonical_kv_caches=self.canonical_kv_caches, + ) + ) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) def _get_attention_kv_cache_gid(self) -> int: diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 4fc1aff94fed..58967a8d0006 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -15,10 +15,19 @@ from vllm.config.cache import CacheDType from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + CanonicalKVCaches, + KVCacheBlockDataRef, + KVCacheBlockTensor, + supports_hma, +) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheConfig, +) from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, @@ -281,3 +290,217 @@ def allocate_uniform_kv_caches( kv_caches[layer_name] = tensor return kv_caches, cross_layers_kv_cache, attn_backend + + @staticmethod + def use_canonical_kv_caches( + kv_cache_config: KVCacheConfig, + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + ) -> bool: + """ + Determines whether a contiguous canonical KV cache should be + allocated for HMA (Hybrid Multi-Attention) models. + + A canonical layout allocates a single contiguous buffer where, + for a given block number, the KV data for all layers is + contiguous. This allows efficient KV transfer of per-block data. + + This layout will only be applied given 5 conditions: + 1. A KV connector is configured and prefers cross-layer blocks. + 2. The connector supports HMA. + 3. The model has multiple KV cache groups (HMA). + 4. All groups use AttentionSpec with uniform page size. + 5. All backends share the same stride order that places + num_blocks first in the physical layout. + + Args: + kv_cache_config: The KV cache configuration. + attn_groups: The attention groups (indexed [group_id][...]). + cache_dtype: The KV cache dtype. + Returns: + True if we should use contiguous canonical allocation. + """ + if not has_kv_transfer_group(): + return False + if not get_kv_transfer_group().prefer_cross_layer_blocks: + return False + + # The connector must support HMA + if not supports_hma(get_kv_transfer_group()): + return False + if len(kv_cache_config.kv_cache_groups) < 1: + return False + + # Currently, all groups must use AttentionSpec with uniform page size + # We plan to gradually relax this requirement to support other cases + page_sizes: set[int] = set() + for group in kv_cache_config.kv_cache_groups: + if not isinstance(group.kv_cache_spec, AttentionSpec): + return False + page_sizes.add(group.kv_cache_spec.page_size_bytes) + if len(page_sizes) != 1: + return False + + # all kv cache tensors must have the same size so that + # they can share a single contiguous buffer + tensor_sizes = set(t.size for t in kv_cache_config.kv_cache_tensors) + if len(tensor_sizes) != 1: + return False + + # all backends must agree on the same stride order + common_stride_order: tuple[int, ...] | None = None + for groups in attn_groups: + for attn_group in groups: + attn_backend = attn_group.backend + spec = attn_group.kv_cache_spec + assert isinstance(spec, AttentionSpec) + kv_cache_shape = attn_backend.get_kv_cache_shape( + 1234, + spec.block_size, + spec.num_kv_heads, + spec.head_size, + cache_dtype_str=cache_dtype, + ) + + try: + stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + except (AttributeError, NotImplementedError): + return False + + if len(stride_order) != len(kv_cache_shape) + 1: + return False + + # num_blocks must be the leading physical dimension. + # +1 accounts for the prepended group_size dimension. + if stride_order[0] != kv_cache_shape.index(1234) + 1: + return False + + if common_stride_order is None: + common_stride_order = stride_order + elif stride_order != common_stride_order: + return False + + return common_stride_order is not None + + @staticmethod + def allocate_canonical_kv_caches( + kv_cache_config: KVCacheConfig, + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + device: torch.device, + kernel_block_sizes: list[int], + ) -> tuple[dict[str, torch.Tensor], CanonicalKVCaches]: + """ + Allocates contiguous KV caches for HMA models where all + groups share the same page size. + + Follows the same pattern as allocate_uniform_kv_caches: a single + flat buffer reshaped per the backend stride order. The physical + layout places num_blocks as the leading dimension, giving + per-block cross-layer contiguity. + + This function assumes use_canonical_kv_caches() returned True. + + Args: + kv_cache_config: The KV cache config. + attn_groups: The attention groups (indexed [group_id][...]). + cache_dtype: The KV cache dtype. + device: The torch device to allocate on. + kernel_block_sizes: The kernel block sizes per KV cache group. + Returns: + A tuple (kv_caches, canonical_kv_caches) where: + kv_caches is a dict mapping between layer names to their + corresponding memory buffer for KV cache. + canonical_kv_caches is the CanonicalKVCaches wrapping + for the connector. + """ + first_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + assert isinstance(first_spec, AttentionSpec) + + tensor_sizes = set(t.size for t in kv_cache_config.kv_cache_tensors) + assert len(tensor_sizes) == 1 + tensor_size = tensor_sizes.pop() + + page_size = first_spec.page_size_bytes + num_blocks = kv_cache_config.num_blocks + assert tensor_size == page_size * num_blocks + group_size = len(kv_cache_config.kv_cache_tensors) + total_size = tensor_size * group_size + + logger.info("Allocating canonical KV cache: group_size=%d", group_size) + + # allocate one flat contiguous buffer + contiguous_buffer_bytes = torch.zeros( + total_size, dtype=torch.int8, device=device + ) + + # build layer_name -> group_idx mapping + layer_to_group_idx: dict[str, int] = {} + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in group.layer_names: + layer_to_group_idx[layer_name] = gid + + kv_caches: dict[str, torch.Tensor] = {} + group_data_refs: list[list[KVCacheBlockDataRef]] = [ + [KVCacheBlockDataRef(tensor_idx=0, page_size_bytes=page_size * group_size)] + for _ in kv_cache_config.kv_cache_groups + ] + + # Single cross-layers canonical tensor: (num_blocks, page_size_bytes) + # as raw int8, where page_size_bytes covers all positions. + cross_layer_page_size = page_size * group_size + cross_layers_tensor = contiguous_buffer_bytes.view( + num_blocks, cross_layer_page_size + ) + + for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors): + # Per-layer reshape: compute shape from each layer's own + # spec and backend, then view the flat buffer accordingly. + for layer_name in kv_cache_tensor.shared_by: + gid = layer_to_group_idx[layer_name] + spec = kv_cache_config.kv_cache_groups[gid].kv_cache_spec + assert isinstance(spec, AttentionSpec) + + attn_backend = attn_groups[gid][0].backend + kernel_block_size = kernel_block_sizes[gid] + num_blocks_per_kv_block = spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + kv_cache_shape = attn_backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + spec.num_kv_heads, + spec.head_size, + cache_dtype_str=cache_dtype, + ) + + # prepend a group_size dimension into the shape + full_shape = (group_size,) + kv_cache_shape + + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + assert len(kv_cache_stride_order) == len(full_shape) + + physical_shape = tuple(full_shape[j] for j in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(j) + for j in range(len(kv_cache_stride_order)) + ] + + typed_buffer = contiguous_buffer_bytes.view(spec.dtype).view( + physical_shape + ) + kv_caches[layer_name] = typed_buffer.permute(*inv_order)[i] + + return kv_caches, CanonicalKVCaches( + tensors=[ + KVCacheBlockTensor( + tensor=cross_layers_tensor, + page_size_bytes=cross_layer_page_size, + ) + ], + group_data_refs=group_data_refs, + )