Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/test_cloud_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,37 @@ def test_ignores_unsupported_kwargs(self):

assert "unsupported_param" not in kwargs

def test_passes_through_response_format(self):
"""response_format is forwarded to litellm (regression: was silently dropped)."""
from vllm_mlx.cloud_router import CloudRouter

router = CloudRouter(cloud_model="test-model", threshold=1000)
messages = [{"role": "user", "content": "Hello"}]
rf = {"type": "json_schema", "json_schema": {"name": "out", "schema": {"type": "object"}}}

kwargs = router._build_call_kwargs(
messages=messages,
stream=False,
response_format=rf,
)

assert kwargs["response_format"] == rf

def test_response_format_none_omitted(self):
"""response_format=None is not included in kwargs."""
from vllm_mlx.cloud_router import CloudRouter

router = CloudRouter(cloud_model="test-model", threshold=1000)
messages = [{"role": "user", "content": "Hello"}]

kwargs = router._build_call_kwargs(
messages=messages,
stream=False,
response_format=None,
)

assert "response_format" not in kwargs


class TestCloudRouterLazyImport:
"""Tests for CloudRouter lazy litellm import."""
Expand Down
87 changes: 87 additions & 0 deletions tests/test_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,93 @@ def test_trie_structure(self, cache_manager):
assert cache_124 == ["cache_124"]


class TestPrefixCachePinning:
"""Tests for pin/unpin prefix functionality."""

@pytest.fixture
def mock_model(self):
return MagicMock()

def test_pin_survives_store(self, mock_model):
"""Pinned entry stays pinned after store_cache re-stores same key (regression)."""
mgr = PrefixCacheManager(mock_model, max_entries=3)
mgr.store_cache([1, 2], ["cache_a"])
assert mgr.pin_prefix([1, 2]) is True

# Re-store same tokens — pin must not be silently undone
mgr.store_cache([1, 2], ["cache_b"])

# Fill remaining capacity and overflow — pinned entry must survive
mgr.store_cache([3], ["cache_c"])
mgr.store_cache([4], ["cache_d"])
mgr.store_cache([5], ["cache_e"]) # would evict [1,2] if it were in LRU

cache, _ = mgr.fetch_cache([1, 2])
assert cache is not None, "Pinned entry was evicted after store_cache"

def test_pin_survives_touch(self, mock_model):
"""Pinned entry stays pinned after fetch_cache touches it (regression)."""
mgr = PrefixCacheManager(mock_model, max_entries=3)
mgr.store_cache([1, 2], ["cache_a"])
mgr.pin_prefix([1, 2])

# Access pinned entry — must not re-add to LRU
mgr.fetch_cache([1, 2])

# Overflow — pinned entry must survive
mgr.store_cache([3], ["c"])
mgr.store_cache([4], ["d"])
mgr.store_cache([5], ["e"])
mgr.store_cache([6], ["f"])

cache, _ = mgr.fetch_cache([1, 2])
assert cache is not None, "Pinned entry was evicted after fetch_cache touch"

def test_pin_capacity_guard(self, mock_model):
"""pin_prefix rejects when pinned count reaches max_size (regression)."""
mgr = PrefixCacheManager(mock_model, max_entries=2)
mgr.store_cache([1], ["a"])
mgr.store_cache([2], ["b"])
assert mgr.pin_prefix([1]) is True
assert mgr.pin_prefix([2]) is True

# Now at capacity — next pin must fail
mgr.store_cache([3], ["c"]) # won't fit in LRU, but trie entry exists
# Actually [3] can't be stored because LRU+pinned > max_size and LRU is empty
# So test with existing entries:
# Unpin one, store a new entry, try to pin 3 total
mgr.unpin_prefix([2])
mgr.store_cache([3], ["c"])
assert mgr.pin_prefix([3]) is True # now 2 pinned (max_size=2)

mgr.store_cache([4], ["d"])
assert mgr.pin_prefix([4]) is False, "Pin should fail when at capacity"

def test_unpin_restores_lru(self, mock_model):
"""Unpinned entry becomes evictable again."""
mgr = PrefixCacheManager(mock_model, max_entries=2)
mgr.store_cache([1], ["a"])
mgr.store_cache([2], ["b"])
mgr.pin_prefix([1])

mgr.unpin_prefix([1])

# Now [1] is back in LRU and can be evicted
mgr.store_cache([3], ["c"])
mgr.store_cache([4], ["d"])

cache, _ = mgr.fetch_cache([1])
assert cache is None, "Unpinned entry should be evictable"

def test_clear_resets_pinned(self, mock_model):
"""clear() removes pinned entries too."""
mgr = PrefixCacheManager(mock_model, max_entries=5)
mgr.store_cache([1], ["a"])
mgr.pin_prefix([1])
mgr.clear()
assert len(mgr) == 0


class TestSchedulerIntegration:
"""Test integration with scheduler."""

Expand Down
23 changes: 23 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,29 @@ def test_rate_limiter_window_cleanup(self):
allowed, _ = limiter.is_allowed("test_client")
assert allowed is True

def test_rate_limiter_stale_key_purge(self):
"""Stale client keys are purged when dict exceeds 100 entries (regression)."""
from vllm_mlx.server import RateLimiter
import time

limiter = RateLimiter(requests_per_minute=10, enabled=True)

# Seed 101 unique clients with expired timestamps
old_time = time.time() - 120 # 2 minutes ago (outside window)
with limiter._lock:
for i in range(101):
limiter._requests[f"stale_client_{i}"] = [old_time]

assert len(limiter._requests) == 101

# One more request triggers purge (len > 100)
limiter.is_allowed("new_client")

# Stale keys should be purged
assert len(limiter._requests) < 101
# new_client should still be present
assert "new_client" in limiter._requests


# =============================================================================
# Integration Tests (require running server)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,40 @@ class ToolDef(BaseModel):
self._make_tool_calls(),
tools,
)


# ---------------------------------------------------------------------------
# extract_json_from_response guard tests
# ---------------------------------------------------------------------------


class TestExtractJsonFromResponse:
"""Tests for extract_json_from_response showing why guard is needed."""

def test_extracts_json_from_reasoning_text(self):
"""Correctly extracts JSON from reasoning prefix."""
from vllm_mlx.api.utils import extract_json_from_response

text = 'Let me think... {"result": 42}'
assert extract_json_from_response(text) == '{"result": 42}'

def test_corrupts_plain_text_ending_with_json(self):
"""Without guard, plain text ending with JSON-like braces gets corrupted.

This documents why server.py wraps the call with `if response_format`.
"""
from vllm_mlx.api.utils import extract_json_from_response

# Plain text that happens to end with balanced braces
plain = 'The config looks like {"debug": true}'
result = extract_json_from_response(plain)
# The function extracts '{"debug": true}' — losing the prefix
assert result == '{"debug": true}'
assert result != plain

def test_returns_original_when_no_json(self):
"""Returns original text when no JSON structure found."""
from vllm_mlx.api.utils import extract_json_from_response

text = "Hello, world!"
assert extract_json_from_response(text) == text
42 changes: 42 additions & 0 deletions tests/test_simple_engine_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,45 @@ def test_set_false(self):
engine.preserve_native_tool_format = True
engine.preserve_native_tool_format = False
assert engine.preserve_native_tool_format is False


# ---------------------------------------------------------------------------
# _inject_shared_model config propagation
# ---------------------------------------------------------------------------


class TestInjectSharedModelConfig:
"""Tests for _inject_shared_model using engine config values."""

@pytest.mark.asyncio
async def test_inject_propagates_engine_config(self):
"""Injected model uses engine's config, not hardcoded defaults (regression)."""
engine = SimpleEngine(
model_name="test",
prefill_step_size=4096,
kv_bits=4,
kv_group_size=128,
)

mock_model = MagicMock()
mock_tokenizer = MagicMock()

await engine._inject_shared_model(mock_model, mock_tokenizer)

assert engine._model.prefill_step_size == 4096
assert engine._model.kv_bits == 4
assert engine._model.kv_group_size == 128

@pytest.mark.asyncio
async def test_inject_default_config(self):
"""Injected model uses default config when engine uses defaults."""
engine = SimpleEngine(model_name="test")

mock_model = MagicMock()
mock_tokenizer = MagicMock()

await engine._inject_shared_model(mock_model, mock_tokenizer)

assert engine._model.prefill_step_size == 2048
assert engine._model.kv_bits is None
assert engine._model.kv_group_size == 64
8 changes: 4 additions & 4 deletions vllm_mlx/api/guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def generate_json(
json_schema: dict[str, Any],
max_tokens: int = 256,
temperature: float = 0.7,
) -> str:
) -> str | None:
"""
Generate JSON output constrained to a schema.

Expand All @@ -139,7 +139,7 @@ def generate_json(
temperature: Sampling temperature

Returns:
JSON string matching the schema
JSON string matching the schema, or None on failure
"""
# Convert schema to Pydantic model
pydantic_model = json_schema_to_pydantic(json_schema)
Expand Down Expand Up @@ -172,7 +172,7 @@ def generate_json_object(
prompt: str,
max_tokens: int = 256,
temperature: float = 0.7,
) -> str:
) -> str | None:
"""
Generate any valid JSON object.

Expand All @@ -182,7 +182,7 @@ def generate_json_object(
temperature: Sampling temperature

Returns:
JSON string
JSON string, or None on failure
"""
try:
from outlines import generate
Expand Down
2 changes: 1 addition & 1 deletion vllm_mlx/cloud_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _build_call_kwargs(
for key in (
"temperature", "max_tokens", "top_p", "stop",
"frequency_penalty", "presence_penalty",
"tools", "tool_choice",
"tools", "tool_choice", "response_format",
):
if key in kwargs and kwargs[key] is not None:
call_kwargs[key] = kwargs[key]
Expand Down
6 changes: 6 additions & 0 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,12 @@ async def _inject_shared_model(
self._model.trust_remote_code = self._trust_remote_code
self._model.draft_model_name = self._draft_model_name
self._model.num_draft_tokens = self._num_draft_tokens
self._model.prefill_step_size = self._prefill_step_size
self._model.kv_bits = self._kv_bits
self._model.kv_group_size = self._kv_group_size
self._model._prompt_cache = None
self._model._cached_token_ids = []
self._model._cache_lock = False
self._model.model = model
self._model.tokenizer = tokenizer
self._model.draft_model = None
Expand Down
Loading