diff --git a/tests/test_v1_worker.py b/tests/test_v1_worker.py index 7b0185f0..2d24833f 100644 --- a/tests/test_v1_worker.py +++ b/tests/test_v1_worker.py @@ -12,6 +12,7 @@ pytest.importorskip("vllm", reason="vllm not installed") from vllm_metal.stt.policy import STT_SCHED_AVAILABLE_BYTES # noqa: E402 +from vllm_metal.v1 import model_runner as mr # noqa: E402 from vllm_metal.v1.worker import MetalWorker # noqa: E402 @@ -178,6 +179,32 @@ def test_hybrid_adds_linear_state(self) -> None: sdpa_bytes = 2 * 8 * 2048 * 4 * 256 * 2 assert result == sdpa_bytes + linear_bytes + def test_linear_cache_bytes_uses_float32_recurrent(self) -> None: + runner = mr.MetalModelRunner.__new__(mr.MetalModelRunner) + runner.model_args = {"full_attention_interval": 2} + runner.kv_cache_dtype = mx.float16 + runner.linear_conv_kernel_dim = 3 + runner.linear_conv_dim = 5 + runner.linear_num_v_heads = 2 + runner.linear_value_head_dim = 7 + runner.linear_key_head_dim = 11 + runner.num_linear_layers = 3 + + conv_bytes = ( + (runner.linear_conv_kernel_dim - 1) + * runner.linear_conv_dim + * mx.float16.size + ) + recurrent_bytes = ( + runner.linear_num_v_heads + * runner.linear_value_head_dim + * runner.linear_key_head_dim + * mx.float32.size + ) + expected = runner.num_linear_layers * (conv_bytes + recurrent_bytes) + + assert runner.linear_cache_bytes_per_slot() == expected + def test_block_alignment_rounds_up_token_count(self) -> None: """When block_size doesn't divide max_model_len evenly, the token count must be rounded up to the next block boundary so that the diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index a7bf3c93..59ec3e4e 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -1176,6 +1176,8 @@ def linear_cache_bytes_per_slot(self) -> int: if self.kv_cache_dtype is None: raise RuntimeError("KV cache dtype not initialized; load_model() first") dtype_size = self.kv_cache_dtype.size + # GDN recurrent state is always float32 (see GDNPagedStateCache). + recurrent_dtype_size = mx.float32.size conv_bytes = ( (self.linear_conv_kernel_dim - 1) * self.linear_conv_dim * dtype_size ) @@ -1183,7 +1185,7 @@ def linear_cache_bytes_per_slot(self) -> int: self.linear_num_v_heads * self.linear_value_head_dim * self.linear_key_head_dim - * dtype_size + * recurrent_dtype_size ) return self.num_linear_layers * (conv_bytes + recurrent_bytes)