Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
26cbb1e
Add CanonicalKVCaches data classes for HMA KV cache representation
EtelisIBM Mar 23, 2026
544b689
Add WorkerConnectorInitializationData and initialize_worker_connector
EtelisIBM Mar 23, 2026
2268cf4
Add canonical KV cache allocation for HMA models
EtelisIBM Mar 23, 2026
7b6771b
Wire up canonical KV cache allocation in gpu_model_runner
EtelisIBM Mar 23, 2026
ac68b33
Add unit tests for canonical KV cache allocation
EtelisIBM Mar 23, 2026
ccdeea4
Fix mypy error: rename shadowed variable in use_canonical_kv_caches
EtelisIBM Mar 23, 2026
77f75b6
Move canonical KV cache dataclasses to connector base
EtelisIBM Mar 23, 2026
69d011a
Address CR: relax group count check and use per-group spec
EtelisIBM Mar 23, 2026
ab66101
Address CR: prioritize canonical path and scope initialize_worker_con…
EtelisIBM Mar 23, 2026
cc3f747
Address CR: merge loops in allocate_canonical_kv_caches
EtelisIBM Mar 23, 2026
3cc7485
Address CR: always call initialize_worker_connector
EtelisIBM Mar 23, 2026
6f84560
Validate tensor sizes in use_canonical_kv_caches
EtelisIBM Mar 25, 2026
0c05217
Refactor canonical KV cache allocation into single-pass loop
EtelisIBM Mar 30, 2026
a5282d9
Simplify canonical KV cache allocation using physical buffer
EtelisIBM Apr 2, 2026
d257b23
Move per-layer reshape logic into inner loop
EtelisIBM Apr 6, 2026
4be3987
Address CR: use single cross-layers int8 tensor for canonical KV caches
EtelisIBM Apr 12, 2026
a17dae7
Integrate canonical KV caches with offloading connector
EtelisIBM Apr 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions tests/v1/kv_connector/unit/test_canonical_kv_caches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for CanonicalKVCaches abstraction."""

from unittest.mock import patch

import pytest
import torch

from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CanonicalKVCaches,
SupportsHMA,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheTensor,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin,
)
from vllm.v1.worker.utils import AttentionGroup

# ---------------------------------------------------------------------------
# Mock backends and connectors
# ---------------------------------------------------------------------------

BLOCK_SIZE = 16
NUM_KV_HEADS = 4
HEAD_SIZE = 8
NUM_BLOCKS = 10
DTYPE = torch.float16


class MockFlashAttnBackend:
"""Mimics FlashAttention NHD layout."""

@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto"
):
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(include_num_layers_dimension=False):
if include_num_layers_dimension:
return (2, 0, 1, 3, 4, 5)
return (0, 1, 2, 3, 4)


class MockNoStrideOrderBackend:
"""Backend that does not support stride order."""

@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto"
):
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(include_num_layers_dimension=False):
raise NotImplementedError


class MockConnector(SupportsHMA):
prefer_cross_layer_blocks = True

def request_finished_all_groups(self, request, block_ids):
return False, None


class MockConnectorNoHMA:
prefer_cross_layer_blocks = True


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_full_attn_spec():
return FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_size=HEAD_SIZE,
dtype=DTYPE,
)


def _make_sw_spec(sliding_window=128):
return SlidingWindowSpec(
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_size=HEAD_SIZE,
dtype=DTYPE,
sliding_window=sliding_window,
)


def _make_hma_kv_cache_config():
"""HMA config: 3 groups, group_size=2, 2 KVCacheTensors."""
full_spec = _make_full_attn_spec()
sw_spec = _make_sw_spec()
page_size = full_spec.page_size_bytes

groups = [
KVCacheGroupSpec(["full.0", "full.1"], full_spec),
KVCacheGroupSpec(["sw.0", "sw.2"], sw_spec),
KVCacheGroupSpec(["sw.1", "sw.3"], sw_spec),
]
size = page_size * NUM_BLOCKS
tensors = [
KVCacheTensor(size=size, shared_by=["full.0", "sw.0", "sw.1"]),
KVCacheTensor(size=size, shared_by=["full.1", "sw.2", "sw.3"]),
]
return KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=tensors,
kv_cache_groups=groups,
)


def _make_attn_groups(backend_cls, kv_cache_config):
attn_groups = []
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
attn_groups.append(
[
AttentionGroup(
backend=backend_cls,
layer_names=group.layer_names,
kv_cache_spec=group.kv_cache_spec,
kv_cache_group_id=gid,
)
]
)
return attn_groups


def _patch_connector(connector):
return (
patch(
"vllm.v1.worker.kv_connector_model_runner_mixin.has_kv_transfer_group",
return_value=True,
),
patch(
"vllm.v1.worker.kv_connector_model_runner_mixin.get_kv_transfer_group",
return_value=connector,
),
)


def _use_canonical(config, attn_groups):
return KVConnectorModelRunnerMixin.use_canonical_kv_caches(
config, attn_groups, "auto"
)


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


@pytest.mark.cpu_test
def test_use_canonical_kv_caches_happy_path():
"""Should return True for a valid HMA model with compatible connector."""
config = _make_hma_kv_cache_config()
attn_groups = _make_attn_groups(MockFlashAttnBackend, config)
p1, p2 = _patch_connector(MockConnector())
with p1, p2:
assert _use_canonical(config, attn_groups) is True


@pytest.mark.cpu_test
@pytest.mark.parametrize(
"description,config_fn,backend,connector_fn,patch_no_connector",
[
(
"single_group",
lambda: KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
KVCacheTensor(
size=_make_full_attn_spec().page_size_bytes * NUM_BLOCKS,
shared_by=["layer0"],
)
],
kv_cache_groups=[KVCacheGroupSpec(["layer0"], _make_full_attn_spec())],
),
MockFlashAttnBackend,
MockConnector,
False,
),
(
"no_connector",
_make_hma_kv_cache_config,
MockFlashAttnBackend,
None,
True,
),
(
"no_hma_support",
_make_hma_kv_cache_config,
MockFlashAttnBackend,
MockConnectorNoHMA,
False,
),
(
"mamba_group",
lambda: KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(["attn.0"], _make_full_attn_spec()),
KVCacheGroupSpec(
["mamba.0"],
MambaSpec(
block_size=BLOCK_SIZE,
shapes=((16,), (16,)),
dtypes=(DTYPE,),
),
),
],
),
MockFlashAttnBackend,
MockConnector,
False,
),
(
"no_stride_order",
_make_hma_kv_cache_config,
MockNoStrideOrderBackend,
MockConnector,
False,
),
],
ids=lambda x: x if isinstance(x, str) else "",
)
def test_use_canonical_kv_caches_returns_false(
description, config_fn, backend, connector_fn, patch_no_connector
):
"""Should return False when any precondition is not met."""
config = config_fn()
attn_groups = _make_attn_groups(backend, config)

if patch_no_connector:
with patch(
"vllm.v1.worker.kv_connector_model_runner_mixin.has_kv_transfer_group",
return_value=False,
):
assert _use_canonical(config, attn_groups) is False
else:
p1, p2 = _patch_connector(connector_fn())
with p1, p2:
assert _use_canonical(config, attn_groups) is False


@pytest.mark.cpu_test
def test_allocate_canonical_kv_caches():
"""Allocation should produce correct kv_caches dict and
CanonicalKVCaches with contiguous per-block data."""
config = _make_hma_kv_cache_config()
attn_groups = _make_attn_groups(MockFlashAttnBackend, config)

kv_caches, canonical = KVConnectorModelRunnerMixin.allocate_canonical_kv_caches(
config, attn_groups, "auto", torch.device("cpu"), [BLOCK_SIZE]
)

assert isinstance(canonical, CanonicalKVCaches)

# -- kv_caches dict: all 6 layers present with correct shapes
expected_shape = (2, NUM_BLOCKS, BLOCK_SIZE, NUM_KV_HEADS, HEAD_SIZE)
assert len(kv_caches) == 6
for name in ["full.0", "full.1", "sw.0", "sw.1", "sw.2", "sw.3"]:
assert kv_caches[name].shape == expected_shape

# layers sharing a position point to the same memory
assert kv_caches["full.0"].data_ptr() == kv_caches["sw.0"].data_ptr()
assert kv_caches["full.1"].data_ptr() == kv_caches["sw.2"].data_ptr()

# -- single cross-layers tensor: (num_blocks, cross_layer_page_size) int8
assert len(canonical.tensors) == 1
bt = canonical.tensors[0]
per_position_page = 2 * BLOCK_SIZE * NUM_KV_HEADS * HEAD_SIZE * DTYPE.itemsize
group_size = len(config.kv_cache_tensors)
cross_layer_page = per_position_page * group_size
assert bt.tensor.shape == (NUM_BLOCKS, cross_layer_page)
assert bt.tensor.dtype == torch.int8
assert bt.page_size_bytes == cross_layer_page

# each row is contiguous and covers all positions for one block
assert bt.tensor.is_contiguous()

# -- group_data_refs: 3 groups, each with 2 layers
assert len(canonical.group_data_refs) == 3
full_page = config.kv_cache_groups[0].kv_cache_spec.page_size_bytes
for refs in canonical.group_data_refs:
assert len(refs) == 2
assert [r.tensor_idx for r in refs] == [0, 0]
for ref in refs:
assert ref.page_size_bytes == full_page
Loading