diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 501fc4223eee..79d799f1b9e8 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -63,6 +63,7 @@ SLRUStrategy, ) from sglang.srt.mem_cache.hicache_storage import get_hash_str, hash_str_to_int64 +from sglang.srt.mem_cache.semantic_prefix import SemanticPrefixProvider if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -335,6 +336,7 @@ def __init__(self, params: CacheInitParams): ) self.evictable_leaves = set() + self._semantic_provider: Optional[SemanticPrefixProvider] = None self.reset() @classmethod @@ -375,6 +377,24 @@ def maybe_bigram_convert( ) -> Tuple[RadixKey, Optional[torch.Tensor]]: return maybe_bigram_convert(self.is_eagle, key, value) + def set_semantic_provider( + self, provider: Optional[SemanticPrefixProvider] + ) -> None: + """Register a :class:`~sglang.srt.mem_cache.semantic_prefix.SemanticPrefixProvider`. + + When set, :meth:`match_prefix` will call + :meth:`~SemanticPrefixProvider.on_prefix_miss` whenever the exact + radix-tree lookup returns zero cached tokens and ``params.req`` is + available. Pass ``None`` to unregister a previously registered + provider. + + Args: + provider: Provider instance, or ``None`` to clear. + """ + self._semantic_provider = provider + if provider is not None: + provider.on_init() + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: """Find the longest cached prefix of ``key`` in the radix tree. @@ -411,6 +431,57 @@ def match_prefix(self, params: MatchPrefixParams) -> MatchResult: * If the lookup ends inside a stored segment the node is split once to expose a precise boundary; this structural refinement improves subsequent match efficiency and does not duplicate data. + + Semantic fallback: + When a :class:`~sglang.srt.mem_cache.semantic_prefix.SemanticPrefixProvider` + has been registered via :meth:`set_semantic_provider` and the exact + lookup returns zero cached tokens, the provider's + :meth:`~SemanticPrefixProvider.on_prefix_miss` is called with the + request ID and token IDs. If the provider returns an alternate + donor token sequence, a second exact lookup is performed against + those tokens, allowing the engine to reuse a semantically similar + cached prefix without re-computing the full prefill. + """ + result = self._match_prefix_exact(params) + + # Semantic fallback: if no tokens matched and a provider is registered, + # ask the provider for an alternate donor whose KV is already cached. + if ( + len(result.device_indices) == 0 + and self._semantic_provider is not None + and params.req is not None + ): + semantic_result = self._semantic_provider.on_prefix_miss( + rid=params.req.rid, + token_ids=list(params.key.token_ids), + ) + if semantic_result is not None: + if semantic_result.source_id: + logger.debug( + "Semantic KV hit for req %s via donor %s " + "(expected %d cached tokens)", + params.req.rid, + semantic_result.source_id, + semantic_result.num_cached_tokens, + ) + alternate_key = RadixKey( + semantic_result.alternate_token_ids, + params.key.extra_key, + ) + alternate_params = MatchPrefixParams( + key=alternate_key, + req=params.req, + ) + result = self._match_prefix_exact(alternate_params) + + return result + + def _match_prefix_exact(self, params: MatchPrefixParams) -> MatchResult: + """Exact radix-tree prefix lookup with no semantic fallback. + + This is the inner implementation called by :meth:`match_prefix`. + Callers that need the full semantic-fallback behaviour should use + :meth:`match_prefix` instead. """ key = params.key key, _ = self.maybe_bigram_convert(key) @@ -500,6 +571,13 @@ def cache_finished_req(self, req: Req, is_insert: bool = True): self.token_to_kv_pool_allocator.free( kv_indices[req.cache_protected_len : new_prefix_len] ) + # Notify the semantic provider so it can register this request as + # a potential future donor for approximate KV reuse. + if self._semantic_provider is not None: + self._semantic_provider.on_request_cached( + rid=req.rid, + token_ids=list(token_ids), + ) else: self.token_to_kv_pool_allocator.free( kv_indices[req.cache_protected_len : len(keys)] diff --git a/python/sglang/srt/mem_cache/semantic_prefix.py b/python/sglang/srt/mem_cache/semantic_prefix.py new file mode 100644 index 000000000000..c39efa53df17 --- /dev/null +++ b/python/sglang/srt/mem_cache/semantic_prefix.py @@ -0,0 +1,145 @@ +"""SemanticPrefixProvider — interface for approximate KV cache matching. + +When an exact radix-tree lookup returns zero cached tokens, the provider +can supply an alternate set of token IDs whose KV is already resident in +the RadixCache. The engine then reuses that donor KV, skipping full +prefill recomputation. + +Typical use-cases +----------------- +* Semantic KV sharing (e.g. SemBlend): look up semantically similar + documents already in the cache. +* Fuzzy prefix matching: tolerate small edits at prefix boundaries. +* RAG-aware caching: reuse cached KV for retrieved contexts. +* Topic-based KV sharing: share computation across requests with the + same subject matter. + +Usage +----- +Implement :class:`SemanticPrefixProvider` and register it with the +server's prefix cache:: + + server.prefix_cache.set_semantic_provider(my_provider) + +``on_prefix_miss`` is called synchronously inside the scheduler step +(inside ``RadixCache.match_prefix``), so it must be fast. Heavy +embedding or similarity search should be done asynchronously and the +result staged before the call. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class SemanticPrefixResult: + """Result returned by :meth:`SemanticPrefixProvider.on_prefix_miss`. + + Attributes + ---------- + alternate_token_ids: + Token IDs of the donor request whose KV is already resident in + the RadixCache. The cache will be queried with these tokens + instead of the query's own tokens. + num_cached_tokens: + Hint for the expected number of cached tokens (used for logging + only; the actual count is determined by the radix lookup). + skip_insert: + When ``True`` (the default) the query result is *not* inserted + into the RadixCache under the query's own token IDs after the + request completes, preventing cache pollution. + metadata: + Arbitrary application-defined data passed through to + :meth:`on_request_cached` for bookkeeping. Must be picklable + when used in multi-process deployments. + source_id: + Optional label used in log messages. + """ + + alternate_token_ids: list[int] + num_cached_tokens: int + skip_insert: bool = True + metadata: Any = None + source_id: str = "" + + +class SemanticPrefixProvider(ABC): + """Abstract base class for approximate / semantic KV cache matching. + + Subclasses implement :meth:`on_prefix_miss` to supply a donor request + whenever the standard exact-match radix lookup returns zero hit tokens, + and :meth:`on_request_cached` to update internal state after each + request's KV is committed to the cache. + + The two optional lifecycle hooks (:meth:`on_init` and + :meth:`on_shutdown`) allow the provider to integrate with SGLang's + startup / teardown sequence. + + Thread-safety + ------------- + :meth:`on_prefix_miss` and :meth:`on_request_cached` are called from + the scheduler thread. Implementations are responsible for their own + locking where necessary. + """ + + @abstractmethod + def on_prefix_miss( + self, + rid: str, + token_ids: list[int], + ) -> Optional[SemanticPrefixResult]: + """Called when the exact radix-tree lookup returns zero hit tokens. + + The implementation should return a :class:`SemanticPrefixResult` + whose ``alternate_token_ids`` are already resident in the + RadixCache, or ``None`` to fall back to a normal cold prefill. + + Parameters + ---------- + rid: + SGLang request ID (unique per request). + token_ids: + Full prompt token IDs for the incoming request. + + Returns + ------- + :class:`SemanticPrefixResult` or ``None`` + """ + ... + + @abstractmethod + def on_request_cached( + self, + rid: str, + token_ids: list[int], + ) -> None: + """Called after a request's KV is committed to the RadixCache. + + Implementations should use this to register the request as a + potential future donor and update any per-request state. + + Parameters + ---------- + rid: + SGLang request ID of the cached request. + token_ids: + Full token IDs (prompt + generated output) of the cached + request. + """ + ... + + def on_init(self, model_config: Any = None) -> None: # noqa: B027 + """Called once when the RadixCache initialises. + + Parameters + ---------- + model_config: + SGLang ``ModelConfig`` instance, or ``None`` when not + available at init time. + """ + + def on_shutdown(self) -> None: # noqa: B027 + """Called once when the server shuts down.""" diff --git a/test/srt/conftest.py b/test/srt/conftest.py new file mode 100644 index 000000000000..32ec3da5e345 --- /dev/null +++ b/test/srt/conftest.py @@ -0,0 +1,49 @@ +"""pytest configuration for test/srt. + +When tests are run from a sparse git checkout (e.g. during local development +where only ``python/sglang/srt/mem_cache/`` was fetched), ``sglang.lang`` and +other frontend modules may be missing or the installed sglang version may +differ from the fork being tested. + +This conftest ensures the fork's ``python/`` directory takes precedence over +any installed sglang package, and stubs out missing frontend modules so that +tests focusing on the server-side runtime (``sglang.srt.*``) can run without +a full install. + +In CI — where SGLang is installed from the complete source tree being tested — +the fork's python/ directory is already the installed package, so these stubs +are never needed. +""" +from __future__ import annotations + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +# ── Ensure the fork's python/ directory takes precedence ───────────────────── +_FORK_PYTHON = str(Path(__file__).parent.parent.parent / "python") + +# Remove any previously-loaded sglang modules so the fork's versions are used. +for _key in list(sys.modules): + if _key == "sglang" or _key.startswith("sglang."): + del sys.modules[_key] + +# Insert the fork at the very front of sys.path. +if _FORK_PYTHON in sys.path: + sys.path.remove(_FORK_PYTHON) +sys.path.insert(0, _FORK_PYTHON) + +# ── Stub out frontend modules missing from the sparse checkout ─────────────── +_STUB_MODULES = [ + "sglang.lang", + "sglang.lang.api", + "sglang.lang.backend", + "sglang.lang.backend.runtime_endpoint", + "sglang.lang.backend.anthropic", + "sglang.lang.backend.litellm", + "sglang.lang.backend.openai", + "sglang.lang.backend.vertexai", + "sglang.lang.choices", +] +for _mod_name in _STUB_MODULES: + sys.modules[_mod_name] = MagicMock() diff --git a/test/srt/test_semantic_prefix_provider.py b/test/srt/test_semantic_prefix_provider.py new file mode 100644 index 000000000000..1b7e7b768d77 --- /dev/null +++ b/test/srt/test_semantic_prefix_provider.py @@ -0,0 +1,507 @@ +"""Tests for SemanticPrefixProvider integration in RadixCache. + +All tests run without a GPU — they use :meth:`RadixCache.create_simulated` +which requires no physical memory pools. +""" + +from __future__ import annotations + +import unittest +from typing import Optional +from unittest.mock import MagicMock + +import torch + +from sglang.srt.mem_cache.base_prefix_cache import InsertParams, MatchPrefixParams +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey +from sglang.srt.mem_cache.semantic_prefix import ( + SemanticPrefixProvider, + SemanticPrefixResult, +) + + +# ────────────────────────────────────────────────────────────── +# Helper provider implementations +# ────────────────────────────────────────────────────────────── + + +class _NullProvider(SemanticPrefixProvider): + """Provider that never returns a donor.""" + + def on_prefix_miss(self, rid: str, token_ids: list[int]): + return None + + def on_request_cached(self, rid: str, token_ids: list[int]) -> None: + pass + + +class _FixedDonorProvider(SemanticPrefixProvider): + """Provider that always returns a fixed donor sequence.""" + + def __init__(self, donor_ids: list[int], source_id: str = "") -> None: + self.donor_ids = donor_ids + self.source_id = source_id + self.miss_calls: list[tuple] = [] + self.cached_calls: list[tuple] = [] + + def on_prefix_miss( + self, rid: str, token_ids: list[int] + ) -> Optional[SemanticPrefixResult]: + self.miss_calls.append((rid, token_ids)) + return SemanticPrefixResult( + alternate_token_ids=self.donor_ids, + num_cached_tokens=len(self.donor_ids), + source_id=self.source_id, + ) + + def on_request_cached(self, rid: str, token_ids: list[int]) -> None: + self.cached_calls.append((rid, token_ids)) + + +class _RaisingProvider(SemanticPrefixProvider): + """Provider that always raises (used to test exception propagation).""" + + def on_prefix_miss(self, rid: str, token_ids: list[int]): + raise RuntimeError("deliberate test error") + + def on_request_cached(self, rid: str, token_ids: list[int]) -> None: + pass + + +# ────────────────────────────────────────────────────────────── +# Shared helpers +# ────────────────────────────────────────────────────────────── + + +def _make_cache() -> RadixCache: + return RadixCache.create_simulated() + + +def _insert(cache: RadixCache, token_ids: list[int]) -> None: + cache.insert( + InsertParams( + key=RadixKey(token_ids), + value=torch.tensor(token_ids, dtype=torch.int64), + ) + ) + + +def _mock_req(rid: str = "r1") -> MagicMock: + req = MagicMock() + req.rid = rid + return req + + +def _make_req_for_finished( + cache: RadixCache, rid: str, token_ids: list[int] +) -> MagicMock: + """Return a Req-like mock suitable for cache_finished_req.""" + req = MagicMock() + req.rid = rid + req.origin_input_ids = token_ids + req.output_ids = [] + req.extra_key = None + req.req_pool_idx = 0 + req.last_node = cache.root_node # avoids dec_lock_ref tree traversal + req.cache_protected_len = 0 + req.priority = 0 + req.pop_committed_kv_cache.return_value = len(token_ids) + return req + + +def _attach_pool(cache: RadixCache, num_tokens: int) -> None: + """Wire minimal mock memory pools onto a simulated cache.""" + kv_pool = torch.arange(num_tokens * 2, dtype=torch.int64).reshape(2, num_tokens) + req_to_token = MagicMock() + req_to_token.req_to_token = kv_pool + cache.req_to_token_pool = req_to_token + allocator = MagicMock() + allocator.free = MagicMock() + cache.token_to_kv_pool_allocator = allocator + + +# ────────────────────────────────────────────────────────────── +# Test suites +# ────────────────────────────────────────────────────────────── + + +class TestSetSemanticProvider(unittest.TestCase): + """Tests for set_semantic_provider().""" + + def test_stores_provider(self): + cache = _make_cache() + p = _NullProvider() + cache.set_semantic_provider(p) + self.assertIs(cache._semantic_provider, p) + + def test_on_init_called_once(self): + cache = _make_cache() + p = MagicMock(spec=SemanticPrefixProvider) + cache.set_semantic_provider(p) + p.on_init.assert_called_once_with() + + def test_set_none_clears_provider(self): + cache = _make_cache() + p = _NullProvider() + cache.set_semantic_provider(p) + cache.set_semantic_provider(None) + self.assertIsNone(cache._semantic_provider) + + def test_set_none_does_not_call_on_init(self): + """Passing None should not raise and should not call on_init.""" + cache = _make_cache() + # Should not raise — no on_init to call + cache.set_semantic_provider(None) + + def test_replace_provider(self): + """Replacing a provider calls on_init on the new one only.""" + cache = _make_cache() + p1 = MagicMock(spec=SemanticPrefixProvider) + p2 = MagicMock(spec=SemanticPrefixProvider) + cache.set_semantic_provider(p1) + cache.set_semantic_provider(p2) + p1.on_init.assert_called_once() + p2.on_init.assert_called_once() + self.assertIs(cache._semantic_provider, p2) + + def test_provider_none_by_default(self): + cache = _make_cache() + self.assertIsNone(cache._semantic_provider) + + +class TestMatchPrefixNoProvider(unittest.TestCase): + """Baseline: no semantic provider — exact-match behaviour unchanged.""" + + def test_miss_returns_empty(self): + cache = _make_cache() + result = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]))) + self.assertEqual(len(result.device_indices), 0) + + def test_hit_returns_indices(self): + cache = _make_cache() + ids = [10, 20, 30] + _insert(cache, ids) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(ids))) + self.assertEqual(len(result.device_indices), len(ids)) + + def test_partial_hit(self): + cache = _make_cache() + _insert(cache, [1, 2, 3, 4, 5]) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]))) + self.assertGreater(len(result.device_indices), 0) + + +class TestMatchPrefixSemanticFallback(unittest.TestCase): + """Tests for the semantic fallback path in match_prefix.""" + + # ── provider not triggered on exact hit ───────────────────────────────── + + def test_provider_not_called_on_exact_hit(self): + ids = [10, 20, 30, 40, 50] + cache = _make_cache() + _insert(cache, ids) + p = MagicMock(spec=SemanticPrefixProvider) + cache.set_semantic_provider(p) + + req = _mock_req() + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(ids), req=req)) + + self.assertGreater(len(result.device_indices), 0) + p.on_prefix_miss.assert_not_called() + + # ── provider not triggered without params.req ──────────────────────────── + + def test_provider_not_called_without_req(self): + cache = _make_cache() + p = _FixedDonorProvider(donor_ids=[1, 2, 3]) + cache.set_semantic_provider(p) + + # params.req is None (default) → fallback must not activate + cache.match_prefix(MatchPrefixParams(key=RadixKey([99, 88, 77]))) + + self.assertEqual(len(p.miss_calls), 0) + + # ── provider called on exact miss with req ─────────────────────────────── + + def test_provider_called_on_miss(self): + cache = _make_cache() + p = _NullProvider() + p.on_prefix_miss = MagicMock(return_value=None) + cache.set_semantic_provider(p) + + req = _mock_req(rid="req-1") + query = [5, 6, 7] + cache.match_prefix(MatchPrefixParams(key=RadixKey(query), req=req)) + + p.on_prefix_miss.assert_called_once_with(rid="req-1", token_ids=query) + + def test_provider_receives_original_token_ids(self): + """Provider gets the original (pre-page-alignment) token IDs.""" + received: list[list[int]] = [] + + class _Recorder(SemanticPrefixProvider): + def on_prefix_miss(self, rid, token_ids): + received.append(list(token_ids)) + return None + + def on_request_cached(self, rid, token_ids): + pass + + cache = _make_cache() + cache.set_semantic_provider(_Recorder()) + query = [7, 8, 9, 11] + cache.match_prefix(MatchPrefixParams(key=RadixKey(query), req=_mock_req())) + self.assertEqual(received, [query]) + + # ── provider returns None → cold prefill ──────────────────────────────── + + def test_provider_returns_none_result_stays_empty(self): + cache = _make_cache() + cache.set_semantic_provider(_NullProvider()) + result = cache.match_prefix( + MatchPrefixParams(key=RadixKey([1, 2, 3]), req=_mock_req()) + ) + self.assertEqual(len(result.device_indices), 0) + + # ── provider returns alternate donor IDs ──────────────────────────────── + + def test_alternate_tokens_cached_gives_hit(self): + donor_ids = [10, 20, 30, 40, 50] + cache = _make_cache() + _insert(cache, donor_ids) + + p = _FixedDonorProvider(donor_ids=donor_ids, source_id="doc-42") + cache.set_semantic_provider(p) + + req = _mock_req(rid="req-2") + result = cache.match_prefix( + MatchPrefixParams(key=RadixKey([99, 88, 77]), req=req) + ) + + self.assertGreater(len(result.device_indices), 0) + self.assertEqual(len(p.miss_calls), 1) + self.assertEqual(p.miss_calls[0][0], "req-2") + + def test_alternate_tokens_not_cached_stays_empty(self): + cache = _make_cache() # empty — donor IDs not inserted + p = _FixedDonorProvider(donor_ids=[55, 66, 77]) + cache.set_semantic_provider(p) + + result = cache.match_prefix( + MatchPrefixParams(key=RadixKey([1, 2, 3]), req=_mock_req()) + ) + self.assertEqual(len(result.device_indices), 0) + + def test_extra_key_preserved_in_alternate_lookup(self): + """extra_key from the original query is used when looking up alternate tokens.""" + donor_ids = [10, 20, 30] + extra_key = "lora-7" + cache = _make_cache() + cache.insert( + InsertParams( + key=RadixKey(donor_ids, extra_key=extra_key), + value=torch.tensor(donor_ids, dtype=torch.int64), + ) + ) + + p = _FixedDonorProvider(donor_ids=donor_ids) + cache.set_semantic_provider(p) + + req = _mock_req() + result = cache.match_prefix( + MatchPrefixParams( + key=RadixKey([99, 88], extra_key=extra_key), req=req + ) + ) + self.assertGreater(len(result.device_indices), 0) + + # ── exception propagation ──────────────────────────────────────────────── + + def test_provider_exception_propagates(self): + """Exceptions from the provider are not silently swallowed.""" + cache = _make_cache() + cache.set_semantic_provider(_RaisingProvider()) + + with self.assertRaises(RuntimeError, msg="deliberate test error"): + cache.match_prefix( + MatchPrefixParams(key=RadixKey([1, 2, 3]), req=_mock_req()) + ) + + # ── logging ───────────────────────────────────────────────────────────── + + def test_source_id_logged_on_semantic_hit(self): + """A non-empty source_id triggers a DEBUG log entry.""" + donor_ids = [5, 6, 7, 8] + cache = _make_cache() + _insert(cache, donor_ids) + + p = _FixedDonorProvider(donor_ids=donor_ids, source_id="my-donor") + cache.set_semantic_provider(p) + + req = _mock_req() + with self.assertLogs( + "sglang.srt.mem_cache.radix_cache", level="DEBUG" + ) as cm: + cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]), req=req)) + + self.assertTrue( + any("my-donor" in line for line in cm.output), + f"'my-donor' not found in log output: {cm.output}", + ) + + def test_no_log_when_source_id_empty(self): + """An empty source_id must not produce a log entry at DEBUG level.""" + donor_ids = [5, 6, 7, 8] + cache = _make_cache() + _insert(cache, donor_ids) + + p = _FixedDonorProvider(donor_ids=donor_ids, source_id="") # empty + cache.set_semantic_provider(p) + + req = _mock_req() + import logging + + with self.assertLogs( + "sglang.srt.mem_cache.radix_cache", level="DEBUG" + ) as cm: + # Force at least one log entry so assertLogs doesn't fail on empty + logging.getLogger("sglang.srt.mem_cache.radix_cache").debug("sentinel") + cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]), req=req)) + + semantic_logs = [ + line + for line in cm.output + if "Semantic KV hit" in line + ] + self.assertEqual(len(semantic_logs), 0) + + +class TestMatchPrefixExact(unittest.TestCase): + """_match_prefix_exact never invokes the semantic provider.""" + + def test_exact_hit(self): + cache = _make_cache() + ids = [1, 2, 3] + _insert(cache, ids) + result = cache._match_prefix_exact(MatchPrefixParams(key=RadixKey(ids))) + self.assertEqual(len(result.device_indices), len(ids)) + + def test_exact_miss_returns_empty(self): + cache = _make_cache() + result = cache._match_prefix_exact( + MatchPrefixParams(key=RadixKey([99, 88, 77])) + ) + self.assertEqual(len(result.device_indices), 0) + + def test_does_not_call_provider(self): + cache = _make_cache() + p = MagicMock(spec=SemanticPrefixProvider) + cache.set_semantic_provider(p) + + cache._match_prefix_exact( + MatchPrefixParams(key=RadixKey([1, 2, 3]), req=_mock_req()) + ) + p.on_prefix_miss.assert_not_called() + + +class TestOnRequestCachedHook(unittest.TestCase): + """cache_finished_req must notify the provider after a successful insert.""" + + def _make_cache_with_pool( + self, provider: SemanticPrefixProvider, num_tokens: int = 10 + ) -> RadixCache: + cache = _make_cache() + cache.set_semantic_provider(provider) + _attach_pool(cache, num_tokens) + return cache + + def test_called_on_insert(self): + p = MagicMock(spec=SemanticPrefixProvider) + p.on_init = MagicMock() + p.on_request_cached = MagicMock() + cache = self._make_cache_with_pool(p, num_tokens=8) + + token_ids = [1, 2, 3, 4] + req = _make_req_for_finished(cache, "req-a", token_ids) + + cache.cache_finished_req(req, is_insert=True) + + p.on_request_cached.assert_called_once() + kwargs = p.on_request_cached.call_args.kwargs + self.assertEqual(kwargs["rid"], "req-a") + self.assertEqual(kwargs["token_ids"], token_ids) + + def test_not_called_when_no_insert(self): + p = MagicMock(spec=SemanticPrefixProvider) + p.on_init = MagicMock() + p.on_request_cached = MagicMock() + cache = self._make_cache_with_pool(p, num_tokens=8) + + token_ids = [5, 6, 7, 8] + req = _make_req_for_finished(cache, "req-b", token_ids) + + cache.cache_finished_req(req, is_insert=False) + + p.on_request_cached.assert_not_called() + + def test_not_called_without_provider(self): + """No provider → cache_finished_req must not raise.""" + cache = _make_cache() + _attach_pool(cache, num_tokens=8) + + token_ids = [1, 2, 3] + req = _make_req_for_finished(cache, "req-c", token_ids) + + # Must not raise AttributeError + cache.cache_finished_req(req, is_insert=True) + + def test_token_ids_match_committed_range(self): + """on_request_cached receives exactly the committed token IDs.""" + p = MagicMock(spec=SemanticPrefixProvider) + p.on_init = MagicMock() + p.on_request_cached = MagicMock() + cache = self._make_cache_with_pool(p, num_tokens=10) + + full_ids = list(range(6)) + committed_len = 4 # only first 4 tokens are "committed" + req = _make_req_for_finished(cache, "req-d", full_ids[:committed_len]) + req.pop_committed_kv_cache.return_value = committed_len + + cache.cache_finished_req(req, is_insert=True) + + kwargs = p.on_request_cached.call_args.kwargs + self.assertEqual(kwargs["token_ids"], full_ids[:committed_len]) + + +class TestMultipleRequests(unittest.TestCase): + """Semantic provider is called independently for each request.""" + + def test_each_miss_calls_provider(self): + cache = _make_cache() + p = _NullProvider() + p.on_prefix_miss = MagicMock(return_value=None) + cache.set_semantic_provider(p) + + for i in range(3): + req = _mock_req(rid=f"req-{i}") + cache.match_prefix(MatchPrefixParams(key=RadixKey([i * 10]), req=req)) + + self.assertEqual(p.on_prefix_miss.call_count, 3) + + def test_second_provider_replaces_first(self): + cache = _make_cache() + p1 = MagicMock(spec=SemanticPrefixProvider) + p2 = MagicMock(spec=SemanticPrefixProvider) + p2.on_prefix_miss = MagicMock(return_value=None) + cache.set_semantic_provider(p1) + cache.set_semantic_provider(p2) + + req = _mock_req() + cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]), req=req)) + + p1.on_prefix_miss.assert_not_called() + p2.on_prefix_miss.assert_called_once() + + +if __name__ == "__main__": + unittest.main()