Skip to content

[Hybrid] Map multiple FullAttn layers to a single page#35703

Open
peakcrosser7 wants to merge 40 commits intovllm-project:mainfrom
peakcrosser7:feat/multi_attn2mamba
Open

[Hybrid] Map multiple FullAttn layers to a single page#35703
peakcrosser7 wants to merge 40 commits intovllm-project:mainfrom
peakcrosser7:feat/multi_attn2mamba

Conversation

@peakcrosser7
Copy link
Copy Markdown
Contributor

@peakcrosser7 peakcrosser7 commented Mar 2, 2026

Purpose

As Mamba models grow larger, their Mamba states are becoming increasingly large, which in turn drives up the block_size. Since input tokens are hashed at the granularity of block_size, an excessively large block_size is detrimental to prefix-caching hit rates.
Take Qwen3.5-397B-A17B-FP8 as an example: when deployed with TP2, the block_size can reach 2096, meaning only one hash is generated per 2096 tokens. Both caching and cache-hit lookups operate at this 2096-token granularity, which is highly unfavorable for prefix cache hits.

This PR reduces the effective block_size by packing multiple FullAttn layers onto a single page.
Under the previous logic, a page (or block) was allocated to either one Mamba layer (for storing states) or one FullAttn layer (for storing KV-Cache).
This PR introduces support for assigning a single page to N FullAttn layers, where each FullAttn layer uses 1/N of the page's space, thereby reducing the block_size to 1/N of its original size (consequently, Mamba states will occupy more blocks).

Usage: Add the --attn-pack-size N argument to the engine startup parameters.

Test Plan

  1. Added three test cases to test_kv_cache_utils.py: test_merge_attn_layers_into_pack, test_split_attn_layers_from_pack, and test_get_kv_cache_configs_with_mamba.
pytest tests/v1/core/test_kv_cache_utils.py::test_merge_attn_layers_into_pack  tests/v1/core/test_kv_cache_utils.py::test_split_attn_layers_from_pack tests/v1/core/test_kv_cache_utils.py::test_get_kv_cache_configs_with_mamba
  1. Added tests/v1/worker/test_hybrid_kv_cache_layout.py to ensure the Attention KV-Cache layout remains consistent when replacing _update_hybrid_attention_mamba_layout() with _get_hybrid_attention_mamba_layout().
pytest tests/v1/worker/test_hybrid_kv_cache_layout.py
  1. Added test_hybrid_attention_mamba_kv_cache_pack_size in test_gpu_model_runner.py to verify the end-to-end logic of get_kv_cache_configs() and initialize_kv_cache() with the new Attention layer packing enabled.
pytest tests/v1/worker/test_gpu_model_runner.py::test_hybrid_attention_mamba_kv_cache_pack_size
  1. Temporarily added the --attn-pack-size 4 argument to tests/evals/gsm8k/configs/Qwen3-Next-FP8-EP2.yaml for local accuracy validation (the --moe-backend=flashinfer_trtllm option was commented out as I am using H20 GPUs).
pytest -sv tests/evals/gsm8k/test_gsm8k_correctness.py -k Qwen3-Next-FP8-EP2  --config-list-file=configs/models-blackwell.txt

Test Result

root@hhy_develop_vllm_opsrc:~/huanghy/vllm_opsrc# pytest tests/v1/core/test_kv_cache_utils.py::test_merge_attn_layers_into_pack  tests/v1/core/test_kv_cache_utils.py::test_split_attn_layers_from_pack tests/v1/core/test_kv_cache_utils.py::test_get_kv_cache_configs_with_mamba
===================================================================== test session starts =====================================================================
platform linux -- Python 3.11.13, pytest-8.4.2, pluggy-1.6.0
rootdir: /root/huanghy/vllm_opsrc
configfile: pyproject.toml
plugins: hypothesis-6.137.1, anyio-4.11.0
collected 3 items                                                                                                                                             

tests/v1/core/test_kv_cache_utils.py ...                                                                                                                [100%]

====================================================================== warnings summary =======================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/v1/core/test_kv_cache_utils.py: 14 warnings
  /opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================================================== 3 passed, 16 warnings in 3.24s ================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
