Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SLRUStrategy,
)
from sglang.srt.mem_cache.hicache_storage import get_hash_str, hash_str_to_int64
from sglang.srt.mem_cache.semantic_prefix import SemanticPrefixProvider

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
Expand Down Expand Up @@ -335,6 +336,7 @@ def __init__(self, params: CacheInitParams):
)

self.evictable_leaves = set()
self._semantic_provider: Optional[SemanticPrefixProvider] = None
self.reset()

@classmethod
Expand Down Expand Up @@ -375,6 +377,24 @@ def maybe_bigram_convert(
) -> Tuple[RadixKey, Optional[torch.Tensor]]:
return maybe_bigram_convert(self.is_eagle, key, value)

def set_semantic_provider(
self, provider: Optional[SemanticPrefixProvider]
) -> None:
"""Register a :class:`~sglang.srt.mem_cache.semantic_prefix.SemanticPrefixProvider`.

When set, :meth:`match_prefix` will call
:meth:`~SemanticPrefixProvider.on_prefix_miss` whenever the exact
radix-tree lookup returns zero cached tokens and ``params.req`` is
available. Pass ``None`` to unregister a previously registered
provider.

Args:
provider: Provider instance, or ``None`` to clear.
"""
self._semantic_provider = provider
if provider is not None:
provider.on_init()

def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the longest cached prefix of ``key`` in the radix tree.

Expand Down Expand Up @@ -411,6 +431,57 @@ def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
* If the lookup ends inside a stored segment the node is split once
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.

Semantic fallback:
When a :class:`~sglang.srt.mem_cache.semantic_prefix.SemanticPrefixProvider`
has been registered via :meth:`set_semantic_provider` and the exact
lookup returns zero cached tokens, the provider's
:meth:`~SemanticPrefixProvider.on_prefix_miss` is called with the
request ID and token IDs. If the provider returns an alternate
donor token sequence, a second exact lookup is performed against
those tokens, allowing the engine to reuse a semantically similar
cached prefix without re-computing the full prefill.
"""
result = self._match_prefix_exact(params)

# Semantic fallback: if no tokens matched and a provider is registered,
# ask the provider for an alternate donor whose KV is already cached.
if (
len(result.device_indices) == 0
and self._semantic_provider is not None
and params.req is not None
):
semantic_result = self._semantic_provider.on_prefix_miss(
rid=params.req.rid,
token_ids=list(params.key.token_ids),
)
if semantic_result is not None:
if semantic_result.source_id:
logger.debug(
"Semantic KV hit for req %s via donor %s "
"(expected %d cached tokens)",
params.req.rid,
semantic_result.source_id,
semantic_result.num_cached_tokens,
)
alternate_key = RadixKey(
semantic_result.alternate_token_ids,
params.key.extra_key,
)
alternate_params = MatchPrefixParams(
key=alternate_key,
req=params.req,
)
result = self._match_prefix_exact(alternate_params)

return result

def _match_prefix_exact(self, params: MatchPrefixParams) -> MatchResult:
"""Exact radix-tree prefix lookup with no semantic fallback.

This is the inner implementation called by :meth:`match_prefix`.
Callers that need the full semantic-fallback behaviour should use
:meth:`match_prefix` instead.
"""
key = params.key
key, _ = self.maybe_bigram_convert(key)
Expand Down Expand Up @@ -500,6 +571,13 @@ def cache_finished_req(self, req: Req, is_insert: bool = True):
self.token_to_kv_pool_allocator.free(
kv_indices[req.cache_protected_len : new_prefix_len]
)
# Notify the semantic provider so it can register this request as
# a potential future donor for approximate KV reuse.
if self._semantic_provider is not None:
self._semantic_provider.on_request_cached(
rid=req.rid,
token_ids=list(token_ids),
)
else:
self.token_to_kv_pool_allocator.free(
kv_indices[req.cache_protected_len : len(keys)]
Expand Down
145 changes: 145 additions & 0 deletions python/sglang/srt/mem_cache/semantic_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""SemanticPrefixProvider — interface for approximate KV cache matching.

When an exact radix-tree lookup returns zero cached tokens, the provider
can supply an alternate set of token IDs whose KV is already resident in
the RadixCache. The engine then reuses that donor KV, skipping full
prefill recomputation.

Typical use-cases
-----------------
* Semantic KV sharing (e.g. SemBlend): look up semantically similar
documents already in the cache.
* Fuzzy prefix matching: tolerate small edits at prefix boundaries.
* RAG-aware caching: reuse cached KV for retrieved contexts.
* Topic-based KV sharing: share computation across requests with the
same subject matter.

Usage
-----
Implement :class:`SemanticPrefixProvider` and register it with the
server's prefix cache::

server.prefix_cache.set_semantic_provider(my_provider)

``on_prefix_miss`` is called synchronously inside the scheduler step
(inside ``RadixCache.match_prefix``), so it must be fast. Heavy
embedding or similarity search should be done asynchronously and the
result staged before the call.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional


@dataclass
class SemanticPrefixResult:
"""Result returned by :meth:`SemanticPrefixProvider.on_prefix_miss`.

