Skip to content

[Bugfix] Exclude O(1) Mamba groups from hybrid KV cache token capacity#40384

Open
jhsmith409 wants to merge 2 commits intovllm-project:mainfrom
jhsmith409:fix/hybrid-max-num-kv-tokens
Open

[Bugfix] Exclude O(1) Mamba groups from hybrid KV cache token capacity#40384
jhsmith409 wants to merge 2 commits intovllm-project:mainfrom
jhsmith409:fix/hybrid-max-num-kv-tokens

Conversation

@jhsmith409
Copy link
Copy Markdown
Contributor

Summary

On hybrid attention + Mamba models (Qwen3-Next, Qwen3.5/3.6 MoE hybrids, RecurrentGemma, Jamba, Zamba2, Nemotron-H, …), the reported GPU KV cache token capacity and the scheduler's max_num_kv_tokens are deflated by the number of Mamba groups, which in the default mamba_cache_mode='none' (and 'align') pre-reserve a fixed number of blocks and do not scale with sequence length.

Both _report_kv_cache_config() (vllm/v1/core/kv_cache_utils.py) and Scheduler.__init__ (vllm/v1/core/sched/scheduler.py) currently compute per-token capacity as:

num_tokens = num_blocks // len(kv_cache_config.kv_cache_groups) * min_block_size

For a typical hybrid with one attention group and N Mamba groups, that's off by a factor of (1 + N) / 1 — 2× understatement for the common case, 4× for Nemotron-H-style 1 attn + 3 mamba groups. The max_num_kv_tokens number is what sizes the routed_experts buffer for MoE and what the scheduler believes is its budget; getting this wrong shows up as (a) misleading boot-time logs and (b) over-conservative scheduling of concurrent requests on the very models (hybrid MoE) where extra concurrency is the whole point.

Fix

  • Factor the filter into a tiny helper token_capacity_kv_cache_groups(vllm_config, kv_cache_config) in kv_cache_utils.py that returns only the groups that scale with sequence length (attention always, Mamba only when mamba_cache_mode == 'all').
  • Use that helper in both _report_kv_cache_config and Scheduler.__init__.
  • Fall back to all groups if the filter would produce an empty list (preserves dense-model and Mamba-only paths).

The helper is exported (no leading underscore) because scheduler.py imports it; if the maintainers would rather keep it scheduler-local or inline it, happy to rewrite.

Why this is not duplicating an existing PR

Checked on 2026-04-20:

Test plan + results

python -m pytest tests/v1/core/test_kv_cache_utils.py -v

No existing test exercises the filter directly; I'll follow up with a small unit test in a separate commit once PR feedback lands (or now, if reviewers prefer). Syntax check (python -m py_compile) is clean; ruff check/ruff format were not available in my local sandbox but the edits follow the surrounding style.

End-to-end verification on our runtime stack (cu130-nightly + TurboQuant hybrid overlay + RedHatAI/Qwen3.6-35B-A3B-NVFP4, turboquant_k8v4, max_model_len=8192, max_num_seqs=8, --gpu-memory-utilization=0.85, torch.compile + cudagraph):

  • Before: INFO kv_cache_utils.py:1363] GPU KV cache size: 143,936 tokens
  • After: I'll reply in a follow-up comment with the log delta from a re-run on this branch; expect a clean 2× jump on 1 attn + 1 mamba group.

AI-assist disclosure (per AGENTS.md)

Change was drafted with help from Claude (Anthropic); human submitter reviewed every line end-to-end and understands the hybrid KV cache group semantics. Original bug identification and filter design credit to @Sandermage — ref his issue #40124 tracking table (patch 9) and the ai-jz/vllm#1 approach he references. Co-authored-by: trailers included.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added v1 bug Something isn't working labels Apr 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the token_capacity_kv_cache_groups utility function to correctly identify KV cache groups that scale with sequence length, specifically handling Mamba and Attention specifications. This function is integrated into the KV cache reporting and scheduler initialization to improve the accuracy of token capacity calculations in hybrid models. I have no feedback to provide.

@jhsmith409
Copy link
Copy Markdown
Contributor Author

E2E log delta on a hybrid MoE

Applied this patch on top of our TurboQuant-hybrid test overlay (JartX#10 LCM fallback + #40074 + #39748 + #40092, all on cu130-nightly), model RedHatAI/Qwen3.6-35B-A3B-NVFP4 (40 layers = 30 DeltaNet linear-attention + 10 full-attention), --kv-cache-dtype=turboquant_k8v4, --max-model-len=8192, --max-num-seqs=8, --gpu-memory-utilization=0.85, torch.compile + CUDA graphs.

# before this PR
INFO kv_cache_utils.py:1363] GPU KV cache size: 143,936 tokens

# with this PR applied
INFO kv_cache_utils.py:1386] GPU KV cache size: 362,608 tokens

That's a ~2.5× increase in reported capacity on a 1-attention + Mamba-group(s) hybrid, matching the expected correction (the prior log was dividing the shared block pool by every group instead of just the groups that scale with sequence length). No startup-time or runtime regressions observed; completions still succeed (verified with the same 28 k-token-prompt stress harness I used for #40074 and #39748).

Note: our setup runs with enable_return_routed_experts=False, so the scheduler-side max_num_kv_tokens change didn't fire in this particular run; the reporting-side change in _report_kv_cache_config is what produced the 143 936 → 362 608 delta. On a config that does enable routed-experts capture, the same helper governs the routed_experts buffer sizing.

@Sandermage
Copy link
Copy Markdown
Contributor

Sandermage commented Apr 20, 2026

Hi @jhsmith409, great PR — thanks for the clean upstream version!

(Small disclaimer: I'm from Ukraine and my English is still a work in progress, so I'm using AI to help with translation. Hope it reads okay!)

Really nice that you extracted the filter into token_capacity_kv_cache_groups() — much cleaner than my inline version in our monkey-patcher. The empty-fallback guard for mamba-only models is also something I hadn't handled, good catch.

Here's my A5000 TP=2 datapoint on Qwen3.6-35B-A3B-FP8 (turboquant_k8v4 KV, mamba_cache_mode='align', max_model_len=163840, gpu_memory_utilization=0.905):

Metric Value
GPU KV cache size 946,656 tokens
Maximum concurrency @ 163,840 tokens/req 5.18×
Without the filter (pre-fix estimate, same hw) ~473k tokens / ~2.59×

This model has 1 attention + 1 Mamba group, so the factor on Ampere comes out to roughly 2× — consistent in direction with the 2.52× you measured on 5090 + NVFP4.

I've been running the same logic as a runtime monkey-patch (Patch 9) since v5.7, based on ai-jz/vllm#1. Happy to drop it in favor of this PR the moment it lands. If it helps reviewers, I can re-run with #40384 cherry-picked cleanly on current main (my patch removed) to give an independent A5000 datapoint — just let me know.

@Sandermage
Copy link
Copy Markdown
Contributor

Update — I just switched my local patch to apply your exact architecture (helper function + clean callsites), instead of my previous inline implementation. Full mirror of your PR diff: imports of AttentionSpec/MambaSpec, injection of token_capacity_kv_cache_groups() helper before _report_kv_cache_config, and helper calls from both _report_kv_cache_config and Scheduler.__init__.

A5000 TP=2 clean-room numbers (Qwen3.6-35B-A3B-FP8 + turboquant_k8v4, max_model_len=163840, mamba_cache_mode='align', gpu_memory_utilization=0.905):

Metric Value
GPU KV cache size 946,656 tokens
Max concurrency @ 163,840 tokens/req 5.18×
Stability over 10 requests 10/10 OK, avg 142.28 t/s (stdev 0.154)
Decode @ 32k / 128k / 160k 106.7 / 56.4 / 49.1 t/s

Identical behavior to my previous inline patch (which was logically equivalent), with per-context deltas ≤0.6% — well within the ~0.8%/step thermal boost drift we see on workstation A5000s between back-to-back runs.

So the design in this PR works cleanly on Ampere SM 8.6 with the FP8 hybrid model too. Happy to add a formal "Tested on A5000 TP=2" line in the PR body if you'd like, just let me know.

@Sandermage
Copy link
Copy Markdown
Contributor

Hi @jhsmith409 — quick cross-reference for completeness:

(Small disclaimer: I'm from Ukraine and my English is still a work in progress, so I'm using AI to help with translation. Hope it reads okay!)

While auditing this PR's reach across the codebase, I noticed there's a third site with the same bug class that this PR doesn't cover: vllm/v1/worker/gpu_model_runner.py:6823-6836 inside init_routed_experts_capturer(). It uses the same num_blocks // num_groups * min_block_size formula and would benefit from the same token_capacity_kv_cache_groups() filter.

Active path is gated on enable_return_routed_experts=True, so it doesn't fire in default serving — but if scheduler max_num_kv_tokens and worker buffer size diverge, you'd get the IndexError class trace shown in PR #37118.

PR #37118 (@allgather, Nov 2025) already proposes a fix for that worker site, but with a different formula (num_blocks * attn_group.block_size — full address space). For a single-attention-group hybrid like our Qwen3.6-A3B both formulas produce the same value, but on Nemotron-H-style multi-attention-group hybrids they'd diverge.

Probably worth a maintainer aligning the two PRs so we don't end up with inconsistent treatment of the same bug class. Just flagging so it's visible in this thread.

Also filed #40417 about extending the helper to exclude SlidingWindowSpec / ChunkedLocalAttentionSpec — same class, different specs.

@jhsmith409
Copy link
Copy Markdown
Contributor Author

Thanks for the cross-reference @Sandermage — and no worries about the translation, the writeup is perfectly clear. I use Claude to review what I write as well.

I dug through #37118 and #40417 before replying. My read is that the worker site you flagged is best left to #37118, because the two PRs are answering structurally different questions even though they both touch a max_num_kv_tokens attribute named the same way:

  • [Bugfix] Exclude O(1) Mamba groups from hybrid KV cache token capacity #40384 (this PR) sizes the scheduler's per-token capacity divisor — "how many groups contribute KV state that grows with token count." O(1) Mamba groups should be excluded from the divisor because they don't scale with sequence length; that's why token_capacity_kv_cache_groups() filters them out.
  • [Bugfix] out-of-bounds error for routed experts capture #37118 sizes the routed-experts side buffer that gets indexed by an attention group's slot_mapping. That buffer needs the full address space of the chosen attention group (num_blocks * attn_group.block_size), not the per-token capacity. Otherwise the slot_mapping indices walk off the end — exactly the IndexError in HollowMan6's trace.

On a single-attention-group hybrid like Qwen3.6-A3B the two formulas collapse to the same number, but on a multi-attention-group hybrid (Nemotron-H-style) they diverge in opposite directions: my helper would under-size the routed-experts buffer (re-introducing the #37118 bug class), and HollowMan6's full-address-space formula would over-bound the scheduler's per-token planning capacity. So I think the right outcome is to land both, scoped to their respective sites, rather than try to share a helper.

Flagging this on #37118 as well so HollowMan6 sees the cross-reference.

On #40417 — agreed, the AttentionSpec isinstance check is too broad; SlidingWindowSpec / ChunkedLocalAttentionSpec are bounded just like Mamba (in the relevant sense) and shouldn't be in the per-token divisor. I'd lean toward your option (a) — tighten the positive list to FullAttentionSpec — because it's compositional with the existing class hierarchy (TQFullAttentionSpec / MLAAttentionSpec / SinkFullAttentionSpec all inherit from FullAttentionSpec) and adding a future bounded subtype that inherits from AttentionSpec directly would be auto-handled. Happy to fold that into this PR as a second commit if a maintainer thinks it should ship together; otherwise I'll send it as a follow-up after #40384 lands so the scope here stays minimal.

Let me know which way you'd prefer to go on the #40417 fold.

@Sandermage
Copy link
Copy Markdown
Contributor

Ah, thanks for the detailed breakdown @jhsmith409 — you're right and I was wrong to frame these as "same bug, different sites should align formulas". The scheduler's per-token divisor and the routed-experts side buffer are structurally different concerns: O(1) groups should be excluded from the first (per-token planning capacity) but must be included in the second (index-able address space). On a Nemotron-H-style multi-attention-group hybrid the two formulas correctly diverge.

Updating my mental model. Landing both PRs scoped to their respective sites is clearly the right outcome — I'll add a note on #37118 making the same correction so the record is clean.

On #40417 — great, FullAttentionSpec positive list it is. Fine either as a follow-up or folded here, whatever is easier for review. I'll prep the one-liner change locally and wait for your signal.

@TCRnext
Copy link
Copy Markdown

TCRnext commented Apr 28, 2026

Nice work! On 35B-A3B,GPU KV cache size improved from 288,288 tokens to 1,157,376 tokens on 3090*2.Decode speeds get little higher when high concurrency with large context length. @mgoin Could you please take a look at this PR when you have a moment?

Jim Smith and others added 2 commits April 29, 2026 12:33
On hybrid attention + Mamba models (Qwen3-Next, Qwen3.5/3.6 MoE hybrids,
RecurrentGemma, Jamba, Zamba2, Nemotron-H, …), `kv_cache_config.kv_cache_groups`
contains one attention group plus one (or several) Mamba groups. In the
default `mamba_cache_mode='none'` (and `'align'`) the Mamba state is O(1)
per request and pre-reserves a fixed number of blocks; only the attention
groups scale with sequence length.

Both `_report_kv_cache_config()` and `Scheduler.__init__()` currently
compute per-token capacity as `num_blocks // len(all_groups) * block_size`,
so each extra Mamba group deflates the reported capacity and the scheduler's
`max_num_kv_tokens` (used to size the `routed_experts` buffer for MoE).

A Qwen3.6-35B-A3B hybrid run on cu130-nightly with turboquant_k8v4 KV cache
and max_model_len=8192 reports "GPU KV cache size: 146,704 tokens" today;
with this fix the attention group alone owns the shared block pool, so the
reported count matches what the scheduler can actually allocate.

Factors the filter into a small helper (`token_capacity_kv_cache_groups`)
and uses it in both sites. Falls back to the current behavior if the filter
would produce an empty list, preserving dense-model and Mamba-only paths.

Credit to @Sandermage (ref vllm-project#40124
patch 9, based on ai-jz/vllm#1) for identifying the bug and the filter
design.

Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Covers the six meaningful shapes of kv_cache_config the helper sees:

- dense (all AttentionSpec groups) → unchanged
- hybrid with mamba_cache_mode='none' or 'align' → Mamba groups dropped
- hybrid with mamba_cache_mode='all' → Mamba kept (scales with seq len)
- Mamba-only model under 'none' → filter would empty, fallback kicks in
- 1 attn + 3 Mamba (Nemotron-H shape) → single-group result
- empty config → empty list (no IndexError)

Signed-off-by: Jim Smith <jhsmith0@me.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants