diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6a7f44a310e4..45451c92e495 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -762,12 +762,7 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): token_ids = self.fill_ids[:max_prefix_len] if tree_cache is not None: - ( - self.prefix_indices, - self.last_node, - self.last_host_node, - self.host_hit_length, - ) = tree_cache.match_prefix( + match_result = tree_cache.match_prefix( key=RadixKey(token_ids=token_ids, extra_key=self.extra_key), **( {"req": self, "cow_mamba": True} @@ -775,6 +770,17 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): else {} ), ) + ( + self.prefix_indices, + self.last_node, + self.last_host_node, + self.host_hit_length, + ) = ( + match_result.device_indices, + match_result.last_device_node, + match_result.last_host_node, + match_result.host_hit_length, + ) self.cache_protected_len = len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index d9af0c88337e..7196c8dbb3a5 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -180,10 +180,19 @@ def _compute_prefix_matches( extra_key = r.extra_key # NOTE: the prefix_indices must always be aligned with last_node - r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = ( - self.tree_cache.match_prefix( - rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) - ) + match_result = self.tree_cache.match_prefix( + rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key) + ) + ( + r.prefix_indices, + r.last_node, + r.last_host_node, + r.host_hit_length, + ) = ( + match_result.device_indices, + match_result.last_device_node, + match_result.last_host_node, + match_result.host_hit_length, ) # NOTE(sang): This logic is for in-batch prefix caching; @@ -194,12 +203,11 @@ def _compute_prefix_matches( # threshold means we cannot use in-batch prefix caching for short prefixes. # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: - in_batch_matching_prefixes, _, _, _ = ( - self.waiting_queue_radix_tree.match_prefix( - rid=r.rid, - key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), - ) + match_result = self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, + key=RadixKey(token_ids=prefix_ids, extra_key=extra_key), ) + in_batch_matching_prefixes = match_result.device_indices if ( len(in_batch_matching_prefixes) >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 9c16b639b9f9..826bfa6c54e5 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -41,12 +41,16 @@ class MatchResult(NamedTuple): this **must** be the same as `last_device_node`. host_hit_length : Length of the KV cache hit on the host, if applicable. 0 if HiCache is not enabled. + mamba_branching_seqlen: The mamba radix cache branching point, which is the longest + page-aligned position that could've been cache hit if there + exists a mamba state. """ device_indices: torch.Tensor last_device_node: Any last_host_node: Any host_hit_length: int = 0 + mamba_branching_seqlen: Optional[int] = None class BasePrefixCache(ABC, PrefixCacheTrait): diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 7a9b4c0d0926..48d085c5844d 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -520,9 +520,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: self.req_to_token_pool.mamba_pool.free(mamba_value_forked) # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix( + match_result = self.match_prefix( RadixKey(page_aligned_token_ids, req.extra_key) ) + (new_indices, new_last_node) = ( + match_result.device_indices, + match_result.last_device_node, + ) if not mamba_exist: assert torch.equal(new_last_node.mamba_value, mamba_value_forked) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index e528cf116bac..398754fd57c6 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -428,7 +428,11 @@ def cache_unfinished_req(self, req: Req, chunked=False): ) # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix(radix_key) + match_result = self.match_prefix(radix_key) + (new_indices, new_last_node) = ( + match_result.device_indices, + match_result.last_device_node, + ) assert len(new_indices) == len(keys), f"{len(new_indices)=}, {len(keys)=}" self.req_to_token_pool.write( diff --git a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py index 820653b4ba8d..792fd0de2bcd 100644 --- a/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +++ b/python/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py @@ -232,7 +232,8 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # type: req.req_pool_idx, :kv_committed_len ] - _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key)) + match_result = self.match_prefix(RadixKey(token_ids, req.extra_key)) + new_last_node = match_result.last_device_node assert new_last_node is not None self.inc_lock_ref(new_last_node) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 2e97aca51c7b..5a253020106a 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -550,9 +550,14 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: ) # The prefix indices could be updated, reuse it - new_indices, new_last_node, _, _ = self.match_prefix( + match_result = self.match_prefix( RadixKey(page_aligned_token_ids, req.extra_key) ) + (new_indices, new_last_node) = ( + match_result.device_indices, + match_result.last_device_node, + ) + assert old_prefix_len <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}"