[Bugfix] Fix ~7x KV cache memory overestimation for hybrid Mamba+Attention models#1
[Bugfix] Fix ~7x KV cache memory overestimation for hybrid Mamba+Attention models#1justtestingthingsx wants to merge 2 commits into
Conversation
…ntion models For hybrid models like Qwen3.5 (24 GDN + 8 attention layers), the KV cache profiler treats all layers uniformly, padding Mamba's small O(1) state to match attention's O(n) KV page size. This wastes ~85% of memory per Mamba block, causing a ~7x overestimation of required KV cache memory. On a 10GB RTX 3080 with Qwen3.5-9B GPTQ, this overestimation leaves no room for KV cache after model weights are loaded (8.19 GiB model + overestimated KV = -0.67 GiB available). Fix: - Detect hybrid Mamba+Attention models and skip page size unification - Allocate per-group tensors at natural page sizes instead of padding - Report token capacity from attention groups only (Mamba is O(1)) - Fix max_concurrency calculation for non-uniform page sizes - Fix max_memory_usage estimation for hybrid groups All changes are additive code paths gated behind hybrid model detection. Non-hybrid models are completely unaffected. Fixes: vllm-project#37121 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Hi! I ran into similar issue, and tried your fix, but we had some issues. Report from Claude Opus:
disable_hybrid_kv_cache_manager → unify_hybrid_kv_cache_specs() converts MambaSpecs in-place
1 The allocation logic in get_kv_cache_config_from_groups (the elif _is_hybrid_kv_cache_groups branch) looks correct in principle — it sums natural page sizes per block and allocates per-group tensors. But it never gets to show its effect because the groups arrive pre-padded. Happy to test again if you update the grouping path. note from me: i am also happy to test again :) Have a nice day! |
|
Hey, thanks for testing and the detailed report! You're right on both issues — the detection fires too late and the page size unification undoes the fix. Honestly we haven't gotten it working on our end either (RTX 3080 10GB). Since we opened this, vLLM v0.17.0/v0.17.1 landed with Qwen3.5 support and some KV cache changes, so the code around Will ping you when there's something to test again. If you beat us to a fix feel free to open a PR upstream — the more people poking at this the better. |
…+Attention KV cache
Fix two bugs in the hybrid Mamba+Attention KV cache handling:
1. Detection ordering: Move _is_hybrid_mamba_attention() check to the top
of get_kv_cache_groups(), before unify_hybrid_kv_cache_specs() runs.
Previously, unification could modify specs in-place or raise ValueError
on Mamba+Attention combos before hybrid detection had a chance to run.
2. Per-layer tensor allocation: Replace the uniform page size grouping
(_get_kv_cache_groups_uniform_page_size) with a new dedicated function
(_get_kv_cache_groups_hybrid_mamba_attention) that:
- Groups layers by spec type (one group per distinct KVCacheSpec)
- Preserves each group's natural page size
- Allocates per-layer tensors in get_kv_cache_config_from_groups so
each layer gets its own memory (critical because layers in the same
group share a block table but need independent state)
The old code routed hybrid models through the uniform page size path,
which padded Mamba groups (~1.1 MiB natural) to match attention groups
(~3.2 MiB), inflating memory allocation by ~3x per layer (~7x total
for models like Qwen3.5 with 24 Mamba + 8 attention layers).
Also updates _max_memory_usage_bytes_from_groups and
get_max_concurrency_for_kv_cache_config to correctly account for
per-layer tensor costs in hybrid models.
Non-hybrid models (pure attention, pure Mamba, attention+sliding window)
are completely unaffected -- all changes are gated behind hybrid
Mamba+Attention detection.
See vllm-project#37121
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Update: v2 fix pushed to this branch — addresses both issues you found. Changes in v2:
Test results:
Note: we also found that vLLM v0.17 V1 engine needs We also rebased onto upstream main on branch Also worth noting: upstream PR vllm-project#37429 by @swtb3 takes a different approach (compact Mamba allocation) and showed +27% KV tokens. Both approaches solve the core issue differently — ours preserves per-group natural page sizes, theirs does dedicated Mamba block pool management. Let us know if you can retest! |
Summary
Impact
Test plan
Fixes: vllm-project#37121
🤖 Generated with Claude Code