Skip to content

[Bugfix] Disable cross-layer KV cache for MLA attention backends#37090

Merged
orozery merged 5 commits intovllm-project:mainfrom
haosdent:fix-37032
Mar 16, 2026
Merged

[Bugfix] Disable cross-layer KV cache for MLA attention backends#37090
orozery merged 5 commits intovllm-project:mainfrom
haosdent:fix-37032

Conversation

@haosdent
Copy link
Copy Markdown
Contributor

@haosdent haosdent commented Mar 15, 2026

Purpose

Fixes #37032

MLA models (e.g., GLM-4.7-Flash) produce garbage output with --kv-offloading-size because cross-layer KV cache allocation creates non-contiguous per-layer views. MLA decode kernels assume contiguous block layout, so they read wrong memory for block_id > 0.

Fix (3 parts)

  1. mla_attention.py, indexer.py: MLA backends return identity permutation (0, 1, 2, 3) from get_kv_cache_stride_order(include_num_layers_dimension=True), keeping num_layers first in physical layout to signal cross-layer is unsupported.

  2. kv_connector_model_runner_mixin.py: use_uniform_kv_cache() checks stride_order[0] == 0 (layers dim first) and skips cross-layer allocation, falling back to per-layer contiguous KV caches.

  3. offloading_connector.py: register_kv_caches() uses AttentionLayerBase instead of Attention to match both Attention and MLAAttention layers.

Test Plan

pytest tests/v1/kv_offload/test_cpu_gpu.py::test_mla_backend_rejects_cross_layer_kv_cache tests/v1/kv_offload/test_cpu_gpu.py::test_deepseek_v32_indexer_rejects_cross_layer_kv_cache -v

End-to-end with deepseek-ai/DeepSeek-V2-Lite-Chat + KV offloading on NVIDIA GB10.

Test Result

Unit tests: 2/2 passed. End-to-end:

  • Without fix: ' 夹缝缝缝缝缝缝缝缝缝缝缝缝缝缝缝缝缝缝' (garbage)
  • With fix: ' 5, 6, 7, 8, 9, 10.\n' (correct)

MLA attention kernels assume contiguous per-layer KV cache views.
When KV offloading enables cross-layer blocks, per-layer views become
non-contiguous (stride(0) includes a num_layers factor), causing decode
kernels to read from wrong memory locations and produce garbage output.

Have MLA backends raise NotImplementedError from
get_kv_cache_stride_order(include_num_layers_dimension=True) so that
use_uniform_kv_cache() falls back to per-layer registration with
contiguous tensors.

Fixes: vllm-project#37032

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
@mergify mergify bot added v1 bug Something isn't working labels Mar 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug causing garbage output with MLA models when KV offloading is enabled. The root cause is the incorrect handling of non-contiguous memory layouts by MLA attention kernels when cross-layer KV cache is used. The fix correctly disables this feature for MLA backends by raising a NotImplementedError, which leverages an existing fallback mechanism to use per-layer contiguous KV caches. The changes are well-targeted, clearly explained, and are accompanied by new tests that verify the corrected behavior. The solution appears robust and effectively resolves the issue.

@mergify mergify bot added the kv-connector label Mar 15, 2026
@haosdent haosdent changed the title [WIP][Bugfix] Disable cross-layer KV cache for MLA attention backends [Bugfix] Disable cross-layer KV cache for MLA attention backends Mar 15, 2026
@haosdent haosdent marked this pull request as ready for review March 15, 2026 15:38
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 15, 2026

Hi @haosdent, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

The OffloadingConnectorWorker.register_kv_caches() used `Attention` as
the layer type filter, which excluded MLA attention layers (MLAAttention
is not a subclass of Attention). This caused a KeyError when MLA models
fell back to per-layer KV caches. Use `AttentionLayerBase` (the common
base class) instead to match both Attention and MLAAttention layers.

Co-authored-by: Claude

Signed-off-by: haosdent <haosdent@gmail.com>
@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 16, 2026

@haosdent thanks for this fix!
I was able to reproduce the issue.
I was not aware that MLA kernels did not support cross layers.

However, instead of returning NotImplementedError, I think we should return the identity permutation (e.g. (0, 1, 2, ...)).
This will yield contiguous per-layer tensors.
For DeepseekV32IndexerBackend this is already the case, so no change is needed there.

@haosdent
Copy link
Copy Markdown
Contributor Author

However, instead of returning NotImplementedError, I think we should return the identity permutation (e.g. (0, 1, 2, ...)).
This will yield contiguous per-layer tensors.

Thanks @orozery ! Let me update the pull request

@haosdent
Copy link
Copy Markdown
Contributor Author

However, instead of returning NotImplementedError, I think we should return the identity permutation (e.g. (0, 1, 2, ...)).
This will yield contiguous per-layer tensors.

Hi, @orozery, I checked again, and this approach does not seem to work. The identity permutation still allocates a single cross-layer tensor. Slicing tensor[i] gives a non-contiguous view where stride(0) = num_layers * block_size * head_size. MLA kernels assume stride(0) = block_size * head_size which no permutation fixes this. We need to opt out of cross-layer allocation entirely, which is what NotImplementedError does.

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 16, 2026

Hi, @orozery, I checked again, and this approach does not seem to work. The identity permutation still allocates a single cross-layer tensor. Slicing tensor[i] gives a non-contiguous view where stride(0) = num_layers * block_size * head_size. MLA kernels assume stride(0) = block_size * head_size which no permutation fixes this. We need to opt out of cross-layer allocation entirely, which is what NotImplementedError does.

It will allocate a single cross-layer tensor, but will create per-layer views kv_caches: dict[str, torch.tensor] which are contiguous.
I tested it now but it actually fails since the offloading connector currently assumes that when using cross layers the first dimension is num_blocks.

I think that if get_kv_cache_stride_order starts with (0, ...), e.g. the layers dimension is first we can abort creating a cross-layers tensor, here:

return len(kv_cache_stride_order) == len(kv_cache_shape) + 1

This also aligns with #34742, aiming to remove all exception throwing by get_kv_cache_stride_order.

MLA backends now return identity permutation (0,1,2,3) from
get_kv_cache_stride_order(include_num_layers_dimension=True) instead of
raising NotImplementedError. use_uniform_kv_cache() checks
stride_order[0] == 0 (layers dim first) to skip cross-layer allocation.
This aligns with the direction of removing exceptions from
get_kv_cache_stride_order.

Signed-off-by: haosdent <haosdent@gmail.com>
@haosdent
Copy link
Copy Markdown
Contributor Author

Thanks @orozery , I referred to #34742 and updated the PR, can you help review again? Many thanks!

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 16, 2026

@haosdent LGTM!
Can we just move the new unit tests to a new e.g. tests/v1/kv_connector/unit/test_kv_cache_layout.py ?

Signed-off-by: haosdent <haosdent@gmail.com>
@haosdent
Copy link
Copy Markdown
Contributor Author

Can we just move the new unit tests to a new e.g. tests/v1/kv_connector/unit/test_kv_cache_layout.py ?

Done, thanks @orozery

Copy link
Copy Markdown
Collaborator

@orozery orozery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @haosdent !

@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 16, 2026
@orozery orozery merged commit ca1954d into vllm-project:main Mar 16, 2026
67 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
andylolu2 pushed a commit to andylolu2/vllm that referenced this pull request Mar 18, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…m-project#37090)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: GLM 4.7-flash returns gibberish when native KV cache offloading is on

2 participants