Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from sglang.srt.environ import envs
from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchPrefixParams
from sglang.srt.mem_cache.common import (
alloc_for_decode,
alloc_for_extend,
Expand Down Expand Up @@ -864,12 +864,11 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):

if tree_cache is not None:
match_result = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if tree_cache.supports_mamba()
else {}
),
MatchPrefixParams(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
req=self if tree_cache.supports_mamba() else None,
cow_mamba=tree_cache.supports_mamba(),
)
)
(
self.prefix_indices,
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sglang.srt.dllm.config import DllmConfig
from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_in_seq_split
from sglang.srt.managers.schedule_batch import DllmStagingReqs, Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchPrefixParams
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator
from sglang.srt.server_args import ServerArgs
Expand Down Expand Up @@ -190,7 +190,9 @@ def _compute_prefix_matches(
extra_key = r.extra_key
# NOTE: the prefix_indices must always be aligned with last_node
match_result = self.tree_cache.match_prefix(
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
MatchPrefixParams(
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
)
)
(
r.prefix_indices,
Expand All @@ -213,8 +215,9 @@ def _compute_prefix_matches(
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
match_result = self.waiting_queue_radix_tree.match_prefix(
rid=r.rid,
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
MatchPrefixParams(
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
)
)
in_batch_matching_prefixes = match_result.device_indices
if (
Expand Down
15 changes: 14 additions & 1 deletion python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import time
from abc import ABC, abstractmethod
from typing import (
Expand All @@ -20,6 +21,7 @@

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.radix_cache import RadixKey


@runtime_checkable
Expand All @@ -30,6 +32,17 @@ class PrefixCacheTrait(Protocol):
disable: bool


@dataclasses.dataclass
class MatchPrefixParams:
"""Unified parameters for match_prefix across different cache types"""

key: RadixKey

# Mamba specific
cow_mamba: bool = False
req: Optional[Req] = None


class MatchResult(NamedTuple):
"""Result of a prefix match operation.

Expand Down Expand Up @@ -77,7 +90,7 @@ def reset(self):
pass

@abstractmethod
def match_prefix(self, key: Any, **kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
pass

@abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

import torch

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import (
BasePrefixCache,
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,7 +47,7 @@ def disable(self):
def reset(self):
pass

def match_prefix(self, **unused_kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
return MatchResult(
device_indices=torch.empty((0,), dtype=torch.int64),
last_device_node=None,
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams, MatchResult
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
Expand Down Expand Up @@ -688,7 +688,8 @@ def terminate_prefetch(self, req_id: str):
return
operation.mark_terminate()

def match_prefix(self, key: RadixKey, **kwargs):
def match_prefix(self, params: MatchPrefixParams):
key = params.key
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
key, _ = self.maybe_bigram_convert(key)
if self.disable or len(key) == 0:
Expand Down
17 changes: 11 additions & 6 deletions python/sglang/srt/mem_cache/mamba_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
PagedTokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import (
BasePrefixCache,
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
Expand Down Expand Up @@ -414,19 +418,20 @@ def reset(self) -> None:
self.full_lru_list = LRUList(mamba=False)
self.mamba_lru_list = LRUList(mamba=True)

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the matching prefix from the radix tree.
Args:
key: A RadixKey contains token IDs to find a matching prefix.
params: MatchPrefixParams containing key and optional Mamba-specific parameters.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
cow_mamba: bool = kwargs.get("cow_mamba", False)
req: Req = kwargs.get("req", None)
key = params.key
cow_mamba = params.cow_mamba
req = params.req

if self.disable or len(key) == 0:
return MatchResult(
Expand Down Expand Up @@ -658,7 +663,7 @@ def _skip_cache_unfinished_req(req: Req) -> None:

# The prefix indices could be updated, reuse it
match_result = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
MatchPrefixParams(key=RadixKey(page_aligned_token_ids, req.extra_key))
)
(new_indices, new_last_node) = (
match_result.device_indices,
Expand Down
28 changes: 18 additions & 10 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
BlockRemoved,
BlockStored,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import (
BasePrefixCache,
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.evict_policy import (
EvictionStrategy,
FIFOStrategy,
Expand Down Expand Up @@ -337,7 +341,7 @@ def maybe_bigram_convert(

return key, value

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the longest cached prefix of ``key`` in the radix tree.

The logical namespace for prefix matching is determined by both the
Expand All @@ -352,12 +356,11 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
context) by supplying a distinct ``extra_key``.

Args:
key (RadixKey): The lookup key containing a list of token ids and an
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
length is internally truncated to a multiple of ``page_size``
before matching. Passing an empty key returns an empty result
with the root as the last node.
**kwargs: Reserved for future extensions (ignored currently).
params (MatchPrefixParams): Parameters containing the lookup key
with a list of token ids and an optional ``extra_key`` namespace tag.
If ``page_size > 1`` the length is internally truncated to a multiple
of ``page_size`` before matching. Passing an empty key returns an
empty result with the root as the last node.

Returns:
MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
Expand All @@ -375,6 +378,7 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.
"""
key = params.key
key, _ = self.maybe_bigram_convert(key)

def empty_match_result():
Expand Down Expand Up @@ -501,7 +505,7 @@ def cache_unfinished_req(self, req: Req, chunked=False):
)

# The prefix indices could be updated, reuse it
match_result = self.match_prefix(radix_key)
match_result = self.match_prefix(MatchPrefixParams(key=radix_key))
(new_indices, new_last_node) = (
match_result.device_indices,
match_result.last_device_node,
Expand Down Expand Up @@ -845,4 +849,8 @@ def take_events(self):
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
tree.pretty_print()

print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
print(
tree.match_prefix(
MatchPrefixParams(key=RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None))
)
)
9 changes: 7 additions & 2 deletions python/sglang/srt/mem_cache/radix_cache_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import torch

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import (
BasePrefixCache,
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
IOHandle,
RadixTreeCpp,
Expand Down Expand Up @@ -89,7 +93,8 @@ def reset(self):
raise NotImplementedError("Host cache is not supported yet")
self.tree.reset()

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
key = params.key
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
self.tree.match_prefix(key.token_ids)
)
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams, MatchResult
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode

try:
Expand Down Expand Up @@ -119,7 +119,7 @@ def reset(self): # type: ignore[override]
with self._node_lock:
self._in_flight_nodes.clear()

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
def match_prefix(self, params: MatchPrefixParams) -> MatchResult: # type: ignore[override]
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.

Reuses the base matching logic to obtain (value, last_node). If there
Expand All @@ -128,14 +128,15 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[
into those slots, then materialize a new child node for the retrieved
chunk.
"""
key = params.key
if self.disable or not key:
return super().match_prefix(key, **kwargs)
return super().match_prefix(params)

if self.page_size != 1:
aligned_len = len(key) // self.page_size * self.page_size
key = key[:aligned_len]

base_res = super().match_prefix(key, **kwargs)
base_res = super().match_prefix(params)
value: torch.Tensor = base_res.device_indices
last_node: TreeNode = base_res.last_device_node

Expand Down Expand Up @@ -229,7 +230,9 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # type:
req.req_pool_idx, :kv_committed_len
]

match_result = self.match_prefix(RadixKey(token_ids, req.extra_key))
match_result = self.match_prefix(
MatchPrefixParams(key=RadixKey(token_ids, req.extra_key))
)
new_last_node = match_result.last_device_node
assert new_last_node is not None

Expand Down
13 changes: 9 additions & 4 deletions python/sglang/srt/mem_cache/swa_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
import torch
from numpy import float64

from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import (
BasePrefixCache,
MatchPrefixParams,
MatchResult,
)
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
Expand Down Expand Up @@ -379,17 +383,18 @@ def reset(self) -> None:
self.full_lru_list = LRUList(is_swa_list=False)
self.swa_lru_list = LRUList(is_swa_list=True)

def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the matching prefix from the radix tree.
Args:
key: A RadixKey contains token IDs to find a matching prefix.
params: MatchPrefixParams containing key.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
key = params.key
key.token_ids = self.key_convert_fn(key.token_ids)

if self.disable or len(key) == 0:
Expand Down Expand Up @@ -546,7 +551,7 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:

# The prefix indices could be updated, reuse it
match_result = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
MatchPrefixParams(key=RadixKey(page_aligned_token_ids, req.extra_key))
)
(new_indices, new_last_node) = (
match_result.device_indices,
Expand Down
Loading
Loading