diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 1c8ca29b..779d40e4 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -197,6 +197,721 @@ async def fake_stream_chat(*args, **kwargs): assert output.finish_reason == "stop" mock_llm_model.chat.assert_not_called() + @pytest.mark.anyio + async def test_stream_chat_cache_path_accepts_pydantic_message_objects(self): + """`stream_chat`'s declared signature is ``list[dict]`` but real callers + (``server.py``'s streaming endpoint, ``test_server.py``'s direct + invocations) pass Pydantic ``Message`` objects. The system-prefix + KV-cache eligibility check on this path uses ``.get('role')`` / + ``dict(m)`` semantics; without normalisation the iteration raises + ``'Message' object has no attribute 'get'`` before the call ever + reaches the underlying ``stream_generate``.""" + from vllm_mlx.api.models import Message + from vllm_mlx.engine.simple import SimpleEngine + + # ``apply_chat_template`` returns identical strings for both + # probe-divergence renders → boundary stays at 0 → the cache path + # is correctly skipped and execution falls through to + # ``self.stream_generate``. The test's value is asserting no + # ``AttributeError`` leaks out of the message-normalisation step. + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nyou are an assistant<|im_end|>\n" + "<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + captured_stream_generate = [] + + async def fake_stream_generate(*, prompt, **kwargs): + captured_stream_generate.append({"prompt": prompt, "kwargs": kwargs}) + out = MagicMock( + text="hi back", + new_text="hi back", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine._supports_system_kv_cache = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="hi"), + ] + + chunks = [c async for c in engine.stream_chat(messages=messages)] + + # No AttributeError was raised → normalisation worked. + # apply_chat_template was called at least 3 times: once for the + # initial ``prompt`` build and twice more for the Alpha/Bravo + # probe-divergence renders. + assert tokenizer.apply_chat_template.call_count >= 3 + assert len(captured_stream_generate) == 1 + assert chunks and chunks[0].text == "hi back" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_no_system_message(self): + """If the message list has no system role, the cache-eligibility + check must short-circuit ``has_system = False`` without entering the + probe-divergence step or the cache-aware streaming branch.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>user\nhello<|im_end|>\n<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + called_stream_generate = [] + + async def fake_stream_generate(*, prompt, **kwargs): + called_stream_generate.append(prompt) + out = MagicMock( + text="hi", + new_text="hi", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[{"role": "user", "content": "hello"}], + ) + ] + + # No system → cache path skipped → only the initial apply_chat_template + # call (for ``prompt``) happens, no probe renders. + assert tokenizer.apply_chat_template.call_count == 1 + assert called_stream_generate + assert chunks and chunks[0].text == "hi" + + @pytest.mark.anyio + async def test_stream_chat_cache_path_falls_back_when_mlx_raises(self): + """When the cache-aware ``_run_with_cache`` body raises *before* the + first generated token, the path must surface the failure as a + pre-first-token error and fall back to the uncached + ``self.stream_generate`` instead of bubbling the exception out.""" + from vllm_mlx.engine.simple import SimpleEngine + + # Probe-divergence renders that DO diverge, so the cache path is + # entered. The boundary lies after the system block, well past the + # 16-char minimum. + def apply_chat_template_side_effect(messages, **kwargs): + # Find the last user message content to make probes diverge. + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + # Return long enough token lists that the system-prefix slice is + # a proper prefix of the full sequence and ``kv_cache_eligible`` + # becomes True. + tokenizer.encode = MagicMock( + side_effect=[ + list(range(50)), # full prompt + list(range(20)), # system prefix (proper prefix of above) + ] + ) + + model = MagicMock() + model.tokenizer = tokenizer + # ``self._model.model`` is dereferenced inside _run_with_cache. + model.model = MagicMock() + + fallback_calls = [] + + async def fake_stream_generate(*, prompt, **kwargs): + fallback_calls.append(prompt) + out = MagicMock( + text="fallback-response", + new_text="fallback-response", + prompt_tokens=50, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + # Force the cache-aware path to raise before the first emit so we + # exercise the pre-first-token error → uncached fallback branch. + def make_prompt_cache_raises(*args, **kwargs): + raise RuntimeError("simulated mlx-lm failure") + + with ( + patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False), + patch( + "mlx_lm.models.cache.make_prompt_cache", + side_effect=make_prompt_cache_raises, + ), + patch("mlx_lm.sample_utils.make_sampler", return_value=MagicMock()), + ): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # The cache-path failure must NOT propagate; we should see the + # fallback's chunk instead. + assert fallback_calls, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "fallback-response" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_decode_controls_present(self): + """If the request carries ``stop`` or ``logits_processors`` (or any non-default + ``top_k`` / ``min_p`` / ``presence_penalty`` / ``repetition_penalty``), the + cache branch must be skipped so those controls flow through + ``self.stream_generate``. + The cache branch drives ``mlx_lm.stream_generate`` directly with only + prompt/max_tokens/sampler/prompt_cache, silently dropping any other decode + controls. + Gating here keeps cache-eligible and uncached requests on identical decode + semantics.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + # Render contains a system block so ``has_system`` would otherwise send us + # into the cache branch. + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + sentinel_processor = MagicMock(name="logits_processor") + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + stop=["<|im_end|>"], + logits_processors=[sentinel_processor], + ) + ] + + # Cache-path probe (``apply_chat_template`` called a second + third time for + # probe-divergence) must NOT happen — only the initial prompt render. + assert tokenizer.apply_chat_template.call_count == 1 + # The uncached fallback must have been invoked. + # The decode-control kwargs must have been threaded through. + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert fallback_kwargs[0].get("stop") == ["<|im_end|>"] + assert fallback_kwargs[0].get("logits_processors") == [sentinel_processor] + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_takes_cache_path_when_decode_controls_are_no_ops(self): + """server.py always sets ``top_k=0``, ``min_p=0.0``, ``presence_penalty=0.0``, + ``repetition_penalty=1.0`` (no-ops) in ``chat_kwargs``. + The gate must compare against those defaults so the common path still hits the + cache. + Only *active* controls should block.""" + from vllm_mlx.engine.simple import SimpleEngine + + def apply_chat_template_side_effect(messages, **kwargs): + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + tokenizer.encode = MagicMock(side_effect=[list(range(50)), list(range(20))]) + + model = MagicMock() + model.tokenizer = tokenizer + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + # Probe normally runs in start(); short-circuit it here so the + # gate doesn't trip on the synthetic engine state. + engine._supports_system_kv_cache = True + + # No fallback needed since cache path should be exercised. + # Patch _run_blocking_serialized to short-circuit cache execution cleanly + # without needing a real mlx_lm. + async def short_circuit(func, *args, on_cancel=None, **kw): + # Simulate immediate completion with no responses. + # The producer-task harness will fire _emit_done(). + return None + + engine._run_blocking_serialized = short_circuit # type: ignore[method-assign] + + _ = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + top_k=0, + min_p=0.0, + presence_penalty=0.0, + repetition_penalty=1.0, + ) + ] + + # Probe-divergence ran ⇒ apply_chat_template called for prompt + 2 probes. + assert tokenizer.apply_chat_template.call_count == 3 + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_mtp_active(self): + """When ``self._mtp`` is configured, the cache branch must be skipped. + The branch calls ``mlx_lm.stream_generate`` directly with no ``mtp`` / + ``num_draft_tokens`` kwargs, while ``MLXLanguageModel.stream_generate`` + attaches them from ``self._mtp`` / ``self._mtp_num_draft_tokens``. + Running the same request through the cache branch would silently drop + speculative decoding for cache-eligible turns while keeping it on + uncached turns — different engine semantics for the same request.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model", mtp=True, mtp_num_draft_tokens=4) + engine._model = model + engine._loaded = True + # Isolate the gate to the feature under test. + engine._supports_system_kv_cache = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Cache-path probes must NOT run — only the initial prompt render. + assert tokenizer.apply_chat_template.call_count == 1 + # The uncached wrapper must have been invoked. + # MTP kwargs are layered on inside ``MLXLanguageModel.stream_generate``, + # not at this seam, so this test only proves the cache branch was + # bypassed; the wrapper attaches MTP itself when ``self._mtp`` is set. + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_specprefill_loaded(self): + """A loaded SpecPrefill draft model (``self._draft_model is not None``) + triggers ``_stream_generate_specprefill`` routing inside the wrapper for + large prompts. + The cache branch has no equivalent routing, so it must be skipped + whenever a draft model is loaded so all requests go through the + wrapper's SpecPrefill decision.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._draft_model = MagicMock(name="specprefill_draft_model") + engine._loaded = True + engine._supports_system_kv_cache = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_max_kv_size_set(self): + """Configured ``max_kv_size`` caps the prompt cache. + The cache branch builds its cache with ``make_prompt_cache(model)`` + without forwarding ``max_kv_size``, so a non-zero engine-level bound + must force the uncached path.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model", max_kv_size=4096) + engine._model = model + engine._loaded = True + engine._supports_system_kv_cache = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_skips_cache_path_when_model_has_non_kv_cache(self): + """Models whose ``make_prompt_cache`` returns ``RotatingKVCache`` + (sliding-window models like gemma3_text, olmo3, recurrent_gemma) cannot + be safely snapshotted: ``.state`` aliases buffers that + ``update_and_fetch`` mutates in place, so restoring a captured snapshot + on the next turn would silently desynchronize from the running cache. + ``start()`` probes ``make_prompt_cache`` once and sets + ``_supports_system_kv_cache=False`` for those models; the gate must + then skip the cache branch.""" + from vllm_mlx.engine.simple import SimpleEngine + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + "<|im_start|>user\nhello<|im_end|>\n" + "<|im_start|>assistant\n" + ) + tokenizer.bos_token = None + tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + + model = MagicMock() + model.tokenizer = tokenizer + + fallback_kwargs: list[dict] = [] + + async def fake_stream_generate(*, prompt, **kw): + fallback_kwargs.append(kw) + out = MagicMock( + text="ok", + new_text="ok", + prompt_tokens=3, + completion_tokens=1, + finished=True, + finish_reason="stop", + ) + yield out + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + # Simulate the probe finding non-KVCache entries (rotating cache). + engine._supports_system_kv_cache = False + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + chunks = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Cache-path probes must NOT run when the model isn't snapshot-safe. + assert tokenizer.apply_chat_template.call_count == 1 + assert fallback_kwargs, "uncached stream_generate fallback was not invoked" + assert chunks and chunks[0].text == "ok" + + @pytest.mark.anyio + async def test_stream_chat_uses_gate_time_snapshot_under_concurrent_mutation( + self, + ): + """A concurrent MISS that reassigns ``self._system_kv_snapshot`` between + the cache-hit gate (which runs outside ``_run_blocking_serialized``) and + the snapshot restore (which runs inside the serialized worker) must not + corrupt the HIT. The restore must use the snapshot reference captured at + gate time, not re-read ``self._system_kv_snapshot`` later — otherwise a + different system prefix's KV would be loaded under the hash that decided + HIT. + + Simulates the race by reassigning ``engine._system_kv_snapshot`` inside + the ``_run_blocking_serialized`` hook (executed after the gate but + before the worker enters the cache branch), then asserts the restore + loop wrote the gate-time entries, not the post-gate intruder.""" + import hashlib + + from vllm_mlx.engine.simple import SimpleEngine + + # Same template the positive test uses: divergence falls at the user + # content so the detected system prefix is the leading frame up through + # ``<|im_start|>user\n``. + def apply_chat_template_side_effect(messages, **kwargs): + user_content = "" + for m in reversed(messages): + role = ( + m.get("role") if isinstance(m, dict) else getattr(m, "role", None) + ) + if role == "user": + content = ( + m.get("content") + if isinstance(m, dict) + else getattr(m, "content", "") + ) + user_content = content or "" + break + return ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + expected_prefix = ( + "<|im_start|>system\nYou are helpful.<|im_end|>\n" "<|im_start|>user\n" + ) + expected_hash = hashlib.sha256(expected_prefix.encode()).hexdigest()[:16] + + tokenizer = MagicMock() + tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + tokenizer.bos_token = None + # First encode = full prompt tokens; second = system prefix tokens. + # range(20) is a prefix of range(50), so prefix-match validation passes. + tokenizer.encode = MagicMock(side_effect=[list(range(50)), list(range(20))]) + + model = MagicMock() + model.tokenizer = tokenizer + + original_snapshot = [("ORIGINAL_K", "ORIGINAL_V")] + intruder_snapshot = [("INTRUDER_K", "INTRUDER_V")] + + captured_states: list = [] + + class MockCacheEntry: + def __init__(self) -> None: + self._state = None + + @property + def state(self): + return self._state + + @state.setter + def state(self, value) -> None: + captured_states.append(value) + self._state = value + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + engine._supports_system_kv_cache = True + # Pre-seed HIT state matching the divergence-detected prefix. + engine._system_kv_hash = expected_hash + engine._system_kv_token_count = 20 + engine._system_kv_snapshot = original_snapshot + + async def serialized_with_race(func, *args, on_cancel=None, **kw): + # Simulate a concurrent MISS overwriting the instance attribute + # AFTER the gate's HIT decision but BEFORE the worker restores. + engine._system_kv_snapshot = intruder_snapshot + await asyncio.to_thread(func) + return None + + engine._run_blocking_serialized = ( + serialized_with_race # type: ignore[method-assign] + ) + + with ( + patch("mlx_lm.stream_generate", return_value=iter([])), + patch( + "mlx_lm.models.cache.make_prompt_cache", + return_value=[MockCacheEntry()], + ), + patch("mlx_lm.sample_utils.make_sampler"), + ): + _ = [ + c + async for c in engine.stream_chat( + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "hello"}, + ], + ) + ] + + # Restore wrote the gate-time snapshot exactly once. + # If the worker had re-read ``self._system_kv_snapshot`` we would see + # ``("INTRUDER_K", "INTRUDER_V")`` instead — that's the TOCTOU bug. + assert captured_states == [("ORIGINAL_K", "ORIGINAL_V")], ( + "Snapshot restore did not use the gate-time reference; " + f"captured={captured_states}" + ) + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 47343892..14c7760b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -7,6 +7,7 @@ """ import asyncio +import hashlib import logging import os import threading @@ -178,6 +179,12 @@ def __init__( self._system_kv_snapshot = None # List of (keys, values) per backbone layer self._system_kv_hash = None # Hash of system prefix text self._system_kv_token_count = 0 # Tokens in cached prefix + # True only when the model's prompt cache is composed entirely of + # plain ``KVCache`` entries. Sliding-window models (gemma3_text, + # olmo3, recurrent_gemma) return ``RotatingKVCache`` whose ``.state`` + # aliases buffers ``update_and_fetch`` mutates in place — snapshot + # restore would silently desynchronize. Probed once in ``start()``. + self._supports_system_kv_cache: bool = False @property def model_name(self) -> str: @@ -253,6 +260,36 @@ async def start(self) -> None: self._mtp_num_draft_tokens, ) + # Probe whether this model's prompt cache is snapshot-safe for the + # stream_chat system-prefix cache branch. Sliding-window models + # (gemma3_text, olmo3, recurrent_gemma) return RotatingKVCache + # entries whose ``.state`` aliases in-place-mutated buffers. + # Only relevant for the LLM path; MLLM never enters the cache + # branch. + if not self._is_mllm and self._model is not None: + try: + from mlx_lm.models.cache import KVCache, make_prompt_cache + + probe_cache = make_prompt_cache(self._model.model) + self._supports_system_kv_cache = bool(probe_cache) and all( + isinstance(c, KVCache) for c in probe_cache + ) + if not self._supports_system_kv_cache: + cache_types = sorted({type(c).__name__ for c in probe_cache}) + logger.info( + "System KV cache snapshot disabled: model returned " + "non-KVCache entries (%s); stream_chat will use the " + "uncached path", + cache_types, + ) + except Exception as e: + logger.debug( + "System KV cache support probe failed (%s); " + "disabling snapshot path", + e, + ) + self._supports_system_kv_cache = False + # Build parallel mlx_lm TextModel for text-only routing. # Even when MTP is disabled, text-only requests should not be trapped # on the slower mlx_vlm multimodal path. @@ -338,6 +375,7 @@ async def stop(self) -> None: self._system_kv_snapshot = None self._system_kv_hash = None self._system_kv_token_count = 0 + self._supports_system_kv_cache = False logger.info("SimpleEngine stopped") async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwargs): @@ -860,7 +898,372 @@ def run_stream(): prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) prompt += "\nassistant:" - # Stream generate + # --- System-prompt KV caching on the pure-LLM stream_chat path --- + # Mirrors the cache in _stream_generate_text. Locates the system prefix + # via probe-divergence (cf. prompt_warmup._build_strict_prefix_string): + # render the template with two different user contents and take the + # shared prefix. Works across Qwen/ChatML, Llama, Gemma, and any other + # chat format -- no per-model marker list. Falls back to the original + # uncached self.stream_generate() if the system prefix can't be + # isolated or any step of the cache-aware path raises. + cache_hit = False + suffix_tokens = None + system_tokens = None + system_token_count = 0 + full_token_count = 0 + system_hash = None + kv_cache_eligible = False + # Snapshot reference captured at gate time so a concurrent MISS that + # reassigns ``self._system_kv_snapshot`` between the gate and the + # restore (which runs later inside ``_run_blocking_serialized``) + # can't desynchronize the restored KV from the hash that decided HIT. + hit_snapshot: Any = None + + # Decode-control gate. + # The cache branch below drives ``mlx_lm.stream_generate`` directly with only + # ``prompt``, ``max_tokens``, ``sampler`` (built from temperature+top_p), and + # ``prompt_cache``. + # The uncached fallback threads ``**kwargs`` through ``self.stream_generate``, + # which preserves ``stop``, request-local ``logits_processors`` (parser stop + # tokens and JSON-constrained decoding attached by server.py per request), and + # the ``top_k`` / ``min_p`` / ``presence_penalty`` / ``repetition_penalty`` + # sampling controls. + # If the cache branch ran with any of those active, cache-eligible and uncached + # requests would silently decode under different constraints. + # Skip the cache branch in that case so both paths share identical decode + # semantics. + # server.py always supplies the no-op defaults (``top_k=0``, ``min_p=0.0``, + # ``presence_penalty=0.0``, ``repetition_penalty=1.0``); compare against those + # rather than ``key in kwargs`` so the common path still hits the cache. + cache_blocking_controls: list[str] = [] + if kwargs.get("stop"): + cache_blocking_controls.append("stop") + if kwargs.get("logits_processors"): + cache_blocking_controls.append("logits_processors") + if (kwargs.get("top_k") or 0) > 0: + cache_blocking_controls.append("top_k") + if (kwargs.get("min_p") or 0.0) > 0.0: + cache_blocking_controls.append("min_p") + if (kwargs.get("presence_penalty") or 0.0) != 0.0: + cache_blocking_controls.append("presence_penalty") + if (kwargs.get("repetition_penalty") or 1.0) != 1.0: + cache_blocking_controls.append("repetition_penalty") + + # Engine-feature gate. + # The cache branch also bypasses engine-level features that + # ``self.stream_generate`` (and the ``MLXLanguageModel.stream_generate`` + # wrapper underneath it) layer on top of ``mlx_lm.stream_generate``. + # Same correctness reasoning as the decode-control gate: cache-eligible + # and uncached requests must decode under identical engine semantics, so + # skip the cache branch when any of these are active. + # Specifically: + # - ``self._mtp`` injects ``mtp=True`` and ``num_draft_tokens`` into + # the mlx-lm call (see ``MLXLanguageModel.stream_generate``). + # - A loaded SpecPrefill draft model (``self._draft_model is not None``, + # set when ``specprefill_enabled`` + ``specprefill_draft_model`` are + # configured at engine init) routes large prompts through + # ``_stream_generate_specprefill`` instead of the plain stream path. + # - A per-request ``specprefill`` override from ``extra_body`` (popped + # by the wrapper from ``kwargs``) can force or suppress SpecPrefill + # for a single request. + # ``specprefill=False`` is a meaningful suppression signal — gate on + # ``is not None`` rather than truthiness so the wrapper sees it. + # - ``self._max_kv_size`` (when > 0) caps the prompt cache; the cache + # branch builds its cache with ``make_prompt_cache(model)`` and has + # no equivalent bound. + if self._mtp: + cache_blocking_controls.append("mtp") + if self._draft_model is not None: + cache_blocking_controls.append("specprefill_loaded") + if kwargs.get("specprefill") is not None: + cache_blocking_controls.append("specprefill_request_override") + if (self._max_kv_size or 0) > 0: + cache_blocking_controls.append("max_kv_size") + # Sliding-window models build their prompt cache from RotatingKVCache + # entries whose ``.state`` aliases buffers that ``update_and_fetch`` + # mutates in place. Snapshot capture would corrupt the cached prefix + # on the next decode. Probed once at start; ``False`` if the model + # exposes any non-KVCache entries or the probe failed. + if not self._supports_system_kv_cache: + cache_blocking_controls.append("non_kv_cache_class") + + if cache_blocking_controls: + logger.info( + "System KV cache SKIP (stream_chat): request or engine has " + "controls/features the cache branch cannot honor (%s); using " + "uncached path", + cache_blocking_controls, + ) + + # Normalize messages to plain dicts. The public stream_chat signature + # types messages as list[dict], but internal callers (server.py, + # tests) sometimes pass Pydantic Message objects directly; those + # don't expose a dict-style .get() interface. + def _to_msg_dict(m: Any) -> dict[str, Any]: + if isinstance(m, dict): + return m + if hasattr(m, "model_dump"): + return m.model_dump() + if hasattr(m, "dict"): + return m.dict() + return { + "role": getattr(m, "role", None), + "content": getattr(m, "content", ""), + } + + messages_for_cache = [_to_msg_dict(m) for m in messages] + has_system = any(m.get("role") == "system" for m in messages_for_cache) + if ( + has_system + and not cache_blocking_controls + and hasattr(tokenizer, "apply_chat_template") + ): + + def _with_user(user_content: str) -> list[dict[str, Any]]: + msgs = [dict(m) for m in messages_for_cache] + if msgs and msgs[-1].get("role") == "user": + msgs[-1] = {**msgs[-1], "content": user_content} + else: + msgs = [*msgs, {"role": "user", "content": user_content}] + return msgs + + rendered_a: Any = None + rendered_b: Any = None + try: + rendered_a = tokenizer.apply_chat_template( + _with_user("Alpha"), **template_kwargs + ) + rendered_b = tokenizer.apply_chat_template( + _with_user("Bravo"), **template_kwargs + ) + except Exception: + pass + + if isinstance(rendered_a, str) and isinstance(rendered_b, str): + boundary = 0 + diverged = False + for i in range(min(len(rendered_a), len(rendered_b))): + if rendered_a[i] != rendered_b[i]: + diverged = True + break + boundary = i + 1 + + if diverged and boundary >= 16: + system_prefix_text = rendered_a[:boundary] + system_hash = hashlib.sha256( + system_prefix_text.encode() + ).hexdigest()[:16] + + add_special = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + full_tokens_list = tokenizer.encode( + prompt, add_special_tokens=add_special + ) + system_tokens_list = tokenizer.encode( + system_prefix_text, add_special_tokens=add_special + ) + full_token_count = len(full_tokens_list) + system_token_count = len(system_tokens_list) + + if ( + len(full_tokens_list) > system_token_count + and full_tokens_list[:system_token_count] == system_tokens_list + ): + system_tokens = system_tokens_list + suffix_tokens = full_tokens_list[system_token_count:] + kv_cache_eligible = True + # Read the snapshot reference once. If we promote to + # HIT, ``hit_snapshot`` is the exact list the hash + # check just validated against. A later concurrent + # MISS that reassigns ``self._system_kv_snapshot`` + # before our serialized worker restores it cannot + # alias what we captured here. + candidate_snapshot = self._system_kv_snapshot + if ( + system_hash == self._system_kv_hash + and candidate_snapshot is not None + and system_token_count == self._system_kv_token_count + ): + cache_hit = True + hit_snapshot = candidate_snapshot + logger.info( + "System KV cache HIT (stream_chat): reusing %d " + "tokens, prefilling %d new (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + logger.info( + "System KV cache MISS (stream_chat): will " + "prefill %d system + %d suffix tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + + if kv_cache_eligible: + # Cache-aware path: drive mlx-lm directly with a pre-populated cache. + # Stream chunks back to the caller via an asyncio.Queue (mirrors + # _stream_generate_text) so the client sees tokens as they arrive + # rather than after the full generation finishes. + loop = asyncio.get_running_loop() + response_queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() + abort_event = threading.Event() + + def _emit_response(resp: Any) -> None: + if abort_event.is_set(): + return + loop.call_soon_threadsafe(response_queue.put_nowait, ("resp", resp)) + + def _emit_done() -> None: + loop.call_soon_threadsafe(response_queue.put_nowait, ("done", None)) + + def _emit_error(exc: BaseException) -> None: + loop.call_soon_threadsafe(response_queue.put_nowait, ("error", exc)) + + def _run_with_cache() -> None: + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + model = self._model.model + sampler = make_sampler(temp=temperature, top_p=top_p) + + if cache_hit: + bc = make_prompt_cache(model) + # Restore from the closure-local reference captured at the + # gate, never from ``self._system_kv_snapshot`` directly: + # a concurrent MISS could have replaced the instance + # attribute with a snapshot for a different system prefix + # between the gate check and this point. + for i, saved_state in enumerate(hit_snapshot): + bc[i].state = saved_state + else: + bc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=bc) + mx.eval([c.state for c in bc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=bc) + mx.eval([c.state for c in bc]) + + # Free intermediate prefill activations before snapshotting. + # Intentionally stricter than the MLLM path, which does not + # ``mx.clear_cache()`` between its last prefill chunk and + # the snapshot; here we want the snapshot to reflect only + # the KV state, not residual activations from prefill. + mx.clear_cache() + + snapshot = [c.state for c in bc] + mx.eval([s for pair in snapshot for s in pair]) + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + try: + cache_mb = sum(c.nbytes for c in bc) / 1e6 + except Exception: + cache_mb = -1 + logger.info( + "System KV cache STORED (stream_chat): %d tokens " "(%.1f MB)", + system_token_count, + cache_mb, + ) + + prompt_arr = mx.array(suffix_tokens) + for resp in mlx_stream_generate( + model, + tokenizer, + prompt=prompt_arr, + max_tokens=max_tokens, + sampler=sampler, + prompt_cache=bc, + ): + if abort_event.is_set(): + break + _emit_response(resp) + + async def _produce_responses() -> None: + try: + await self._run_blocking_serialized( + _run_with_cache, + on_cancel=abort_event.set, + ) + except asyncio.CancelledError: + raise + except BaseException as exc: + _emit_error(exc) + else: + _emit_done() + + producer_task = asyncio.create_task(_produce_responses()) + + accumulated_text = "" + token_count = 0 + finished = False + cache_path_failed_before_first_token = False + try: + while True: + kind, payload = await response_queue.get() + if kind == "done": + break + if kind == "error": + if token_count == 0: + logger.warning( + "Pure-LLM KV-cache path failed before first " + "token (%s); falling back to uncached " + "stream_generate", + payload, + ) + cache_path_failed_before_first_token = True + break + # Already streamed partial output; can't cleanly + # restart on the uncached path, so surface the error. + raise payload + resp = payload + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + finish_reason = getattr(resp, "finish_reason", None) + finished = finish_reason is not None or token_count >= max_tokens + if finish_reason is None and finished: + finish_reason = "stop" + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=full_token_count, + completion_tokens=token_count, + finished=finished, + finish_reason=finish_reason, + ) + if finished: + break + finally: + if not producer_task.done(): + abort_event.set() + try: + await producer_task + except BaseException: + pass + + if cache_path_failed_before_first_token: + async for output in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + **kwargs, + ): + yield output + return + + # Fallback: no system prefix detected -> original uncached path async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens,