Skip to content

[Bugfix] Fix KV cache overestimation for hybrid Mamba/attention model…#37124

Closed
swtb3 wants to merge 7 commits intovllm-project:mainfrom
swtb3:fix/hybrid-mamba-kv-cache-reporting
Closed

[Bugfix] Fix KV cache overestimation for hybrid Mamba/attention model…#37124
swtb3 wants to merge 7 commits intovllm-project:mainfrom
swtb3:fix/hybrid-mamba-kv-cache-reporting

Conversation

@swtb3
Copy link
Copy Markdown

@swtb3 swtb3 commented Mar 15, 2026

Purpose

Qwen3.5 mixes 24 GatedDeltaNet layers (O(1) state) with 8 full attention layers (O(n) KV per token). vLLM treated all layers uniformly, causing ~7x memory overestimation (7.57 GiB allocated, ~1 GiB used).

Reporting fixes:

  • get_max_concurrency_for_kv_cache_config: sum per-group costs independently instead of multiplying the largest cost by the largest group count
  • _report_kv_cache_config: count tokens from attention groups only
  • _max_memory_usage_bytes_from_groups: sum actual per-group memory usage instead of calling get_uniform_page_size (which crashes with non-uniform sizes)

Allocation fix:

  • New elif branch in get_kv_cache_config_from_groups for mixed Mamba+attention: gives each layer its own tensor at its natural page size
  • Skip page size unification in get_kv_cache_groups so Mamba layers keep their small page size instead of being padded to match attention

All changes gated behind _has_mixed_mamba_attention() — no impact on pure-attention, pure-Mamba, or attention+sliding_window models.

Test plan

  • pytest tests/v1/core/test_kv_cache_utils.py -v -s — 57/57 passed
  • 9 new tests covering Qwen3.5 architecture across bf16/fp16/fp8-kv
  • Pre-commit hooks all passed (ruff check, ruff format, mypy, typos)

Test Result

pass

Notes


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added v1 bug Something isn't working labels Mar 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the KV cache memory overestimation for hybrid Mamba/attention models like Qwen3.5. The changes correctly introduce specialized logic to handle the different memory requirements of Mamba and attention layers independently, which resolves the reported issue. The modifications to memory allocation, concurrency estimation, and reporting are well-implemented and gated behind a check for mixed-architecture models, minimizing the risk of regressions for other model types. The accompanying tests are thorough and provide good coverage for the new logic, ensuring the fix is robust. Overall, this is a high-quality contribution that significantly improves memory efficiency for this class of models.

Comment thread vllm/v1/core/kv_cache_utils.py Outdated
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 16, 2026
@mergify mergify Bot added the kv-connector label Mar 16, 2026
@swtb3 swtb3 force-pushed the fix/hybrid-mamba-kv-cache-reporting branch from 2264e4a to 2f538a8 Compare March 16, 2026 23:50
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 17, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @swtb3.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 17, 2026
@repne
Copy link
Copy Markdown

repne commented Mar 17, 2026

@swtb3 happy to test this after the conflicts are solved, thank you!

swtb3-ryder and others added 3 commits March 17, 2026 15:11
vllm-project#37121)

Qwen3.5 mixes 24 GatedDeltaNet layers (O(1) state) with 8 full attention
layers (O(n) KV per token). vLLM treated all layers uniformly, causing ~7x
memory overestimation (7.57 GiB allocated, ~1 GiB used).

Reporting fixes:
- get_max_concurrency_for_kv_cache_config: sum per-group costs independently
  instead of multiplying the largest cost by the largest group count
- _report_kv_cache_config: count tokens from attention groups only
- _max_memory_usage_bytes_from_groups: sum actual per-group memory usage
  instead of calling get_uniform_page_size (which crashes with non-uniform sizes)

Allocation fix:
- New elif branch in get_kv_cache_config_from_groups for mixed Mamba+attention:
  gives each layer its own tensor at its natural page size
- Skip page size unification in get_kv_cache_groups so Mamba layers keep their
  small page size instead of being padded to match attention

All changes gated behind _has_mixed_mamba_attention() — no impact on
pure-attention, pure-Mamba, or attention+sliding_window models.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: swtb <135991636+swtb3@users.noreply.github.com>

Signed-off-by: swtb-ryder <sbayly@ryderarchitecture.com>
…all"

  When prefix caching is enabled (mamba_cache_mode="all"), Mamba states
  are cached per-token and scale with sequence length. Only exclude
  Mamba groups from token capacity reporting in "none" and "align" modes.
  Co-authored-by: Claude <noreply@anthropic.com>                                                                                     Signed-off-by: swtb <135991636+swtb3@users.noreply.github.com>

Signed-off-by: swtb-ryder <sbayly@ryderarchitecture.com>
…rid models

  Hybrid Mamba/attention models (e.g., Qwen3.5) suffered OOM and massive
  memory waste because Mamba layers shared the attention BlockPool. Each
  Mamba layer's tensor was sized for N blocks (the full pool) but only
  used 1 block per request, wasting ~399 MB per layer.

  Decouple Mamba from the shared BlockPool by giving MambaManager a                                                                  self-managed compact block space (0..C-1 where C = max concurrent
  requests). Freed memory goes to attention, yielding ~47x token capacity
  improvement on Qwen3.5-4B.

  - Add `mamba_num_blocks` field to `KVCacheConfig`
  - Implement compact allocation branch in `get_kv_cache_config_from_groups`
  - Add compact mode to `MambaManager` with self-managed block lifecycle
  - Update cross-worker tensor scaling for separate Mamba/attention pools
  - Update concurrency calculation for compact allocation
  - Preserve shared-pool behavior for `mamba_cache_mode="all"`
  - Add 13 new tests covering allocation, manager, and edge cases

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: swtb3 <135991636+swtb3@users.noreply.github.com>
@swtb3 swtb3 force-pushed the fix/hybrid-mamba-kv-cache-reporting branch from 2f538a8 to a3ad054 Compare March 17, 2026 15:13
@swtb3
Copy link
Copy Markdown
Author

swtb3 commented Mar 17, 2026

@swtb3 happy to test this after the conflicts are solved, thank you!

done

@mergify mergify Bot removed the needs-rebase label Mar 17, 2026
@repne
Copy link
Copy Markdown

repne commented Mar 18, 2026

Without PR: GPU KV cache size: 101,600 tokens
With PR: GPU KV cache size: 416,000 tokens
...and no OOM this time.

I am however experiencing a dramatic drop in performance:

vllm bench serve \
  --backend vllm \
  --model Qwen/Qwen3.5-27B-FP8 \
  --endpoint /v1/completions \
  --dataset-name sharegpt \
  --dataset-path ~/datasets/ShareGPT_V3_unfiltered_cleaned_split.json \
  --num-warmups 10 \
  --num-prompts 100 \
  --seed 42

Before PR

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Benchmark duration (s):                  92.89
Total input tokens:                      22053
Total generated tokens:                  25102
Request throughput (req/s):              1.08
Output token throughput (tok/s):         270.23
Peak output token throughput (tok/s):    192.00
Peak concurrent requests:                100.00
Total token throughput (tok/s):          507.64
---------------Time to First Token----------------
Mean TTFT (ms):                          24907.82
Median TTFT (ms):                        19508.09
P99 TTFT (ms):                           62480.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          100.69
Median TPOT (ms):                        97.66
P99 TPOT (ms):                           205.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           205.14
Median ITL (ms):                         175.41
P99 ITL (ms):                            978.14
---------------Speculative Decoding---------------
Acceptance rate (%):                     68.03
Acceptance length:                       2.36
Drafts:                                  10627
Draft tokens:                            21254
Accepted tokens:                         14460
Per-position acceptance (%):
  Position 0:                            78.74
  Position 1:                            57.33
==================================================

After PR

