Skip to content

[Feature] Enable uniform KV cache allocation for multi-group HMA models#34373

Closed
Etelis wants to merge 33 commits intovllm-project:mainfrom
Etelis:itay/hma-uniform-kv-cache
Closed

[Feature] Enable uniform KV cache allocation for multi-group HMA models#34373
Etelis wants to merge 33 commits intovllm-project:mainfrom
Etelis:itay/hma-uniform-kv-cache

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Feb 11, 2026

use_uniform_kv_cache() currently rejects any model with more than one KV cache group, which means hybrid-attention models (alternating full + sliding-window layers) cannot use the contiguous cross-layer layout for efficient KV transfers.

This PR relaxes the single-group gate: instead of requiring exactly one group, we loop over all groups and check that they share the same backend shape and stride order.

Test Plan

Unit tests (tests/v1/kv_connector/unit/test_uniform_kv_cache.py) — 4 tests:

test_uniform_kv_cache.py::test_multi_group_compatible PASSED
test_uniform_kv_cache.py::test_multi_group_incompatible PASSED
test_uniform_kv_cache.py::test_allocate_multi_group_shared_tensors PASSED
test_uniform_kv_cache.py::test_allocate_rejects_mismatched_kernel_block_sizes PASSED

Test Result

E2E on an H100 with google/gemma-2-2b:

Single-group regression (HMA disabled, OffloadingConnector):

INFO  Allocating a cross layer KV cache of shape (7842, 26, 2, 16, 4, 256)
                                                         ^^
                                                    all 26 layers in 1 group
Prompt: 'The capital of France is'
Output: ' a city of contrasts. It is a city of art, culture, and history. It is a'

Multi-group (HMA enabled, SupportsHMA test connector):

INFO  Allocating a cross layer KV cache of shape (15685, 13, 2, 16, 4, 256)
                                                          ^^
                                                     13 layers per group (2 groups)
[HMATestConnector] register_cross_layers_kv_cache:
  tensor shape=torch.Size([15685, 13, 2, 16, 4, 256]), dtype=torch.bfloat16
  backend=FlashAttentionBackend

Prompt: 'The capital of France is'
Output: ' a city of contrasts. It is a city of art, culture, and history. It is a'

Relax the single-group constraint in use_uniform_kv_cache() so that
hybrid-attention models (e.g. Gemma 2 with alternating full + sliding-window
layers) can benefit from the contiguous cross-layer KV cache layout used
for efficient KV transfers.

Instead of requiring exactly one group, loop over all groups and verify
they share the same backend shape and stride order. Also relax the
kernel_block_sizes assertion in allocate_uniform_kv_caches() to accept
multiple groups with the same block size.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables uniform KV cache allocation for models with multiple attention groups, such as Hybrid-Attention Models (HMA). This is achieved by relaxing the single-group constraint and instead checking for compatibility (e.g., same shape and stride order) across all groups. The changes in use_uniform_kv_cache are logical and well-supported by a comprehensive new test suite. I have one suggestion to make the compatibility check more robust by also ensuring all attention groups use the same backend, which is implied by the docstring and the subsequent allocation logic.

Comment on lines +153 to +197
if not attn_groups:
return False

attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False
reference_shape = None
reference_stride_order = None

for subgroups in attn_groups:
if len(subgroups) != 1:
return False

attn_group = subgroups[0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False

attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)

attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False

try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False
# check that attention backend includes a layers dimension
if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
return False

# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
if reference_shape is None:
reference_shape = kv_cache_shape
reference_stride_order = kv_cache_stride_order
elif (
kv_cache_shape != reference_shape
or kv_cache_stride_order != reference_stride_order
):
return False

return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for use_uniform_kv_cache states that for a uniform layout, all KV cache groups must have the same backend. However, the current implementation only checks for compatible kv_cache_shape and kv_cache_stride_order, but not that the attn_backend is the same across all groups.

The subsequent allocate_uniform_kv_caches function uses the backend from the first attention group, which could lead to incorrect behavior or runtime errors if other groups use a different backend.

To prevent this potential issue and align with the documentation, I suggest also checking that all attention groups share the same backend instance.

        if not attn_groups:
            return False

        reference_shape = None
        reference_stride_order = None
        reference_backend = None

        for subgroups in attn_groups:
            if len(subgroups) != 1:
                return False

            attn_group = subgroups[0]
            kv_cache_spec = attn_group.kv_cache_spec
            if not isinstance(kv_cache_spec, AttentionSpec):
                return False

            attn_backend = attn_group.backend
            kv_cache_shape = attn_backend.get_kv_cache_shape(
                1234,
                kv_cache_spec.block_size,
                kv_cache_spec.num_kv_heads,
                kv_cache_spec.head_size,
                cache_dtype_str=cache_dtype,
            )

            try:
                kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
                    include_num_layers_dimension=True
                )
            except (AttributeError, NotImplementedError):
                return False

            # check that attention backend includes a layers dimension
            if len(kv_cache_stride_order) != len(kv_cache_shape) + 1:
                return False

            if reference_backend is None:
                reference_shape = kv_cache_shape
                reference_stride_order = kv_cache_stride_order
                reference_backend = attn_backend
            elif (
                kv_cache_shape != reference_shape
                or kv_cache_stride_order != reference_stride_order
                or attn_backend is not reference_backend
            ):
                return False

        return True

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Feb 12, 2026

Looks great for a start! Thanks @Etelis !
I also tested this on my side and it seems to work well.
As expected, KV transfer performance is greatly improved!

IIUC right now you handle the case of multiple groups, but requiring:

  1. A single attention group per KV group.
  2. Consistent KV cache shape and stride order across all layers
    I think we can try and relax the constraints for using cross-layer blocks even further.

Take a look at the current options for defining KV cache tensors:
https://github.com/vllm-project/vllm/blob/8a798be929d62a6467fd079c03c83632f8231b11/vllm/v1/core/kv_cache_utils.py#L1095-1140

There are 2 cases:

  1. A single group containing all layers, but layers can have different page (hidden) size.
  2. Multiple groups but all layers have the same page size.

For case 1 (which is less common I think), we can group layers by their page size and create a single tensor per each group of layers with the same page size.
A group of layers having the same page size (and the same number of blocks, which is always true) can always share a tensor, even if they have different KV cache shapes / stride order.
The only requirement in this case is the cache shape to start with (num_blocks, ...).
In that case, you can set a physical layout of (num_blocks, num_layers, page_size).
Then, each layer will create a different view of it: (num_blocks, num_layers, rest_of_unique_logical_shape...)
And finally, create a new view that fixates the layer dimension to the layer_idx.
Then, each layer can permute according to its stride_order (setting with_layers_dim=False).

Does that make sense?

@Etelis Etelis requested a review from NickLucche as a code owner February 16, 2026 15:37
@Etelis Etelis force-pushed the itay/hma-uniform-kv-cache branch from 0de79af to 8fab499 Compare February 16, 2026 16:20
…dels

Group layers by page_size_bytes and allocate one contiguous int8 tensor
per group. Each layer gets its own zero-copy view via view->slice->permute
(attention) or as_strided (mamba). This relaxes the previous constraint
that all layers must share identical shapes and stride orders.

Introduces CrossLayerGroup dataclass to bundle backing tensors with
metadata. Supports AttentionSpec (all subclasses) and MambaSpec layers.

Signed-off-by: Itay Etlis <itayetlis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Replace cross_layers_kv_cache + cross_layers_attn_backend with a
list[CrossLayerGroup]. Single pure-attention groups use the optimized
register_cross_layers_kv_cache path; otherwise fall back to
register_kv_caches.

Signed-off-by: Itay Etlis <itayetlis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Update existing tests for new CrossLayerGroup return type. Add Mamba
allocation test with shape verification and data isolation. Replace
incompatible-page-size rejection test with acceptance test (different
page sizes now produce separate groups).

Signed-off-by: Itay Etlis <itayetlis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Feb 16, 2026

Looks great for a start! Thanks @Etelis ! I also tested this on my side and it seems to work well. As expected, KV transfer performance is greatly improved!

IIUC right now you handle the case of multiple groups, but requiring:

1. A single attention group per KV group.

2. Consistent KV cache shape and stride order across all layers
   I think we can try and relax the constraints for using cross-layer blocks even further.

Take a look at the current options for defining KV cache tensors: https://github.com/vllm-project/vllm/blob/8a798be929d62a6467fd079c03c83632f8231b11/vllm/v1/core/kv_cache_utils.py#L1095-1140

There are 2 cases:

1. A single group containing all layers, but layers can have different page (hidden) size.

2. Multiple groups but all layers have the same page size.

For case 1 (which is less common I think), we can group layers by their page size and create a single tensor per each group of layers with the same page size. A group of layers having the same page size (and the same number of blocks, which is always true) can always share a tensor, even if they have different KV cache shapes / stride order. The only requirement in this case is the cache shape to start with (num_blocks, ...). In that case, you can set a physical layout of (num_blocks, num_layers, page_size). Then, each layer will create a different view of it: (num_blocks, num_layers, rest_of_unique_logical_shape...) And finally, create a new view that fixates the layer dimension to the layer_idx. Then, each layer can permute according to its stride_order (setting with_layers_dim=False).

Does that make sense?

Thanks @orozery! Went ahead and implemented this.

The main idea is what you described — layers are grouped by page_size_bytes, and each group gets a contiguous int8 tensor of shape (num_blocks, num_layers_in_group, page_size_bytes). Both your cases are covered: different hidden sizes within one group produce separate sub-groups by page size, and multiple groups with the same page size share one tensor.

For per-layer views I'm doing the view → slice → permute pipeline you suggested:

raw.view(dtype).view(kernel_num_blocks, num_layers, kernel_page_elements)
[:, layer_idx]
permute(...)  # each layer uses its own stride order

I also extended this to handle MambaSpec layers — those use as_strided with cross-layer-aware strides so the per-block contiguity is preserved for transfers. (hopefully I understood it correctly)

Tested on H100 with a hybrid setup (4 attention + 2 Mamba layers, 2 page_size groups), all passing. No connector changes yet — will handle nixl_connector/utils.py for multi-group registration as a follow-up.

Etelis and others added 5 commits February 16, 2026 23:50
The backing tensor is always int8 by construction — storing it as a
field adds no information. Remove from dataclass and test assertion.

Signed-off-by: Itay Etlis <itayetlis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
When TP is active, attention layers are laid out as
(num_blocks, num_kv_heads, num_layers, per_head_page_bytes) so
head-based slicing is contiguous for RDMA transfers.

Unifies sentinel probes into _find_kv_cache_dims, generalizes
_per_layer_permutation for any number of extracted dims, and
_create_attention_layer_view handles both layouts via tp flag.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Feb 19, 2026

Checked with GPT-OSS-20B, alternating 12 SW + 12 FA(layer_types: [sliding_attention, full_attention, ...]).

All 24 layers share: (num_kv_heads=8, head_size=64, block_size=16),

so they should merge into a single cross-layer group.

H100 80GB — TP=1 (default layout)

Group 0:
  tp_layout:       False
  tensor shape:    (136756, 12, 32768)
  page_size_bytes: 32768
  num layers:      24  (12 sliding + 12 full, paired into 12 shared tensors)
  spec type:       SlidingWindowSpec
  backend:         FlashAttentionBackend
  num_kv_heads:    8
  head_size:       64
  strides:         (393216, 32768, 1)
  per-layer view:  (2, 136756, 16, 8, 64) bfloat16

Model loaded and generated text.

8x RTX 3090 — TP=2 (TP layout)

With tensor_parallel_size=2, (4 heads per GPU). The grouping key becomes ("tp", 4, 4096)

Group 0:
  tp_layout:       True
  tensor shape:    (66511, 4, 12, 4096)
  page_size_bytes: 16384
  num layers:      24  (12 sliding + 12 full, paired into 12 shared tensors)
  spec type:       SlidingWindowSpec
  backend:         TritonAttentionBackend (had an issue with FLASHINFER on this setup)
  num_kv_heads:    4   (8 total / TP=2)
  head_size:       64
  strides:         (196608, 49152, 4096, 1)
  per-layer view:  (66511, 2, 16, 4, 64) bfloat16

… order

Replace the external tp flag with per-layer backend probing.
Each layer is classified into one of three groups:
- ordered: blocks first, heads before layers (e.g. HND backends)
- default: grouped by page_size (e.g. NHD backends, Mamba)
- solo: blocks not outermost (fallback, one layer per group)

This removes the need for callers to know about TP configuration
and lets the allocation follow the backend's preferred physical
layout exactly.

Signed-off-by: Itay Etlis <itayetlis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis force-pushed the itay/hma-uniform-kv-cache branch from e3027ec to d69c498 Compare February 20, 2026 12:09
@Etelis Etelis force-pushed the itay/hma-uniform-kv-cache branch from 4218db1 to 31e13a1 Compare February 23, 2026 14:57
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 23, 2026

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

EtelisIBM and others added 2 commits February 23, 2026 17:42
All backends place blocks at physical position 0 in the with-layers
stride order, making the blocks_phys != 0 guard unreachable.  Remove
the solo key, the blocks_phys check, and the now-unused tensor_idx
parameter from _cross_layer_group_key.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 23, 2026

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Merge register_kv_caches and register_cross_layers_kv_cache into a
single register_kv_caches method that accepts an optional
cross_layer_groups parameter.  This enables connectors to handle
multiple cross-layer groups for HMA models with heterogeneous
attention types (full + sliding-window) and mixed layer kinds
(attention + Mamba) without falling back to the per-layer path.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 23, 2026

Hi @Etelis, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

EtelisIBM and others added 7 commits February 23, 2026 21:31
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Add back register_cross_layers_kv_cache to nixl, offloading, and
multi connectors alongside the unified register_kv_caches API.
This restores the legacy single-tensor registration path for
connectors that set prefer_cross_layer_blocks.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Add KVCacheTopology and register_hybrid_kv_caches to the connector
base class for multi-group hybrid attention models.

Dual-path gating in use_uniform_kv_cache:
- Hybrid path (register_hybrid_kv_caches): multi-group, Attention+Mamba
- Legacy path (prefer_cross_layer_blocks): single-group, AttentionSpec only

Restore allocate_uniform_kv_caches (original single-tensor allocation)
and rename multi-group allocation to allocate_hybrid_kv_caches.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
…ides

The base class provides a default no-op; connectors will add their
own overrides independently when they adopt the legacy path.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Restore all connector implementations to their main branch state.
The register_kv_caches base class signature is reverted to accept
only kv_caches dict, matching the connector overrides. Cross-layer
registration now uses register_cross_layers_kv_cache (legacy) or
register_hybrid_kv_caches (new) instead.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
os.sched_setaffinity is not available on all platforms (e.g. macOS).
Add a hasattr guard to avoid AttributeError at runtime and a clear
NotImplementedError message.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
Comment thread vllm/v1/worker/kv_connector_model_runner_mixin.py Outdated
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
- Explicitly handle MambaSpec in _cross_layer_group_key; isolate unknown
  spec types instead of grouping them with others
- Return isolated key when blocks dim is not first or layers dim is not
  after blocks in the physical stride order
- Validate tensor size and key agreement for all layers sharing a tensor
  in allocate_hybrid_kv_caches
- Fill num_heads_dim and block_size_dim in KVCacheTopology for ordered
  groups using sentinel-value probing
- Set num_layers_dim=None for isolated (non-shared) tensors
- Remove dead fallback in _create_attention_layer_view (layers reaching
  that function are guaranteed to have a valid stride order)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
- Add _MockBlocksNotFirstBackend to verify layers with blocks not in
  physical dim 0 are isolated (no cross-layer sharing)
- Assert num_blocks_dim, num_layers_dim, num_heads_dim in HND topology
- Add test_blocks_not_first_is_isolated covering the isolated path
- Fix group.spec references to use group.page_size_bytes

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested a review from njhill as a code owner March 3, 2026 06:10
EtelisIBM and others added 5 commits March 8, 2026 19:12
…_mixin

Rename cryptic variable names for readability:
- raw -> buffer, spec -> attn_spec/mamba_spec, el -> element_size
- npkb/knb -> kernel_blocks_per_spec_block/kernel_num_blocks
- rep_name/rep_spec -> representative_name/representative_spec
- gid -> group_id, log_to_phys -> logical_to_physical
- _B/_H -> _SENTINEL_BLOCKS/_SENTINEL_HEADS

Trim docstrings to match vLLM conventions: brief descriptions
for private methods, concise Args/Returns for public methods.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Add Args section and trim to match vLLM docstring conventions.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
When get_kv_cache_stride_order(include_num_layers_dimension=True)
is not supported, fall back to prepending the layers dimension
to the base stride order instead of using an identity permutation.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Mar 8, 2026

Resolved all CRs @orozery.
Hope it looks fine now.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 11, 2026
…nectors

Replace dimension-index-based topology metadata with explicit byte
offset/length references. Connectors now receive KVCacheTensorReference
(physical tensors with page sizes) and KVCacheDataReference (per-group
chunk layout with unpadded sizes and head strides).

Adds build_kv_cache_references to convert CrossLayerGroups into the
new types at registration time. Handles attention chunks, Mamba
per-state chunks (conv/ssm), byte-level padding, and layer-level
padding from uneven HMA groups.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants