Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
"""
key.token_ids = self.key_convert_fn(key.token_ids)

if self.disable or len(key) == 0:
def empty_match_result():
return MatchResult(
device_indices=torch.empty(
(0,),
Expand All @@ -278,10 +278,16 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
last_host_node=self.root_node,
)

if self.disable or len(key) == 0:
return empty_match_result()

if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]

if len(key) == 0:
return empty_match_result()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's good we do a guardrail here, but is this causing a bug? I believe there are protection within _match_prefix_helper as well.

Copy link
Contributor Author

@skyzh skyzh Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of.. this comes from my experiment to enable deterministic radix cache by simply setting the page_size_ in radix cache to the split size of prefill - it yields several issues. So this might just work fine with the current code.

My other argument is that we are going to return empty anyways, so we can skip a bunch of code here and do a shortcircuit path to directly return empty?


value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
Expand Down Expand Up @@ -475,9 +481,9 @@ def inc_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value)
self.evictable_size_ -= len(node.key)
self.protected_size_ += len(node.key)
delta -= len(node.key)
node.lock_ref += 1
node = node.parent
return delta
Expand All @@ -489,9 +495,9 @@ def dec_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value)
self.evictable_size_ += len(node.key)
self.protected_size_ -= len(node.key)
delta += len(node.key)
node.lock_ref -= 1
node = node.parent
return delta
Expand Down Expand Up @@ -589,7 +595,7 @@ def _insert_helper(self, node: TreeNode, key: RadixKey, value):
new_node.key = key
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
self.evictable_size_ += len(key)
self._record_store_event(new_node)
return total_prefix_length

Expand Down
Loading