root@hhy_develop_vllm_opsrc:~/huanghy/vllm_opsrc# 
root@hhy_develop_vllm_opsrc:~/huanghy/vllm_opsrc# pytest tests/v1/worker/test_hybrid_kv_cache_layout.py
===================================================================== test session starts =====================================================================
platform linux -- Python 3.11.13, pytest-8.4.2, pluggy-1.6.0
rootdir: /root/huanghy/vllm_opsrc
configfile: pyproject.toml
plugins: hypothesis-6.137.1, anyio-4.11.0
collected 84 items                                                                                                                                            

tests/v1/worker/test_hybrid_kv_cache_layout.py .......ssssssssssssss............................ssssssssssssss.....................                     [100%]

====================================================================== warnings summary =======================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../../opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: 14 warnings
  /opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================== 56 passed, 28 skipped, 16 warnings in 15.12s =========================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
root@hhy_develop_vllm_opsrc:~/huanghy/vllm_opsrc# pytest tests/v1/worker/test_gpu_model_runner.py::test_hybrid_attention_mamba_kv_cache_pack_size
===================================================================== test session starts =====================================================================
platform linux -- Python 3.11.13, pytest-8.4.2, pluggy-1.6.0
rootdir: /root/huanghy/vllm_opsrc
configfile: pyproject.toml
plugins: hypothesis-6.137.1, anyio-4.11.0
collected 3 items                                                                                                                                             

tests/v1/worker/test_gpu_model_runner.py ...                                                                                                            [100%]

====================================================================== warnings summary =======================================================================
<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:241
  <frozen importlib._bootstrap>:241: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

../../../opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: 14 warnings
  /opt/conda/lib/python3.11/site-packages/torch/jit/_script.py:362: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/v1/worker/test_gpu_model_runner.py::test_hybrid_attention_mamba_kv_cache_pack_size[2]
  /root/huanghy/vllm_opsrc/tests/v1/worker/test_gpu_model_runner.py:1501: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
    base_ptr = kv_tensors[0].storage().data_ptr()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================= 3 passed, 17 warnings in 143.93s (0:02:23) ==========================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
GSM8K Results for /mnt/disk0/huanghaoyan.hhy/Qwen3-Next-80B-A3B-Instruct-FP8:
  Measured metric: 0.8643
  Expected metric: 0.8500
  Tolerance: 0.0800
  Questions: 1319
  Invalid rate: 0.001
  Latency: 133.6s
  QPS: 9.9
✅ GSM8K test passed for /mnt/disk0/huanghaoyan.hhy/Qwen3-Next-80B-A3B-Instruct-FP8
[RemoteOpenAIServer] Sent SIGTERM to process 2078273
[RemoteOpenAIServer] Server 2078273 terminated gracefully
[RemoteOpenAIServer] GPU memory released to 2.09 GB (target: 4.24 GB) in 0.0s
PASSED

Full Test Logs:
eval_pack4.log


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@mergify mergify Bot added the v1 label Mar 2, 2026
@peakcrosser7 peakcrosser7 changed the title [WIP][Hybrid] Map more FullAttn layers to one block [WIP][Hybrid] Map more FullAttn layers to one page Mar 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to group multiple FullAttention layers to share a single KV cache page, which is particularly useful for hybrid Mamba models. The changes are extensive, touching configuration, argument parsing, KV cache utilities, and the GPU model runner. While the overall direction is sound, I've identified a critical bug in the GPU model runner related to incorrect tensor stride calculations that could lead to memory corruption. I've also noted a minor logging inaccuracy in the KV cache utility.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment on lines 6004 to 6011
kv_cache.as_strided_(
size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
stride=(
hidden_size,
2 * hidden_size * attn_group_size,
*kv_cache.stride()[2:],
),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The as_strided_ call appears to be incorrect. The kv_cache tensor has a shape of (2, num_blocks, ...) for some attention backends, but the stride tuple (hidden_size, 2 * hidden_size * attn_group_size, ...) seems to be calculated for a tensor with shape (num_blocks, 2, ...). This mismatch between the tensor's shape and the provided strides can lead to incorrect memory access, data corruption in the KV cache, and ultimately incorrect model outputs. This is a critical issue that needs to be addressed.

Comment thread vllm/v1/core/kv_cache_utils.py
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@peakcrosser7 peakcrosser7 changed the title [WIP][Hybrid] Map more FullAttn layers to one page [WIP][Hybrid] Map more FullAttn layers to a single page Mar 3, 2026
@peakcrosser7 peakcrosser7 marked this pull request as ready for review March 3, 2026 16:29
@peakcrosser7 peakcrosser7 changed the title [WIP][Hybrid] Map more FullAttn layers to a single page [Hybrid] Map multiple FullAttn layers to a single page Mar 3, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@heheda12345
Copy link
Copy Markdown
Collaborator

@Etelis are there any places that can work automatically if we enable this mode but keep page_size_bytes as it's previous definition? @peakcrosser7 already updated gpu_model_runner, and I think KVConnectors should aware of this mode when trying to support this model.

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 29, 2026

@peakcrosser7 Can you verify if my understanding below is correct?

  1. This logic will only apply in case you have 2 groups: 1 MambaSpec group, and 1 FullAttentionSpec group.
  2. After you apply the logic, you will get upto 3 groups with uniform page size: 1 MambaSpec group (possibly padded), 1 FulllAttentionSpec fully-packed group (not padded), optionally 1 partially-packed FullAttentionSpec group (padded)

)

try:
kv_cache_stride_order = backend.get_kv_cache_stride_order()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 @peakcrosser7 I think it's better to use backend.get_kv_cache_stride_order(include_num_layers_dimension=True) to pack the layers more efficiently KV connectors.
Specifically, this will allow a (num_heads, num_layers, ...) page layout instead of the current proposed (num_layers, ...) page layout.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I’ll look into the include_num_layers_dimension parameter and see how to adjust the code accordingly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_cache_shape = (num_layers,) + kv_cache_shape
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
# allocate one contiguous buffer for all layers
cross_layers_kv_cache = (
torch.zeros(total_size, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(kv_cache_shape)
)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = permuted_kv_cache[i]
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor

@orozery Is the logic similar to this? We could add a num_layers dimension to the KV-Cache shape, where num_layers equals pack_size for packed Attention layers. When assigning to a specific layer, we would then index into the corresponding slice based on its attn_pack_idx. Is my understanding correct?

# shape [num_layers=pack_size, 2, num_blocks, block_size, num_heads, head_size]
tensor = torch.as_strided(
    raw_tensor.view(dtype),
    size=kv_cache_shape,
    stride=kv_cache_stride,
    storage_offset=storage_offset,
).permute(*inv_order)
# shape [2, num_blocks, block_size, num_heads, head_size]
kv_caches[layer_name] = tensor[attn_pack_idx]

@Etelis
Copy link
Copy Markdown
Contributor

Etelis commented Mar 29, 2026

@Etelis are there any places that can work automatically if we enable this mode but keep page_size_bytes as it's previous definition? @peakcrosser7 already updated gpu_model_runner, and I think KVConnectors should aware of this mode when trying to support this model.

@heheda12345 Yes — almost everything works automatically if page_size_bytes stays single-layer. The places that would not work automatically and need explicit pack_size handling are:

  1. get_kv_cache_config_from_groups() in kv_cache_utils.py — memory allocation needs pack_size to calculate the correct tensor size (multiply page_size_bytes * pack_size when computing num_blocks and KVCacheTensor.size).
  2. _reshape_kv_cache_tensors() in gpu_model_runner.py — the num_blocks = tensor.numel() // page_size_bytes calculation and stride computation need pack_size.
  3. init_attn_backend() in attn_utils.py — same num_blocks calculation from tensor size.

Everything else — and particularly all KV connector/offloading code (CanonicalKVCacheTensor, nixl_connector, cpu_gpu offload worker, kv_connector_model_runner_mixin) — assumes single-layer page_size_bytes and would work correctly without changes.

Keeping page_size_bytes single-layer and handling pack_size explicitly at those 2-3 sites would make the KV connector integration cleaner. KVConnectors would not need to be "aware" of this mode at all for basic operation — they'd just see standard per-layer pages. They would only need awareness if we want the more efficient (num_heads, num_layers, ...) layout that @orozery is suggesting.

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

@peakcrosser7 Can you verify if my understanding below is correct?

  1. This logic will only apply in case you have 2 groups: 1 MambaSpec group, and 1 FullAttentionSpec group.
  2. After you apply the logic, you will get upto 3 groups with uniform page size: 1 MambaSpec group (possibly padded), 1 FulllAttentionSpec fully-packed group (not padded), optionally 1 partially-packed FullAttentionSpec group (padded)

Hi @orozery , thanks for the review!

I believe there might be a slight misunderstanding regarding the implementation. Let me clarify the logic with two key points:

  1. KV-Cache Groups: This PR targets hybrid models with both Mamba and Attention layers. While there are 2 types of KV-cache groups, the actual number of groups can be more than two. For instance, in Qwen 3.5 (without packing), there is typically 1 Attention group and 3 Mamba groups.

  2. Impact of Packing: The packing logic mainly affects the number of Mamba groups, increasing them by a factor of pack_size. Taking the 3:1 Mamba-to-Attention ratio in Qwen 3.5 as an example (assuming 15 layers of each "group"):
    If pack_size = 4, the 15 Attention layers are packed into 4 units (ceil of 15/4). The Attention group now manages these 4 packed units, but the group count remains one.
    The 45 Mamba layers are reorganized into 12 Mamba groups (ceil of 45/4), with each group containing 4 Mamba layers. In this case, the effective block_size becomes 1/4 of the original.
    Regarding the "padding" you mentioned—using the same example, the final 3 Attention layers are packed into a single unit, which results in the KV-cache space for 1 layer being "wasted." However, the logic does not distinguish between fully-packed and partially-packed groups; they are treated uniformly.

I hope this clears up your confusion! Let me know if you have further questions.

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

Hi @Etelis , thanks for your reply!
You're right—the locations you pointed out are exactly where page_size_bytes should incorporate the pack_size semantics.
Regarding kv-connector and offloading, thanks for the clarification as I wasn't as familiar with those parts. One thing I’d like to double-check: you mentioned that these components only require the per-layer page_size_bytes to function correctly. However, with the packing logic, the KV-cache tensor becomes non-contiguous because the stride for the num_blocks dimension increases. Is the existing logic already compatible with this non-contiguous layout?

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 30, 2026

I hope this clears up your confusion! Let me know if you have further questions.

Thanks! This is great :)
So with your example, we will have 4 KVCacheTensors, right?
And each tensor will be shared by ~12 mamba layers + ~1 pack of 4 attention layers?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @peakcrosser7.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
@peakcrosser7
Copy link
Copy Markdown
Contributor Author

So with your example, we will have 4 KVCacheTensors, right? And each tensor will be shared by ~12 mamba layers + ~1 pack of 4 attention layers?

@orozery Yes, you're right!

@xhdidi
Copy link
Copy Markdown

xhdidi commented Apr 2, 2026

We tested Qwen3.5-35B-A3B, TP2 and found that when mamba-num-attn-pages=16, sending duplicate requests resulted in normal cache hit rate. However, when mamba-num-attn-pages=8, sending a duplicate request resulted in a prefix cache hit rate of only 1%. What could be the reason for this?

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

Hi @xhdidi , thanks for your feedback!
Could you share what the resulting block_size is for each setting, and what the typical range of your request lengths looks like?
Theoretically, a larger mamba-num-attn-pages leads to a smaller block_size, which should generally improve the cache hit rate. If the block_size at 8 is still larger than most of your requests, you might experience a lower hit rate.

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

Hi @orozery ,

I’ve looked into using backend.get_kv_cache_stride_order(include_num_layers_dimension=True), and the logic seems more complex than initially expected.

Current Observations: It appears that the KV-Cache layout returned when include_num_layers_dimension=True does not always meet the necessary requirements.

Taking FlashAttentionBackend as an example:

 class FlashAttentionBackend(AttentionBackend):
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)
      
    @staticmethod
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets
        # us from `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
            return (2, 0, 1, 3, 4, 5)
        elif cache_layout == "NHD":
            # (2, num_blocks, block_size, num_kv_heads, head_size)
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
            return (2, 4, 0, 1, 3, 5)
        elif cache_layout == "HND":
            # (2, num_blocks, num_kv_heads, block_size, head_size)
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order
  • When include_num_layers_dimension=False: Regardless of whether the layout is NHD or HND, the first two dimensions of the returned stride order are always (2, num_blocks). This aligns with our expectations for physical memory allocation. Crucially, the hybrid Mamba logic (previously in _update_hybrid_attention_mamba_layout(), now moved to _get_hybrid_attention_mamba_layout() in this PR) requires these "2" and "num_blocks" dimensions to be the primary dimensions to adjust strides correctly.
  • When include_num_layers_dimension=True: My expectation was that the new num_layers dimension would appear at index 0 or immediately after num_blocks, keeping other dimensions consistent.
    For NHD, the shape becomes (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size), which works for packing Attention layers.
    However, for HND, the shape becomes (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size). Here, num_kv_heads is shifted ahead of "num_layers" and "2" dimensions.
    While this HND layout is likely intentional, it appears to conflict with the requirements for hybrid Mamba models in _get_hybrid_attention_mamba_layout() and the changes in this PR.

Proposed Alternative:
Instead of relying on the backend's include_num_layers_dimension, what if we stick with include_num_layers_dimension=False but manually prepend a num_layers dimension?
By default, num_layers would be 1. In _get_hybrid_attention_mamba_layout(), we set this dimension to attn_pack_size and adjust the strides accordingly.
Finally, we slice the KV-Cache for each Attention layer using the attn_pack_idx.
The implementation would look something like this:

# (2, num_blocks, block_size, num_kv_heads, head_size)
kv_cache_shape = attn_backend.get_kv_cache_shape(
    kernel_num_blocks,
    kernel_block_size,
    kv_cache_spec.num_kv_heads,
    kv_cache_spec.head_size,
    cache_dtype_str=self.cache_config.cache_dtype,
)
try:
    kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
    assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
    kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
# (2, num_blocks, num_kv_heads, block_size, head_size)
kv_cache_shape = tuple(
    kv_cache_shape[i] for i in kv_cache_stride_order
)
# Add num_layers dimension (default=1)
# (num_layers=1, 2, num_blocks, num_kv_heads, block_size, head_size)
kv_cache_shape = (1,) + kv_cache_shape
kv_cache_stride = tuple(torch.empty(kv_cache_shape).stride())
num_layers_dim = 0
# Maintain original KV shape view.
inv_order = [
    kv_cache_stride_order.index(i)
    for i in range(len(kv_cache_stride_order))
]
if has_mamba:
    # Set num_layers=attn_pack_size and adjust strides
    # Return new strides and current Attention layer index
    # (num_layers=attn_pack_size, 2, num_blocks, num_kv_heads, block_size, head_size)
    kv_cache_stride, layer_index = (
        mamba_utils.get_hybrid_attention_mamba_layout(...)
    )
kv_caches[layer_name] = torch.as_strided(
    raw_tensor.view(dtype),
    size=kv_cache_shape,
    stride=kv_cache_stride,
  )[layer_index].permute(*inv_order)

What are your thoughts on this issue and the proposed design?
cc @heheda12345

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 5, 2026

  • However, for HND, the shape becomes (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size). Here, num_kv_heads is shifted ahead of "num_layers" and "2" dimensions.
    While this HND layout is likely intentional, it appears to conflict with the requirements for hybrid Mamba models in _get_hybrid_attention_mamba_layout() and the changes in this PR.

This was actually the reason I wanted to use include_layers_dim=True, since it will put num_heads dimension before num_layers.
This physical ordering helps KV connectors to optimize transfers between vLLM instances using different tensor_parallel_count configuration.
I would prefer the packing layout suggested here to support it as well.
We can do that in a follow-up as well, or by introducing a separate allocation path like we do in KVConnectorModelRunnerMixin.allocate_uniform_kv_caches, but it will require more work than doing it upfront in this PR.

@peakcrosser7
Copy link
Copy Markdown
Contributor Author

I would prefer the packing layout suggested here to support it as well. We can do that in a follow-up as well, or by introducing a separate allocation path like we do in KVConnectorModelRunnerMixin.allocate_uniform_kv_caches, but it will require more work than doing it upfront in this PR.

@orozery Thanks for the response! I agree with your point.

While enabling include_layers_dim=True would likely benefit KV connector transfers, the current logic isn't yet ready for a straightforward switch. This might be better addressed as a dedicated task in a future PR.
That said, I think adding a num_layers dimension is an excellent design choice. Compared to using storage_offsets, it makes the layout changes for packed Attention layers much more intuitive.

@heheda12345, what's your take? Do you think it's necessary to add a num_layers dimension to the KV-Cache layout?

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@mergify mergify Bot removed the needs-rebase label Apr 7, 2026
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @peakcrosser7.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants