Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/test_v1_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
pytest.importorskip("vllm", reason="vllm not installed")

from vllm_metal.stt.policy import STT_SCHED_AVAILABLE_BYTES # noqa: E402
from vllm_metal.v1 import model_runner as mr # noqa: E402
from vllm_metal.v1.worker import MetalWorker # noqa: E402


Expand Down Expand Up @@ -178,6 +179,32 @@ def test_hybrid_adds_linear_state(self) -> None:
sdpa_bytes = 2 * 8 * 2048 * 4 * 256 * 2
assert result == sdpa_bytes + linear_bytes

def test_linear_cache_bytes_uses_float32_recurrent(self) -> None:
runner = mr.MetalModelRunner.__new__(mr.MetalModelRunner)
runner.model_args = {"full_attention_interval": 2}
runner.kv_cache_dtype = mx.float16
runner.linear_conv_kernel_dim = 3
runner.linear_conv_dim = 5
runner.linear_num_v_heads = 2
runner.linear_value_head_dim = 7
runner.linear_key_head_dim = 11
runner.num_linear_layers = 3

conv_bytes = (
(runner.linear_conv_kernel_dim - 1)
* runner.linear_conv_dim
* mx.float16.size
)
recurrent_bytes = (
runner.linear_num_v_heads
* runner.linear_value_head_dim
* runner.linear_key_head_dim
* mx.float32.size
)
expected = runner.num_linear_layers * (conv_bytes + recurrent_bytes)

assert runner.linear_cache_bytes_per_slot() == expected

def test_block_alignment_rounds_up_token_count(self) -> None:
"""When block_size doesn't divide max_model_len evenly, the token
count must be rounded up to the next block boundary so that the
Expand Down
4 changes: 3 additions & 1 deletion vllm_metal/v1/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,14 +1176,16 @@ def linear_cache_bytes_per_slot(self) -> int:
if self.kv_cache_dtype is None:
raise RuntimeError("KV cache dtype not initialized; load_model() first")
dtype_size = self.kv_cache_dtype.size
# GDN recurrent state is always float32 (see GDNPagedStateCache).
recurrent_dtype_size = mx.float32.size
conv_bytes = (
(self.linear_conv_kernel_dim - 1) * self.linear_conv_dim * dtype_size
)
recurrent_bytes = (
self.linear_num_v_heads
* self.linear_value_head_dim
* self.linear_key_head_dim
* dtype_size
* recurrent_dtype_size
)
return self.num_linear_layers * (conv_bytes + recurrent_bytes)

Expand Down
Loading