============ Serving Benchmark Result ============
Successful requests:                     100
Failed requests:                         0
Benchmark duration (s):                  380.51
Total input tokens:                      22053
Total generated tokens:                  25336
Request throughput (req/s):              0.26
Output token throughput (tok/s):         66.58
Peak output token throughput (tok/s):    31.00
Peak concurrent requests:                100.00
Total token throughput (tok/s):          124.54
---------------Time to First Token----------------
Mean TTFT (ms):                          185376.52
Median TTFT (ms):                        191196.04
P99 TTFT (ms):                           379407.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.60
Median TPOT (ms):                        14.10
P99 TPOT (ms):                           24.51
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.60
Median ITL (ms):                         32.56
P99 ITL (ms):                            33.32
---------------Speculative Decoding---------------
Acceptance rate (%):                     68.79
Acceptance length:                       2.38
Drafts:                                  10658
Draft tokens:                            21316
Accepted tokens:                         14664
Per-position acceptance (%):
  Position 0:                            79.40
  Position 1:                            58.19
==================================================

This is what I am running:

NCCL_P2P_LEVEL=SYS
NCCL_IB_DISABLE=1 \
NCCL_NET_GDR_LEVEL=SYS \
NCCL_MIN_NCHANNELS=4 \
NCCL_ALLOC_P2P_NET_LL_BUFFERS=1 \
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1 \
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 \
vllm serve \
    Qwen/Qwen3.5-27B-FP8 \
    --tensor-parallel-size 2 \
    --gpu-memory-utilization 0.94 \
    --max-model-len 262144 \
    --max-num-seqs 32 \
    --max-num-batched-tokens 8192 \
    --block-size 32 \
    --language-model-only \
    -O3 \
    --enable-auto-tool-choice \
    --reasoning-parser qwen3 \
    --tool-call-parser qwen3_coder \
    --attention-backend TRITON_ATTN \
    --enable-prefix-caching \
    --speculative-config.method mtp \
    --speculative-config.num_speculative_tokens 2 \
    --speculative-config.rejection_sample_method probabilistic \
    --load-format instanttensor

Tested against main 09e4576

@swtb3
Copy link
Copy Markdown
Author

swtb3 commented Mar 18, 2026

Before

---------------Time to First Token----------------
Mean TTFT (ms): 24907.82
Median TTFT (ms): 19508.09
P99 TTFT (ms): 62480.57
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 100.69
Median TPOT (ms): 97.66
P99 TPOT (ms): 205.47
---------------Inter-token Latency----------------
Mean ITL (ms): 205.14
Median ITL (ms): 175.41
P99 ITL (ms): 978.14

After

---------------Time to First Token----------------
Mean TTFT (ms): 185376.52
Median TTFT (ms): 191196.04
P99 TTFT (ms): 379407.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 14.60
Median TPOT (ms): 14.10
P99 TPOT (ms): 24.51
---------------Inter-token Latency----------------
Mean ITL (ms): 32.60
Median ITL (ms): 32.56
P99 ITL (ms): 33.32

So TTFT explodes, TPOT/ITL comes down. suspiciously a factor of ~7 difference in each case. Wonder if my changes have caused some sequential queuing

@tdoublep
Copy link
Copy Markdown
Member

Due to rebase error, every update on this PR is pinging all vLLM maintainers. Could you please close it an open a new one so that the right subset of reviewers are getting notified? Thanks

1 similar comment
@tdoublep
Copy link
Copy Markdown
Member

Due to rebase error, every update on this PR is pinging all vLLM maintainers. Could you please close it an open a new one so that the right subset of reviewers are getting notified? Thanks

@swtb3
Copy link
Copy Markdown
Author

swtb3 commented Mar 18, 2026

Superseded by: #37429

@repne

@swtb3 swtb3 closed this Mar 18, 2026
@github-project-automation github-project-automation Bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Mar 18, 2026
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA Mar 18, 2026
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci/build documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) nvidia qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants