Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion tests/test_v1_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_get_supported_tasks_delegates_to_runner_capability(self) -> None:


class TestOneSequenceKvBytes:
"""_one_sequence_kv_bytes must account for hybrid linear state."""
"""_one_sequence_kv_bytes must account for hybrid linear state and block alignment."""

def test_non_hybrid_counts_all_layers(self) -> None:
# Arrange
Expand All @@ -129,6 +129,10 @@ def test_non_hybrid_counts_all_layers(self) -> None:
)
worker = _make_worker(model_runner, use_paged_attention=False)
worker.model_config = SimpleNamespace(max_model_len=2048)
# block_size=16 divides 2048 evenly, so no padding
worker.vllm_config = SimpleNamespace(
cache_config=SimpleNamespace(block_size=16)
)

# Act
result = MetalWorker._one_sequence_kv_bytes(worker)
Expand All @@ -151,10 +155,48 @@ def test_hybrid_adds_linear_state(self) -> None:
)
worker = _make_worker(model_runner, use_paged_attention=False)
worker.model_config = SimpleNamespace(max_model_len=2048)
worker.vllm_config = SimpleNamespace(
cache_config=SimpleNamespace(block_size=16)
)

# Act
result = MetalWorker._one_sequence_kv_bytes(worker)

# Assert — SDPA bytes + linear state
sdpa_bytes = 2 * 8 * 2048 * 4 * 256 * 2
assert result == sdpa_bytes + linear_bytes

def test_block_alignment_rounds_up_token_count(self) -> None:
"""When block_size doesn't divide max_model_len evenly, the token
count must be rounded up to the next block boundary so that the
reported bytes match the scheduler's block-aligned accounting.

This reproduces the KV cache startup failure seen with Mamba-hybrid
models (e.g. Granite 4.0-H) where the attention block_size is padded
to 400 to match the mamba page size.
"""
import mlx.core as mx

model_runner = SimpleNamespace(
is_hybrid=False,
num_layers=4,
num_kv_heads=4,
head_dim=64,
kv_cache_dtype=mx.float16,
)
worker = _make_worker(model_runner, use_paged_attention=False)
worker.model_config = SimpleNamespace(max_model_len=2048)
# block_size=400 (Mamba-hybrid): ceil(2048/400)=6, 6*400=2400 tokens
worker.vllm_config = SimpleNamespace(
cache_config=SimpleNamespace(block_size=400)
)

result = MetalWorker._one_sequence_kv_bytes(worker)

# Should use 2400 tokens (block-aligned), not 2048
aligned_tokens = 2400 # ceil(2048/400) * 400
expected = 2 * 4 * aligned_tokens * 4 * 64 * 2
assert result == expected
# Verify this is strictly more than the unaligned calculation
unaligned = 2 * 4 * 2048 * 4 * 64 * 2
assert result > unaligned
16 changes: 14 additions & 2 deletions vllm_metal/v1/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,13 @@ def _get_model_memory_usage(self) -> int:
return 0

def _one_sequence_kv_bytes(self) -> int:
"""Bytes for one max-length sequence of cache state."""
"""Bytes for one max-length sequence of cache state.

Uses block-aligned token count so the estimate matches the upstream
``_check_enough_kv_cache_memory`` calculation, which rounds
``max_model_len`` up to the nearest ``block_size`` boundary via
``cdiv(max_model_len, block_size) * page_size_bytes``.
"""
runner = self.model_runner
if runner.kv_cache_dtype is None:
raise RuntimeError("KV cache dtype not initialized; runner.load_model()")
Expand All @@ -367,10 +373,16 @@ def _one_sequence_kv_bytes(self) -> int:
num_kv_layers = (
runner.num_sdpa_layers if runner.is_hybrid else runner.num_layers
)

# Round token count up to block boundary to match the scheduler's
# block-aligned memory accounting.
block_size = self.vllm_config.cache_config.block_size
max_tokens = -(-self.model_config.max_model_len // block_size) * block_size

sdpa_kv_bytes = (
2
* num_kv_layers
* self.model_config.max_model_len
* max_tokens
* runner.num_kv_heads
* runner.head_dim
* dtype_size
Expand Down
Loading