Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7dc93ee
[Feature] Enable uniform KV cache allocation for multi-group HMA models
EtelisIBM Feb 11, 2026
7b28e7f
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 11, 2026
8fab499
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 15, 2026
6c280d3
[Feature] Generalize uniform KV cache allocation for heterogeneous mo…
EtelisIBM Feb 16, 2026
bc8cf4c
[Feature] Use CrossLayerGroup in GPU model runner
EtelisIBM Feb 16, 2026
0af5ee3
[Test] Update uniform KV cache tests for multi-group and Mamba support
EtelisIBM Feb 16, 2026
6340355
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 16, 2026
4616657
[Cleanup] Remove dead dtype field from CrossLayerGroup
EtelisIBM Feb 17, 2026
1d538ac
[Feature] Add TP-aware cross-layer KV cache layout
EtelisIBM Feb 19, 2026
1f66cf8
[Feature] Pass TP flag to uniform KV cache allocation
EtelisIBM Feb 19, 2026
c904e6f
[Test] Add TP layout tests for uniform KV cache
EtelisIBM Feb 19, 2026
d69c498
[Feature] Auto-detect cross-layer KV cache layout from backend stride…
EtelisIBM Feb 20, 2026
2a97165
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 23, 2026
31e13a1
[Refactor] Remove unused spec and backend attributes from CrossLayerG…
EtelisIBM Feb 23, 2026
e26e7af
[Refactor] Remove dead solo grouping key from cross-layer allocation
EtelisIBM Feb 23, 2026
33fe8ad
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 23, 2026
76eeb84
[Feature] Unified multi-group cross-layer KV cache connector API
EtelisIBM Feb 23, 2026
dccb08e
[Bugfix] Fix register_kv_caches override signatures for mypy compliance
EtelisIBM Feb 23, 2026
621745e
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Feb 23, 2026
5c41e0d
[Feature] Restore register_cross_layers_kv_cache on connectors
EtelisIBM Feb 24, 2026
b986a58
[Feature] Add hybrid KV cache registration API with dual-path allocation
EtelisIBM Feb 24, 2026
2e5da0b
[Refactor] Remove connector-side register_cross_layers_kv_cache overr…
EtelisIBM Feb 24, 2026
62594d2
[Refactor] Revert connector changes to match main
EtelisIBM Feb 24, 2026
ac609ee
[Bugfix] Guard os.sched_setaffinity with hasattr check
EtelisIBM Feb 24, 2026
3ee1121
[Fix] Make KVCacheTopology.num_layers_dim optional for isolated tensors
EtelisIBM Mar 3, 2026
e0fc044
[Fix] Address review feedback on cross-layer grouping and topology
EtelisIBM Mar 3, 2026
325c041
[Test] Extend uniform KV cache tests for topology and isolation
EtelisIBM Mar 3, 2026
789c162
[Refactor] Improve naming and docstrings in kv_connector_model_runner…
EtelisIBM Mar 8, 2026
8d79bde
[Refactor] Improve register_hybrid_kv_caches docstring in base.py
EtelisIBM Mar 8, 2026
25c82c1
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Mar 8, 2026
c4121cd
[Fix] Use base stride order as fallback in legacy uniform allocation
EtelisIBM Mar 8, 2026
321acf8
Merge branch 'main' into itay/hma-uniform-kv-cache
Etelis Mar 8, 2026
e0e33bf
[Feature] Add KVCacheTensorReference/KVCacheDataReference API for con…
EtelisIBM Mar 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from unittest.mock import MagicMock

import pytest
import torch

from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
Expand All @@ -20,7 +19,6 @@
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
Expand Down Expand Up @@ -175,10 +173,7 @@ def __init__(
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)

# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_cross_layers_kv_cache(
kv_cache=torch.empty(0),
attn_backend=FlashAttentionBackend,
)
self.worker_connector.register_kv_caches(kv_caches={})

# extract connector of scheduler
scheduler_connector = self.scheduler.connector
Expand Down
332 changes: 332 additions & 0 deletions tests/v1/kv_connector/unit/test_uniform_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Tests for uniform cross-layer KV cache allocation."""

from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheTensor,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin,
)
from vllm.v1.worker.utils import AttentionGroup

pytestmark = pytest.mark.cpu_test

MODULE = "vllm.v1.worker.kv_connector_model_runner_mixin"

BLOCK_SIZE = 16
NUM_KV_HEADS = 4
HEAD_SIZE = 64


class _MockBackend:
"""NHD backend: layers dim sits right after blocks."""

@staticmethod
def get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str=None
):
return (num_blocks, 2, num_kv_heads, block_size, head_size)

@staticmethod
def get_kv_cache_stride_order(include_num_layers_dimension=False):
if include_num_layers_dimension:
# logical_with_layers: (L, B, 2, H, bs, d)
# physical: (B, L, 2, H, bs, d)
return (1, 0, 2, 3, 4, 5)
return (0, 1, 2, 3, 4)


class _MockHNDBackend(_MockBackend):
"""HND backend: heads come before layers in physical order."""

@staticmethod
def get_kv_cache_stride_order(include_num_layers_dimension=False):
if include_num_layers_dimension:
# logical_with_layers: (L, B, 2, H, bs, d)
# physical: (B, H, L, 2, bs, d)
return (1, 3, 0, 2, 4, 5)
return (0, 1, 2, 3, 4)


class _MockBlocksNotFirstBackend(_MockBackend):
@staticmethod
def get_kv_cache_stride_order(include_num_layers_dimension=False):
if include_num_layers_dimension:
return (3, 1, 0, 2, 4, 5)
return (0, 1, 2, 3, 4)


class _MockMambaBackend:
@staticmethod
def get_kv_cache_shape(*a, **kw):
raise NotImplementedError

@staticmethod
def get_kv_cache_stride_order(*a, **kw):
raise NotImplementedError


def _make_group(
group_id=0,
spec_cls=FullAttentionSpec,
layer_names=None,
num_kv_heads=NUM_KV_HEADS,
backend=_MockBackend,
):
kwargs = dict(
block_size=BLOCK_SIZE,
num_kv_heads=num_kv_heads,
head_size=HEAD_SIZE,
dtype=torch.float16,
)
if spec_cls is SlidingWindowSpec:
kwargs["sliding_window"] = 128
return [
AttentionGroup(
backend=backend,
layer_names=layer_names or [],
kv_cache_spec=spec_cls(**kwargs),
kv_cache_group_id=group_id,
)
]


def _allocate(
num_blocks,
num_layers,
backend=_MockBackend,
prefix="l",
kernel_block_sizes=None,
attn_groups=None,
):
"""Shorthand for allocate_hybrid_kv_caches with FullAttentionSpec."""
spec = FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_size=HEAD_SIZE,
dtype=torch.float16,
)
names = [f"{prefix}.{i}" for i in range(num_layers)]
if attn_groups is None:
attn_groups = [_make_group(group_id=0, layer_names=names, backend=backend)]
if kernel_block_sizes is None:
kernel_block_sizes = [BLOCK_SIZE] * len(attn_groups)
return KVConnectorModelRunnerMixin.allocate_hybrid_kv_caches(
kv_cache_config=KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=spec.page_size_bytes * num_blocks,
shared_by=[n],
)
for n in names
],
kv_cache_groups=[],
),
attn_groups=attn_groups,
cache_dtype="auto",
device=torch.device("cpu"),
kernel_block_sizes=kernel_block_sizes,
)


def _use_uniform(attn_groups):
mock = MagicMock()
mock.prefer_cross_layer_blocks = True
with (
patch(f"{MODULE}.has_kv_transfer_group", return_value=True),
patch(f"{MODULE}.get_kv_transfer_group", return_value=mock),
):
return KVConnectorModelRunnerMixin.use_uniform_kv_cache(attn_groups, "auto")


def test_multi_group_compatible():
"""Two groups (full + sliding window) with same shape are compatible."""
assert _use_uniform(
[
_make_group(group_id=0, spec_cls=FullAttentionSpec),
_make_group(group_id=1, spec_cls=SlidingWindowSpec),
]
)


def test_different_page_sizes_accepted():
"""Groups with different page_size_bytes are accepted (separate groups)."""
assert _use_uniform(
[
_make_group(num_kv_heads=4, group_id=0),
_make_group(num_kv_heads=8, group_id=1),
]
)


def test_allocate_multi_group_shared_tensors():
"""Allocation shares memory across groups at each position."""
num_positions = 4
spec = FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_size=HEAD_SIZE,
dtype=torch.float16,
)

kv_cache_config = KVCacheConfig(
num_blocks=4,
kv_cache_tensors=[
KVCacheTensor(
size=spec.page_size_bytes * 4, shared_by=[f"full.{i}", f"sw.{i}"]
)
for i in range(num_positions)
],
kv_cache_groups=[],
)

kv_caches, cross_layer_groups = (
KVConnectorModelRunnerMixin.allocate_hybrid_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
_make_group(group_id=0, layer_names=[f"full.{i}" for i in range(4)]),
_make_group(
group_id=1,
spec_cls=SlidingWindowSpec,
layer_names=[f"sw.{i}" for i in range(4)],
),
],
cache_dtype="auto",
device=torch.device("cpu"),
kernel_block_sizes=[BLOCK_SIZE, BLOCK_SIZE],
)
)

assert len(kv_caches) == 8
assert len(cross_layer_groups) == 1
# NHD backend -- default layout (blocks, layers, page_size)
assert cross_layer_groups[0].tensor.ndim == 3
for i in range(num_positions):
assert kv_caches[f"full.{i}"].data_ptr() == kv_caches[f"sw.{i}"].data_ptr()


def test_mamba_allocation():
"""Mamba layers produce list[Tensor] views with data isolation."""
spec = MambaSpec(
block_size=BLOCK_SIZE,
shapes=((4, 2), (8,)),
dtypes=(torch.float32, torch.float32),
)
nb = 2

kv, groups = KVConnectorModelRunnerMixin.allocate_hybrid_kv_caches(
kv_cache_config=KVCacheConfig(
num_blocks=nb,
kv_cache_tensors=[
KVCacheTensor(size=spec.page_size_bytes * nb, shared_by=[f"m.{i}"])
for i in range(2)
],
kv_cache_groups=[],
),
attn_groups=[
[
AttentionGroup(
backend=_MockMambaBackend,
layer_names=["m.0", "m.1"],
kv_cache_spec=spec,
kv_cache_group_id=0,
)
]
],
cache_dtype="auto",
device=torch.device("cpu"),
kernel_block_sizes=[BLOCK_SIZE],
)

assert len(groups) == 1
# Mamba -- default layout (blocks, layers, page_size)
assert groups[0].tensor.ndim == 3
for n in ["m.0", "m.1"]:
assert isinstance(kv[n], list) and len(kv[n]) == 2
assert kv[n][0].shape == (nb, 4, 2)
assert kv[n][1].shape == (nb, 8)

# Data isolation: writing to one layer shouldn't affect the other
kv["m.0"][0][0].fill_(42.0)
kv["m.1"][0][1].fill_(99.0)
assert torch.all(kv["m.0"][0][0] == 42.0)
assert torch.all(kv["m.1"][0][1] == 99.0)
assert torch.all(kv["m.1"][0][0] == 0.0)
assert torch.all(kv["m.0"][0][1] == 0.0)


def test_hnd_backend_extracts_heads():
"""HND backend: heads before layers in physical order."""
nb, num_layers = 4, 3
kv_caches, groups = _allocate(
nb, num_layers, backend=_MockHNDBackend, prefix="layer"
)

assert len(groups) == 1
group = groups[0]

# HND backend -- heads-first layout (blocks, heads, layers, per_head_page)
per_head_page = group.page_size_bytes // NUM_KV_HEADS
assert group.tensor.shape == (nb, NUM_KV_HEADS, num_layers, per_head_page)

expected = (nb, 2, NUM_KV_HEADS, BLOCK_SIZE, HEAD_SIZE)
for i in range(num_layers):
assert kv_caches[f"layer.{i}"].shape == expected

topo = group.topologies[0]
assert topo.num_blocks_dim == 0
assert topo.num_layers_dim == 2
assert topo.num_heads_dim == 1


def test_hnd_head_contiguity():
"""One block + one head across all layers is contiguous in HND layout."""
_, groups = _allocate(4, 2, backend=_MockHNDBackend)

group = groups[0]
block_head = group.tensor[0, 0] # (layers, per_head_page)
assert block_head.is_contiguous()
# H varies slower than layers
assert group.tensor.stride(1) > group.tensor.stride(2)


def test_nhd_backend_uses_default_layout():
"""NHD backend places layers right after blocks -- default layout."""
nb, num_layers = 4, 2
kv_caches, groups = _allocate(nb, num_layers, backend=_MockBackend)

assert len(groups) == 1
group = groups[0]
# NHD backend -- default layout (blocks, layers, page_size)
assert group.tensor.shape == (nb, num_layers, group.page_size_bytes)

expected = (nb, 2, NUM_KV_HEADS, BLOCK_SIZE, HEAD_SIZE)
for i in range(num_layers):
assert kv_caches[f"l.{i}"].shape == expected


def test_blocks_not_first_is_isolated():
nb, num_layers = 4, 2
_, groups = _allocate(
nb,
num_layers,
backend=_MockBlocksNotFirstBackend,
prefix="nf",
)

assert len(groups) == num_layers
for group in groups:
topo = group.topologies[0]
assert topo.num_layers_dim is None
2 changes: 1 addition & 1 deletion vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def __post_init__(self):
# In the default non-cross layers layout the block_size position
# is logical while in the cross layers case it is the physical
# position. This matches the shape of the actual kv cache tensors
# passed at register_kv_caches()/register_cross_layers_kv_cache()
# passed at register_kv_caches()
block_size_position = kv_cache_shape.index(_MOCK_BLOCK_SIZE)

assert block_size_position is not None
Expand Down
Loading