Skip to content

Canonical KV Cache Allocation for HMA Models#37885

Open
Etelis wants to merge 28 commits intovllm-project:mainfrom
Etelis:canonical-kv-caches
Open

Canonical KV Cache Allocation for HMA Models#37885
Etelis wants to merge 28 commits intovllm-project:mainfrom
Etelis:canonical-kv-caches

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Mar 23, 2026

This is the first phase of a multi-phase effort to enable contiguous KV cache allocation for all model architectures. Currently, only single-group (uniform) models benefit from contiguous cross-layer blocks. This PR extends that to HMA models with uniform page sizes. Future phases will broaden support to models with varying page sizes and additional architectures.

The existing allocate_uniform_kv_caches path only supports single-group models (all layers identical). HMA models like Gemma 3 have multiple KV cache groups (full attention + sliding window) with different eviction policies but the same page size. Previously, these models fell back to per-layer allocation, which scatters block data across non-contiguous memory regions, making RDMA transfers inefficient.

This PR extends contiguous KV cache allocation to HMA models where all KV cache groups share the same page size.

Test plan

  • Unit tests: pytest tests/v1/kv_connector/unit/test_canonical_kv_caches.py -v -s
    • Happy path for use_canonical_kv_caches
    • Parametrized rejection cases (single group, no connector, no HMA, mamba layers, no stride order)
    • Allocation correctness: shapes, memory sharing, physical contiguity, group refs, page sizes

Related PRs

  • #34373 — Original KVCacheTopology PR (closed, too complex).
  • #37339WorkerConnectorInitializationData pattern. We adopt their interface design -- Hopefully to be merged after that PR.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the canonical-kv-caches branch from aa5d414 to 9697d1c Compare March 23, 2026 11:58
@mergify mergify Bot removed the needs-rebase label Mar 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces canonical KV cache allocation for Hybrid Multi-Attention (HMA) models, specifically targeting those with uniform page sizes. This is a significant improvement as it enables contiguous cross-layer block allocation, which was previously limited to single-group models. The changes involve new data structures to represent canonical KV caches and their references, along with modifications to the KV cache allocation logic within the gpu_model_runner and kv_connector_model_runner_mixin. A comprehensive unit test suite has been added to validate the new allocation strategy under various conditions, including happy paths and rejection cases. The implementation appears well-considered and robust, addressing the stated goal of improving RDMA transfer efficiency by ensuring memory contiguity for HMA models.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 23, 2026

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
# The connector must support HMA
if not supports_hma(get_kv_transfer_group()):
return False
if len(kv_cache_config.kv_cache_groups) <= 1:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(kv_cache_config.kv_cache_groups) <= 1:
if len(kv_cache_config.kv_cache_groups) < 1:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,
makes sense.

if len(kv_cache_config.kv_cache_groups) <= 1:
return False

# All groups must use AttentionSpec with uniform page size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# All groups must use AttentionSpec with uniform page size
# Currently, all groups must use AttentionSpec with uniform page size
# We plan to gradually relax this requirement to support other cases

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, sorry.

Comment on lines +352 to +353
spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
assert isinstance(spec, AttentionSpec)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this and use the spec inside the loop per each group?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
)
self.cross_layers_kv_cache = cross_layers_kv_cache
self.cross_layers_attn_backend = attn_backend
elif self.use_canonical_kv_caches(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this check before checking use_uniform_kv_cache.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment thread vllm/v1/worker/gpu_model_runner.py
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

# prepend a group_size dimension into the shape
kv_cache_shape = attn_backend.get_kv_cache_shape(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this logic AFTER we allocate the single tensor?

Then, inside the layer loop, we can reshape?
I think we can also remove assert len(unique_kernel_bs) == 1.

I think it's better to also build the group_data_refs inside the same loop.

Comment thread vllm/v1/kv_cache_interface.py Outdated
@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes are currently specific to connector usage.
I think we should move them to base.py.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
…nector

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the canonical-kv-caches branch from db277a8 to 5b2b3bc Compare March 23, 2026 13:54
Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
WorkerConnectorInitializationData,
)

kv_transfer_group.initialize_worker_connector(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, initialize_worker_connector is needed for the CacheBlend use-case.
Let's try to call it exactly as in #37339.
But keep this if here and simply pass, commenting that the canonical kv caches will be registered below.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought they'd add it themselves afterwards,
nvm I will fix it.

Combine the kv_caches population, block tensor splitting, and
layer-to-position mapping into a single pass over positions.
Remove the unique kernel block size assertion.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Call initialize_worker_connector unconditionally so connectors like
CacheBlend can use it regardless of the allocation path taken.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the canonical-kv-caches branch from 30cca9f to 432d002 Compare March 23, 2026 17:27
canonical_kv_caches is the CanonicalKVCaches wrapping
for the connector.
"""
# all tensors have the same size (validated by use_canonical_kv_caches)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did we validate this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orozery sharp eye

fixed.

Move the uniform tensor size check into use_canonical_kv_caches
so the precondition is validated before entering the allocation
path, keeping the assert in allocate_canonical_kv_caches as a
safety net.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested a review from orozery March 26, 2026 14:26
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Use contiguous_buffer.select(group_dim, i) to obtain per-position
canonical block tensors where num_blocks is always the leading
dimension. This eliminates the block_dim splitting loop and
multi-dimensional index arithmetic.

Also strengthen use_canonical_kv_caches to explicitly verify
num_blocks is the leading physical dimension, and restore the
single-group rejection (< 2) so single-group models correctly
use the uniform path.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Comment on lines +442 to +467
kernel_block_size = kernel_block_sizes[0]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

attn_backend = attn_groups[0][0].backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)

# prepend a group_size dimension into the shape
kv_cache_shape = (group_size,) + kv_cache_shape

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))

physical_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
assert physical_shape[0] == kernel_num_blocks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this logic inside the per-layer loop that sets kv_caches?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Copy Markdown
Collaborator

@orozery orozery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Etelis !
Can you please test this PR on top of this branch?
https://github.com/orozery/vllm/tree/kv-offload-hma
Specifically, verify test_cpu_offloading.py passes, and whether we see performance gains.

[] for _ in kv_cache_config.kv_cache_groups
]

kernel_block_size = kernel_block_sizes[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we initialize kernel_block_size = kernel_block_sizes[gid] inside the loop?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I didn't think of the fact backends could have different kernel block sizes.

block_tensor = typed_buffer.select(group_dim, i)
tensor_idx = len(block_tensors)
page_bytes = block_tensor[0].numel() * block_tensor.element_size()
block_tensors.append(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we expecting a single cross-layers tensor? With shape (num_blocks, page_size) and dtype int8?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's dumb.
Fixed.

Etelis and others added 3 commits April 12, 2026 15:56
Replace per-position KVCacheBlockTensor objects with a single
(num_blocks, cross_layer_page_size) int8 tensor. This avoids
recomputing block tensors per position and matches the pattern
used by the offloading connector's register_cross_layers_kv_cache.

Also use per-group kernel_block_sizes[gid] inside the loop instead
of hardcoded kernel_block_sizes[0].

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 12, 2026

Thanks @Etelis ! Can you please test this PR on top of this branch? https://github.com/orozery/vllm/tree/kv-offload-hma Specifically, verify test_cpu_offloading.py passes, and whether we see performance gains.

Running on your branch I have hit some issues with the connector not implementing theinitialize_worker_connector
I have fixed that here:
orozery#1

Ran on top of that branch with an A100
Gemma 3 (HMA)

Metric Baseline (per-layer allocation) With Canonical Allocation Improvement
Cold start 56.12ms 51.60ms -8.1%
GPU hit 12.65ms 12.42ms -1.8%
CPU hit 20.68ms 18.75ms -9.3%

Running other models as well so I'll update soon.

Comment on lines +500 to +507
for layer_name in kv_cache_tensor.shared_by:
layer_gid = layer_to_group_idx[layer_name]
group_data_refs[layer_gid].append(
KVCacheBlockDataRef(
tensor_idx=0,
page_size_bytes=page_size,
)
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a single data reference per group.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Drop the duplicate KVCacheBlockTensor / KVCacheBlockDataRef /
CanonicalKVCaches dataclasses from kv_connector/v1/base.py and import
the existing types from vllm.v1.kv_offload.spec. Emit a single data
reference per group instead of one per layer.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested a review from xuechendi as a code owner April 19, 2026 15:31
@Etelis Etelis requested a review from orozery April 19, 2026 15:35
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
…ge size

Restore KVCacheBlockTensor / KVCacheBlockDataRef / CanonicalKVCaches
in kv_connector/v1/base.py (these types are connector-owned) and fix
KVCacheBlockDataRef.page_size_bytes to cover all layers in the group
(page_size * group_size) now that we emit a single ref per group.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested a review from orozery April 23, 2026 19:10
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
EtelisIBM and others added 2 commits April 28, 2026 10:48
- use_canonical_kv_caches: < 2 -> < 1 to allow single-group HMA models
- allocate_canonical_kv_caches: read num_blocks from config and assert
- allocate_canonical_kv_caches: drop unreachable try/except around
  get_kv_cache_stride_order (guard already validates it succeeds)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Copy link
Copy Markdown
Collaborator

@orozery orozery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @Etelis !

@NickLucche @heheda12345 WDYT?

This PR extends cross-layers layout to models with multiple groups, but all attention (e.g. gpt-oss).
More importantly, it defines a generic API (CanonicalKVCaches) for describing the KV caches (either cross layers or not) to the connector.
It is meant to replace register_cross_layers_kv_cache, which is kept for now for backward compatibility.
This API could support (without any extending) models using mamba or hybrid mamba/attention. (We plan that to be a follow-up to this PR).
Also, this API can later be extended to include striding information for connectors doing hetro-TP transfers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants