diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 068bd184917b..fd9873ac1d27 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -59,7 +59,7 @@ from sglang.srt.environ import envs from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchPrefixParams from sglang.srt.mem_cache.common import ( alloc_for_decode, alloc_for_extend, @@ -864,12 +864,11 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): if tree_cache is not None: match_result = tree_cache.match_prefix( - key=RadixKey(token_ids=token_ids, extra_key=self.extra_key), - **( - {"req": self, "cow_mamba": True} - if tree_cache.supports_mamba() - else {} - ), + MatchPrefixParams( + key=RadixKey(token_ids=token_ids, extra_key=self.extra_key), + req=self if tree_cache.supports_mamba() else None, + cow_mamba=tree_cache.supports_mamba(), + ) ) ( self.prefix_indices, diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 49d71a27a1a9..57ed79f4e323 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -35,7 +35,7 @@ from sglang.srt.dllm.config import DllmConfig from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_in_seq_split from sglang.srt.managers.schedule_batch import DllmStagingReqs, Req, ScheduleBatch -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchPrefixParams from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator from sglang.srt.server_args import ServerArgs @@ -190,7 +190,9 @@ def _compute_prefix_matches( extra_key = r.extra_key # NOTE: the prefix_indices must always be aligned with last_node match_result = self.tree_cache.match_prefix( - rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) + MatchPrefixParams( + key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) + ) ) ( r.prefix_indices, @@ -213,8 +215,9 @@ def _compute_prefix_matches( # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: match_result = self.waiting_queue_radix_tree.match_prefix( - rid=r.rid, - key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), + MatchPrefixParams( + key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) + ) ) in_batch_matching_prefixes = match_result.device_indices if ( diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 8f2560d11625..bdf8730a3cbc 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import time from abc import ABC, abstractmethod from typing import ( @@ -20,6 +21,7 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req + from sglang.srt.mem_cache.radix_cache import RadixKey @runtime_checkable @@ -30,6 +32,17 @@ class PrefixCacheTrait(Protocol): disable: bool +@dataclasses.dataclass +class MatchPrefixParams: + """Unified parameters for match_prefix across different cache types""" + + key: RadixKey + + # Mamba specific + cow_mamba: bool = False + req: Optional[Req] = None + + class MatchResult(NamedTuple): """Result of a prefix match operation. @@ -77,7 +90,7 @@ def reset(self): pass @abstractmethod - def match_prefix(self, key: Any, **kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: pass @abstractmethod diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 38dbdd6fe192..d0861e235206 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -7,7 +7,11 @@ import torch -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + MatchPrefixParams, + MatchResult, +) from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: @@ -43,7 +47,7 @@ def disable(self): def reset(self): pass - def match_prefix(self, **unused_kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: return MatchResult( device_indices=torch.empty((0,), dtype=torch.int64), last_device_node=None, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index b3ea8bdff8a5..f74881f1b638 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -10,7 +10,7 @@ import torch from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation -from sglang.srt.mem_cache.base_prefix_cache import MatchResult +from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams, MatchResult from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, @@ -688,7 +688,8 @@ def terminate_prefetch(self, req_id: str): return operation.mark_terminate() - def match_prefix(self, key: RadixKey, **kwargs): + def match_prefix(self, params: MatchPrefixParams): + key = params.key empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) key, _ = self.maybe_bigram_convert(key) if self.disable or len(key) == 0: diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index acd2fc1cb2b7..24dc81fd3c55 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -33,7 +33,11 @@ PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + MatchPrefixParams, + MatchResult, +) from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool from sglang.srt.mem_cache.radix_cache import ( RadixKey, @@ -414,10 +418,10 @@ def reset(self) -> None: self.full_lru_list = LRUList(mamba=False) self.mamba_lru_list = LRUList(mamba=True) - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: """Find the matching prefix from the radix tree. Args: - key: A RadixKey contains token IDs to find a matching prefix. + params: MatchPrefixParams containing key and optional Mamba-specific parameters. Returns: A tuple of a tensor of matching prefix token IDs and the last node that contains the prefix values. Note that @@ -425,8 +429,9 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: The last node create a new child if the prefix is shorter than the last node's value. """ - cow_mamba: bool = kwargs.get("cow_mamba", False) - req: Req = kwargs.get("req", None) + key = params.key + cow_mamba = params.cow_mamba + req = params.req if self.disable or len(key) == 0: return MatchResult( @@ -658,7 +663,7 @@ def _skip_cache_unfinished_req(req: Req) -> None: # The prefix indices could be updated, reuse it match_result = self.match_prefix( - RadixKey(page_aligned_token_ids, req.extra_key) + MatchPrefixParams(key=RadixKey(page_aligned_token_ids, req.extra_key)) ) (new_indices, new_last_node) = ( match_result.device_indices, diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 24dc1a74171a..2bebf4bf4088 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -39,7 +39,11 @@ BlockRemoved, BlockStored, ) -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + MatchPrefixParams, + MatchResult, +) from sglang.srt.mem_cache.evict_policy import ( EvictionStrategy, FIFOStrategy, @@ -337,7 +341,7 @@ def maybe_bigram_convert( return key, value - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: """Find the longest cached prefix of ``key`` in the radix tree. The logical namespace for prefix matching is determined by both the @@ -352,12 +356,11 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: context) by supplying a distinct ``extra_key``. Args: - key (RadixKey): The lookup key containing a list of token ids and an - optional ``extra_key`` namespace tag. If ``page_size > 1`` the - length is internally truncated to a multiple of ``page_size`` - before matching. Passing an empty key returns an empty result - with the root as the last node. - **kwargs: Reserved for future extensions (ignored currently). + params (MatchPrefixParams): Parameters containing the lookup key + with a list of token ids and an optional ``extra_key`` namespace tag. + If ``page_size > 1`` the length is internally truncated to a multiple + of ``page_size`` before matching. Passing an empty key returns an + empty result with the root as the last node. Returns: MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of @@ -375,6 +378,7 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: to expose a precise boundary; this structural refinement improves subsequent match efficiency and does not duplicate data. """ + key = params.key key, _ = self.maybe_bigram_convert(key) def empty_match_result(): @@ -501,7 +505,7 @@ def cache_unfinished_req(self, req: Req, chunked=False): ) # The prefix indices could be updated, reuse it - match_result = self.match_prefix(radix_key) + match_result = self.match_prefix(MatchPrefixParams(key=radix_key)) (new_indices, new_last_node) = ( match_result.device_indices, match_result.last_device_node, @@ -845,4 +849,8 @@ def take_events(self): tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None)) tree.pretty_print() - print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None))) + print( + tree.match_prefix( + MatchPrefixParams(key=RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)) + ) + ) diff --git a/python/sglang/srt/mem_cache/radix_cache_cpp.py b/python/sglang/srt/mem_cache/radix_cache_cpp.py index 8bb08351d853..c1a55c929e80 100644 --- a/python/sglang/srt/mem_cache/radix_cache_cpp.py +++ b/python/sglang/srt/mem_cache/radix_cache_cpp.py @@ -6,7 +6,11 @@ import torch -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + MatchPrefixParams, + MatchResult, +) from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import ( IOHandle, RadixTreeCpp, @@ -89,7 +93,8 @@ def reset(self): raise NotImplementedError("Host cache is not supported yet") self.tree.reset() - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: + key = params.key device_indices_vec, host_indices_length, node_gpu, node_cpu = ( self.tree.match_prefix(key.token_ids) ) diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index 3dd196efd6e7..3b4702943149 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -6,7 +6,7 @@ import torch -from sglang.srt.mem_cache.base_prefix_cache import MatchResult +from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams, MatchResult from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode try: @@ -119,7 +119,7 @@ def reset(self): # type: ignore[override] with self._node_lock: self._in_flight_nodes.clear() - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override] + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: # type: ignore[override] """Match cached prefix; if there's a tail miss, prefetch from LMCache. Reuses the base matching logic to obtain (value, last_node). If there @@ -128,14 +128,15 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[ into those slots, then materialize a new child node for the retrieved chunk. """ + key = params.key if self.disable or not key: - return super().match_prefix(key, **kwargs) + return super().match_prefix(params) if self.page_size != 1: aligned_len = len(key) // self.page_size * self.page_size key = key[:aligned_len] - base_res = super().match_prefix(key, **kwargs) + base_res = super().match_prefix(params) value: torch.Tensor = base_res.device_indices last_node: TreeNode = base_res.last_device_node @@ -229,7 +230,9 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # type: req.req_pool_idx, :kv_committed_len ] - match_result = self.match_prefix(RadixKey(token_ids, req.extra_key)) + match_result = self.match_prefix( + MatchPrefixParams(key=RadixKey(token_ids, req.extra_key)) + ) new_last_node = match_result.last_device_node assert new_last_node is not None diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 9f2a75fae8dd..e033bc68df2c 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -28,7 +28,11 @@ import torch from numpy import float64 -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + MatchPrefixParams, + MatchResult, +) from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import ( RadixKey, @@ -379,10 +383,10 @@ def reset(self) -> None: self.full_lru_list = LRUList(is_swa_list=False) self.swa_lru_list = LRUList(is_swa_list=True) - def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + def match_prefix(self, params: MatchPrefixParams) -> MatchResult: """Find the matching prefix from the radix tree. Args: - key: A RadixKey contains token IDs to find a matching prefix. + params: MatchPrefixParams containing key. Returns: A tuple of a tensor of matching prefix token IDs and the last node that contains the prefix values. Note that @@ -390,6 +394,7 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: The last node create a new child if the prefix is shorter than the last node's value. """ + key = params.key key.token_ids = self.key_convert_fn(key.token_ids) if self.disable or len(key) == 0: @@ -546,7 +551,7 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: # The prefix indices could be updated, reuse it match_result = self.match_prefix( - RadixKey(page_aligned_token_ids, req.extra_key) + MatchPrefixParams(key=RadixKey(page_aligned_token_ids, req.extra_key)) ) (new_indices, new_last_node) = ( match_result.device_indices, diff --git a/test/registered/radix_cache/test_mamba_unittest.py b/test/registered/radix_cache/test_mamba_unittest.py index 76c9c7918686..0ad099b6ac7a 100644 --- a/test/registered/radix_cache/test_mamba_unittest.py +++ b/test/registered/radix_cache/test_mamba_unittest.py @@ -6,6 +6,7 @@ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool @@ -289,7 +290,7 @@ def make_dummy_req(): tree.pretty_print() req5_token_ids = [1, 2, 3, 4, 5] - result = tree.match_prefix(RadixKey(req5_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req5_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -297,7 +298,7 @@ def make_dummy_req(): assert len(kv_indices) == 0 req6_token_ids = [1, 2, 3, 4, 5, 60, 70] - result = tree.match_prefix(RadixKey(req6_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req6_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -306,7 +307,7 @@ def make_dummy_req(): assert len(last_node.key) == 2 req7_token_ids = [1, 2, 3, 4, 5, 6, 7] - result = tree.match_prefix(RadixKey(req7_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req7_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req7: token_ids: {req7_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -320,7 +321,7 @@ def make_dummy_req(): tree.pretty_print() req8_token_ids = [1, 2, 3, 4, 5, 60, 70] - result = tree.match_prefix(RadixKey(req8_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req8_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req8: token_ids: {req8_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -331,7 +332,7 @@ def make_dummy_req(): req9_token_ids = [1, 2, 3, 4, 5, 6, 7] req9 = make_dummy_req() result = tree.match_prefix( - RadixKey(req9_token_ids), **({"req": req9, "cow_mamba": True}) + MatchPrefixParams(key=RadixKey(req9_token_ids), req=req9, cow_mamba=True) ) kv_indices, last_node = result.device_indices, result.last_device_node assert req9.mamba_pool_idx is not None diff --git a/test/registered/radix_cache/test_radix_cache_unit.py b/test/registered/radix_cache/test_radix_cache_unit.py index 4a165cd9031b..d6cac0e9f81f 100644 --- a/test/registered/radix_cache/test_radix_cache_unit.py +++ b/test/registered/radix_cache/test_radix_cache_unit.py @@ -31,6 +31,7 @@ import torch from sglang.srt.disaggregation.kv_events import BlockRemoved, BlockStored +from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode # Test constants @@ -294,12 +295,12 @@ def test_insert_and_match_basic(self): self.assertEqual(cache.evictable_size(), 3) # Test match_prefix - result = cache.match_prefix(RadixKey([1, 2, 3])) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]))) self.assertEqual(len(result.device_indices), 3) torch.testing.assert_close(result.device_indices, value) # Test partial match - result = cache.match_prefix(RadixKey([1, 2])) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2]))) self.assertEqual(len(result.device_indices), 2) torch.testing.assert_close( result.device_indices, torch.tensor([10, 20], dtype=torch.int64) @@ -402,10 +403,12 @@ def test_extra_key_isolation(self): ) # Keys with different extra_key should not match each other - result1 = cache.match_prefix(RadixKey([1, 2, 3], "key1")) - result2 = cache.match_prefix(RadixKey([1, 2, 3], "key2")) - result3 = cache.match_prefix(RadixKey([1, 2, 3], None)) - result4 = cache.match_prefix(RadixKey([1, 2, 3], "nonexistent")) + result1 = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3], "key1"))) + result2 = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3], "key2"))) + result3 = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3], None))) + result4 = cache.match_prefix( + MatchPrefixParams(key=RadixKey([1, 2, 3], "nonexistent")) + ) # Each should match only its own data self.assertEqual(len(result1.device_indices), 3) @@ -434,7 +437,7 @@ def test_lock_ref_operations(self): cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 20, 30], dtype=torch.int64)) # Get node - result = cache.match_prefix(RadixKey([1, 2, 3])) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey([1, 2, 3]))) node = result.last_device_node initial_evictable = cache.evictable_size() @@ -485,7 +488,7 @@ def test_page_alignment_boundary(self): tokens = list(range(sequence_length)) cache.insert(RadixKey(tokens), torch.tensor(tokens, dtype=torch.int64)) - result = cache.match_prefix(RadixKey(tokens)) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(tokens))) self.assertGreater(len(result.device_indices), 0) # Match length should be page-aligned @@ -541,23 +544,25 @@ def test_advanced_prefix_match_with_node_splits(self): # Match that causes a split inside an existing node: # take first 4 tokens of seq1, then diverge. query1 = [1, 2, 3, 4, 999, 1000] - result1 = cache.match_prefix(RadixKey(query1)) + result1 = cache.match_prefix(MatchPrefixParams(key=RadixKey(query1))) torch.testing.assert_close(result1.device_indices, val1[:4]) # No data change after structural split during matching. self.assertEqual(cache.total_size(), baseline_total) # Full match of the long sequence still returns the full indices. - result_full = cache.match_prefix(RadixKey(seq1)) + result_full = cache.match_prefix(MatchPrefixParams(key=RadixKey(seq1))) torch.testing.assert_close(result_full.device_indices, val1) # Another split deeper on the path (after matching 6 tokens, then diverge). query2 = [1, 2, 3, 4, 5, 6, 777, 888] - result2 = cache.match_prefix(RadixKey(query2)) + result2 = cache.match_prefix(MatchPrefixParams(key=RadixKey(query2))) torch.testing.assert_close(result2.device_indices, val1[:6]) self.assertEqual(cache.total_size(), baseline_total) # Matching the short diverging branch should return exactly its indices. - result_branch = cache.match_prefix(RadixKey(seq2)) + result_branch = cache.match_prefix( + MatchPrefixParams(key=RadixKey(seq2)) + ) torch.testing.assert_close(result_branch.device_indices, val2) def test_hash_value_storage(self): diff --git a/test/registered/radix_cache/test_swa_unittest.py b/test/registered/radix_cache/test_swa_unittest.py index 49891dbd596b..63548401b608 100644 --- a/test/registered/radix_cache/test_swa_unittest.py +++ b/test/registered/radix_cache/test_swa_unittest.py @@ -2,6 +2,7 @@ import torch +from sglang.srt.mem_cache.base_prefix_cache import MatchPrefixParams from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey @@ -188,7 +189,7 @@ def test_swa_radix_cache_1(self): tree.pretty_print() req5_token_ids = [1, 2, 3, 4, 5] - result = tree.match_prefix(RadixKey(req5_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req5_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -196,7 +197,7 @@ def test_swa_radix_cache_1(self): self.assertEqual(len(kv_indices), 0) req6_token_ids = [1, 2, 3, 4, 5, 60, 70] - result = tree.match_prefix(RadixKey(req6_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req6_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -329,7 +330,7 @@ def test_swa_radix_cache_eagle(self): tree.pretty_print() req5_token_ids = [1, 2, 3, 4, 5] - result = tree.match_prefix(RadixKey(req5_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req5_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" @@ -337,7 +338,7 @@ def test_swa_radix_cache_eagle(self): self.assertEqual(len(kv_indices), 0) # no swa prefix matched req6_token_ids = [1, 2, 3, 4, 5, 60, 70] - result = tree.match_prefix(RadixKey(req6_token_ids)) + result = tree.match_prefix(MatchPrefixParams(key=RadixKey(req6_token_ids))) kv_indices, last_node = result.device_indices, result.last_device_node print( f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"