Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/v1/kv_connector/unit/test_kv_cache_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


def test_mla_backend_rejects_cross_layer_kv_cache():
"""MLA backends return identity permutation (layers dim first)
to signal cross-layer KV cache is unsupported."""
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
)

stride_order = MLACommonBackend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert stride_order == (0, 1, 2, 3)
assert stride_order[0] == 0 # layers dim first => no cross-layer
assert MLACommonBackend.get_kv_cache_stride_order(
include_num_layers_dimension=False
) == (0, 1, 2)


def test_deepseek_v32_indexer_rejects_cross_layer_kv_cache():
"""DeepseekV32Indexer returns identity permutation (layers dim first)
to signal cross-layer KV cache is unsupported."""
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
)

stride_order = DeepseekV32IndexerBackend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert stride_order == (0, 1, 2, 3)
assert stride_order[0] == 0 # layers dim first => no cross-layer
assert DeepseekV32IndexerBackend.get_kv_cache_stride_order(
include_num_layers_dimension=False
) == (0, 1, 2)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
Expand Down Expand Up @@ -601,7 +601,9 @@ def _register_handlers(
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config, Attention, layer_names
self.spec.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,10 +1142,12 @@ def get_kv_cache_shape(
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, head_size)
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
if include_num_layers_dimension:
# MLA kernels require contiguous per-layer KV cache views.
# Identity permutation keeps num_layers first in physical
# layout, signaling cross-layer allocation is unsupported.
return (0, 1, 2, 3)
return (0, 1, 2)

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
# DeepseekV32Indexer kernels do not support cross-layer
# KV cache layout. Identity permutation keeps num_layers
# first, signaling incompatibility.
return (0, 1, 2, 3)
return (0, 1, 2)

Expand Down
9 changes: 7 additions & 2 deletions vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,13 @@ def use_uniform_kv_cache(
except (AttributeError, NotImplementedError):
return False

# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
# check that attention backend includes a layers dimension
if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
return False

# stride_order[0] == 0 means num_layers stays first in physical
# layout (identity permutation), so cross-layer is unsupported.
return kv_cache_stride_order[0] != 0

@staticmethod
def allocate_uniform_kv_caches(
Expand Down
Loading