Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/test_paged_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,93 @@ def test_clear(self):
stats = cache.get_stats()
# After clear, null block is still allocated (vLLM style)
assert stats["allocated_blocks"] == 1 # only null block

def test_reconstructs_hybrid_cache_from_boundary_snapshot(self):
from mlx_lm.models.cache import ArraysCache, KVCache
import mlx.core as mx

from vllm_mlx.paged_cache import PagedCacheManager
from vllm_mlx.prefix_cache import BlockAwarePrefixCache

paged_manager = PagedCacheManager(block_size=4, max_blocks=10)
cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager)

tokens = list(range(8))
kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3)
kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3)
linear_state = [
mx.arange(1 * 3 * 8).reshape(1, 3, 8),
mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4),
]
extracted = [
{
"state": (kv_keys, kv_values),
"meta_state": "",
"class_ref": KVCache,
"class_name": "KVCache",
},
{
"state": linear_state,
"meta_state": "",
"class_ref": ArraysCache,
"class_name": "ArraysCache",
},
]

block_table = cache.store_cache("req-1", tokens, extracted)
first_block = paged_manager.allocated_blocks[block_table.block_ids[0]]
last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]]

assert first_block.cache_data[0] is not None
assert first_block.cache_data[1] is None
assert last_block.cache_data[1] is not None

reconstructed = cache.reconstruct_cache(block_table)

assert reconstructed is not None
assert isinstance(reconstructed[0], KVCache)
assert isinstance(reconstructed[1], ArraysCache)
assert reconstructed[0].state[0].tolist() == kv_keys.tolist()
assert reconstructed[0].state[1].tolist() == kv_values.tolist()
assert reconstructed[1].state[0].tolist() == linear_state[0].tolist()
assert reconstructed[1].state[1].tolist() == linear_state[1].tolist()

def test_rejects_hybrid_prefix_without_boundary_snapshot(self):
from mlx_lm.models.cache import ArraysCache, KVCache
import mlx.core as mx

from vllm_mlx.paged_cache import BlockTable, PagedCacheManager
from vllm_mlx.prefix_cache import BlockAwarePrefixCache

paged_manager = PagedCacheManager(block_size=4, max_blocks=10)
cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager)

extracted = [
{
"state": (
mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3),
mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3),
),
"meta_state": "",
"class_ref": KVCache,
"class_name": "KVCache",
},
{
"state": [
mx.arange(1 * 3 * 8).reshape(1, 3, 8),
mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4),
],
"meta_state": "",
"class_ref": ArraysCache,
"class_name": "ArraysCache",
},
]

block_table = cache.store_cache("req-1", list(range(8)), extracted)
prefix_table = BlockTable(
request_id="req-prefix",
block_ids=[block_table.block_ids[0]],
num_tokens=4,
)

assert cache.reconstruct_cache(prefix_table) is None
47 changes: 47 additions & 0 deletions tests/test_tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for tokenizer utility helpers."""

import platform
import sys
from unittest.mock import patch

import pytest

pytestmark = pytest.mark.skipif(
sys.platform != "darwin" or platform.machine() != "arm64",
reason="Requires Apple Silicon",
)


class TestLoadModelWithFallback:
def test_returns_successful_load_result(self):
from vllm_mlx.utils.tokenizer import load_model_with_fallback

fake_model = object()
fake_tokenizer = object()

with patch("mlx_lm.load", return_value=(fake_model, fake_tokenizer)) as load:
model, tokenizer = load_model_with_fallback("mlx-community/Qwen3.5-4B")

load.assert_called_once()
assert model is fake_model
assert tokenizer is fake_tokenizer

def test_uses_tokenizer_fallback_for_tokenizer_errors(self):
from vllm_mlx.utils.tokenizer import load_model_with_fallback

fake_model = object()
fake_tokenizer = object()

with patch(
"mlx_lm.load",
side_effect=ValueError("Tokenizer class Foo does not exist"),
), patch(
"vllm_mlx.utils.tokenizer._load_with_tokenizer_fallback",
return_value=(fake_model, fake_tokenizer),
) as fallback:
model, tokenizer = load_model_with_fallback("example/model")

fallback.assert_called_once_with("example/model")
assert model is fake_model
assert tokenizer is fake_tokenizer
Loading