Attributes
----------
alternate_token_ids:
Token IDs of the donor request whose KV is already resident in
the RadixCache. The cache will be queried with these tokens
instead of the query's own tokens.
num_cached_tokens:
Hint for the expected number of cached tokens (used for logging
only; the actual count is determined by the radix lookup).
skip_insert:
When ``True`` (the default) the query result is *not* inserted
into the RadixCache under the query's own token IDs after the
request completes, preventing cache pollution.
metadata:
Arbitrary application-defined data passed through to
:meth:`on_request_cached` for bookkeeping. Must be picklable
when used in multi-process deployments.
source_id:
Optional label used in log messages.
"""

alternate_token_ids: list[int]
num_cached_tokens: int
skip_insert: bool = True
metadata: Any = None
source_id: str = ""


class SemanticPrefixProvider(ABC):
"""Abstract base class for approximate / semantic KV cache matching.

Subclasses implement :meth:`on_prefix_miss` to supply a donor request
whenever the standard exact-match radix lookup returns zero hit tokens,
and :meth:`on_request_cached` to update internal state after each
request's KV is committed to the cache.

The two optional lifecycle hooks (:meth:`on_init` and
:meth:`on_shutdown`) allow the provider to integrate with SGLang's
startup / teardown sequence.

Thread-safety
-------------
:meth:`on_prefix_miss` and :meth:`on_request_cached` are called from
the scheduler thread. Implementations are responsible for their own
locking where necessary.
"""

@abstractmethod
def on_prefix_miss(
self,
rid: str,
token_ids: list[int],
) -> Optional[SemanticPrefixResult]:
"""Called when the exact radix-tree lookup returns zero hit tokens.

The implementation should return a :class:`SemanticPrefixResult`
whose ``alternate_token_ids`` are already resident in the
RadixCache, or ``None`` to fall back to a normal cold prefill.

Parameters
----------
rid:
SGLang request ID (unique per request).
token_ids:
Full prompt token IDs for the incoming request.

Returns
-------
:class:`SemanticPrefixResult` or ``None``
"""
...

@abstractmethod
def on_request_cached(
self,
rid: str,
token_ids: list[int],
) -> None:
"""Called after a request's KV is committed to the RadixCache.

Implementations should use this to register the request as a
potential future donor and update any per-request state.

Parameters
----------
rid:
SGLang request ID of the cached request.
token_ids:
Full token IDs (prompt + generated output) of the cached
request.
"""
...

def on_init(self, model_config: Any = None) -> None: # noqa: B027
"""Called once when the RadixCache initialises.

Parameters
----------
model_config:
SGLang ``ModelConfig`` instance, or ``None`` when not
available at init time.
"""

def on_shutdown(self) -> None: # noqa: B027
"""Called once when the server shuts down."""
49 changes: 49 additions & 0 deletions test/srt/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""pytest configuration for test/srt.

When tests are run from a sparse git checkout (e.g. during local development
where only ``python/sglang/srt/mem_cache/`` was fetched), ``sglang.lang`` and
other frontend modules may be missing or the installed sglang version may
differ from the fork being tested.

This conftest ensures the fork's ``python/`` directory takes precedence over
any installed sglang package, and stubs out missing frontend modules so that
tests focusing on the server-side runtime (``sglang.srt.*``) can run without
a full install.

In CI — where SGLang is installed from the complete source tree being tested —
the fork's python/ directory is already the installed package, so these stubs
are never needed.
"""
from __future__ import annotations

import sys
from pathlib import Path
from unittest.mock import MagicMock

# ── Ensure the fork's python/ directory takes precedence ─────────────────────
_FORK_PYTHON = str(Path(__file__).parent.parent.parent / "python")

# Remove any previously-loaded sglang modules so the fork's versions are used.
for _key in list(sys.modules):
if _key == "sglang" or _key.startswith("sglang."):
del sys.modules[_key]

# Insert the fork at the very front of sys.path.
if _FORK_PYTHON in sys.path:
sys.path.remove(_FORK_PYTHON)
sys.path.insert(0, _FORK_PYTHON)

# ── Stub out frontend modules missing from the sparse checkout ───────────────
_STUB_MODULES = [
"sglang.lang",
"sglang.lang.api",
"sglang.lang.backend",
"sglang.lang.backend.runtime_endpoint",
"sglang.lang.backend.anthropic",
"sglang.lang.backend.litellm",
"sglang.lang.backend.openai",
"sglang.lang.backend.vertexai",
"sglang.lang.choices",
]
for _mod_name in _STUB_MODULES:
sys.modules[_mod_name] = MagicMock()
Loading
Loading