Skip to content

[Bugfix] Fix KV cache undercount in MLX path for large block sizes#229

Merged
ericcurtin merged 1 commit intovllm-project:mainfrom
samwarren:fix/block-aligned-kv-estimate
Apr 5, 2026
Merged

[Bugfix] Fix KV cache undercount in MLX path for large block sizes#229
ericcurtin merged 1 commit intovllm-project:mainfrom
samwarren:fix/block-aligned-kv-estimate

Conversation

@samwarren
Copy link
Copy Markdown
Contributor

Summary

  • Fix _one_sequence_kv_bytes in MetalWorker to use block-aligned token counts, matching the upstream scheduler's cdiv(max_model_len, block_size) * page_size_bytes accounting
  • Prevents server startup failure on Mamba-hybrid models (e.g. Granite 4.0-H) where block_size is padded to 400 to match the mamba page size

Problem

_one_sequence_kv_bytes computes KV cache bytes using max_model_len directly (e.g. 2048 tokens). But the upstream _check_enough_kv_cache_memory in vLLM core uses block-aligned sizes: cdiv(2048, 400) = 6 blocks = 2400 tokens. This causes "needed > available" even though the intent is to report exactly enough for one sequence:

ValueError: To serve at least one request with the models's max seq len (2048),
(0.18 GiB KV cache is needed, which is larger than the available KV cache memory
(0.16 GiB).

For the default block_size=16, cdiv(2048, 16) * 16 = 2048 — no padding, so this never triggers. It only manifests with large block sizes like 400, which occurs on Mamba-hybrid models (GraniteMoeHybridForCausalLM) where the attention block size is padded to match the mamba page size.

Fix

Round max_model_len up to the nearest block_size boundary in _one_sequence_kv_bytes:

block_size = self.vllm_config.cache_config.block_size
max_tokens = -(-self.model_config.max_model_len // block_size) * block_size

Reproduction

vllm serve mlx-community/granite-4.0-h-tiny-3bit-MLX --max-model-len 2048 --enforce-eager
# Fails with KV cache memory error

# After fix:
# Server starts successfully

Test plan

  • Added test_block_alignment_rounds_up_token_count — verifies block-aligned calculation with block_size=400
  • Updated existing test_non_hybrid_counts_all_layers and test_hybrid_adds_linear_state to include vllm_config.cache_config.block_size in mocks
  • All 10 tests in test_v1_worker.py pass
  • Verified vllm serve mlx-community/granite-4.0-h-tiny-3bit-MLX --max-model-len 4096 --enforce-eager starts and serves requests on M4 Pro 48GB

`_one_sequence_kv_bytes` used `max_model_len` directly as the token count,
but the upstream `_check_enough_kv_cache_memory` uses block-aligned sizes
via `cdiv(max_model_len, block_size) * page_size_bytes`. When `block_size`
is large (e.g. 400 for Mamba-hybrid models where the attention block size
is padded to match the mamba page size), the rounding overhead causes the
needed memory to exceed the reported available memory, failing server
startup with:

  ValueError: 0.34 GiB KV cache is needed, which is larger than the
  available KV cache memory (0.31 GiB)

This affects models like Granite 4.0-H (GraniteMoeHybridForCausalLM) which
mix Mamba and attention layers, triggering block_size=400 alignment.

Fix: round `max_model_len` up to the nearest `block_size` boundary in
`_one_sequence_kv_bytes` so both sides use the same token count.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Samuel Warren <samuel@sketchpro.ai>
@samwarren samwarren force-pushed the fix/block-aligned-kv-estimate branch from c919906 to d0086fd Compare April 5, 2026 04:14
@ericcurtin ericcurtin merged commit 89f9ce3 into vllm-project:main Apr 5, 2026
5 checks passed
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…llm-project#229)

## Summary

- Fix `_one_sequence_kv_bytes` in `MetalWorker` to use block-aligned
token counts, matching the upstream scheduler's `cdiv(max_model_len,
block_size) * page_size_bytes` accounting
- Prevents server startup failure on Mamba-hybrid models (e.g. Granite
4.0-H) where `block_size` is padded to 400 to match the mamba page size

## Problem

`_one_sequence_kv_bytes` computes KV cache bytes using `max_model_len`
directly (e.g. 2048 tokens). But the upstream
`_check_enough_kv_cache_memory` in vLLM core uses block-aligned sizes:
`cdiv(2048, 400) = 6 blocks = 2400 tokens`. This causes "needed >
available" even though the intent is to report exactly enough for one
sequence:

```
ValueError: To serve at least one request with the models's max seq len (2048),
(0.18 GiB KV cache is needed, which is larger than the available KV cache memory
(0.16 GiB).
```

For the default `block_size=16`, `cdiv(2048, 16) * 16 = 2048` — no
padding, so this never triggers. It only manifests with large block
sizes like 400, which occurs on Mamba-hybrid models
(GraniteMoeHybridForCausalLM) where the attention block size is padded
to match the mamba page size.

## Fix

Round `max_model_len` up to the nearest `block_size` boundary in
`_one_sequence_kv_bytes`:

```python
block_size = self.vllm_config.cache_config.block_size
max_tokens = -(-self.model_config.max_model_len // block_size) * block_size
```

## Reproduction

```bash
vllm serve mlx-community/granite-4.0-h-tiny-3bit-MLX --max-model-len 2048 --enforce-eager
# Fails with KV cache memory error

# After fix:
# Server starts successfully
```

## Test plan

- [x] Added `test_block_alignment_rounds_up_token_count` — verifies
block-aligned calculation with `block_size=400`
- [x] Updated existing `test_non_hybrid_counts_all_layers` and
`test_hybrid_adds_linear_state` to include
`vllm_config.cache_config.block_size` in mocks
- [x] All 10 tests in `test_v1_worker.py` pass
- [x] Verified `vllm serve mlx-community/granite-4.0-h-tiny-3bit-MLX
--max-model-len 4096 --enforce-eager` starts and serves requests on M4
Pro 48GB

Signed-off-by: Samuel Warren <samuel@sketchpro.ai>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Alex-ai-future added a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
Reverts to the PR vllm-project#229 design: report one max-length sequence of KV cache
for the MLX path, instead of a fraction of total Metal memory.

Rationale (from LxYuan0420's review):
- The previous change (gpu_memory_utilization * total_memory) altered
  scheduler semantics without explicit policy discussion.
- PR vllm-project#229's one-sequence estimate ensures conservative admission control.
- MLX's make_prompt_cache() dynamically allocates per request, so we only
  need to report enough for one sequence.

This keeps the scheduler behavior consistent with upstream expectations
and avoids over-committing memory.

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Signed-off-by: Alex <alex.tech.lab@outlook.com>
Alex-ai-future added a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…ence estimate

Updates test expectations to match the implementation changes:
- test_hybrid_with_paged_attention_logs_warning: Verify warning is logged
  instead of ValueError (PR vllm-project#235 made hybrid + paged attention supported)
- test_determine_available_memory_single_sequence_mode: Restore to test
  one-sequence estimate (PR vllm-project#229 design) instead of 80% memory fraction

Also fixes test fixtures to include required vllm_config attribute.

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Signed-off-by: Alex <alex.tech.lab@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants