Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,19 +762,25 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
token_ids = self.fill_ids[:max_prefix_len]

if tree_cache is not None:
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
match_result = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = (
match_result.device_indices,
match_result.last_device_node,
match_result.last_host_node,
match_result.host_hit_length,
)
Comment on lines +773 to +783
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tuple unpacking here is a bit verbose. You can directly assign the attributes of match_result to the self attributes for better readability and conciseness.

            self.prefix_indices = match_result.device_indices
            self.last_node = match_result.last_device_node
            self.last_host_node = match_result.last_host_node
            self.host_hit_length = match_result.host_hit_length

self.cache_protected_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

Expand Down
26 changes: 17 additions & 9 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,19 @@ def _compute_prefix_matches(
extra_key = r.extra_key

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
self.tree_cache.match_prefix(
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
)
match_result = self.tree_cache.match_prefix(
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
)
(
r.prefix_indices,
r.last_node,
r.last_host_node,
r.host_hit_length,
) = (
match_result.device_indices,
match_result.last_device_node,
match_result.last_host_node,
match_result.host_hit_length,
)
Comment on lines +186 to 196
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tuple unpacking here is a bit verbose. You can directly assign the attributes of match_result to the r attributes for better readability and conciseness.

            r.prefix_indices = match_result.device_indices
            r.last_node = match_result.last_device_node
            r.last_host_node = match_result.last_host_node
            r.host_hit_length = match_result.host_hit_length


# NOTE(sang): This logic is for in-batch prefix caching;
Expand All @@ -194,12 +203,11 @@ def _compute_prefix_matches(
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _, _, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid,
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
)
match_result = self.waiting_queue_radix_tree.match_prefix(
rid=r.rid,
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
)
in_batch_matching_prefixes = match_result.device_indices
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ class MatchResult(NamedTuple):
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
mamba_branching_seqlen: The mamba radix cache branching point, which is the longest
page-aligned position that could've been cache hit if there
exists a mamba state.
"""

device_indices: torch.Tensor
last_device_node: Any
last_host_node: Any
host_hit_length: int = 0
mamba_branching_seqlen: Optional[int] = None


class BasePrefixCache(ABC, PrefixCacheTrait):
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/mem_cache/mamba_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
self.req_to_token_pool.mamba_pool.free(mamba_value_forked)

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
match_result = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
(new_indices, new_last_node) = (
match_result.device_indices,
match_result.last_device_node,
)
Comment on lines +526 to +529
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tuple unpacking here is a bit verbose. You can directly assign the attributes of match_result for better readability and conciseness.

        new_indices = match_result.device_indices
        new_last_node = match_result.last_device_node


if not mamba_exist:
assert torch.equal(new_last_node.mamba_value, mamba_value_forked)
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,11 @@ def cache_unfinished_req(self, req: Req, chunked=False):
)

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(radix_key)
match_result = self.match_prefix(radix_key)
(new_indices, new_last_node) = (
match_result.device_indices,
match_result.last_device_node,
)
Comment on lines +432 to +435
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tuple unpacking here is a bit verbose. You can directly assign the attributes of match_result for better readability and conciseness.

        new_indices = match_result.device_indices
        new_last_node = match_result.last_device_node

assert len(new_indices) == len(keys), f"{len(new_indices)=}, {len(keys)=}"

self.req_to_token_pool.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # type:
req.req_pool_idx, :kv_committed_len
]

_, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
match_result = self.match_prefix(RadixKey(token_ids, req.extra_key))
new_last_node = match_result.last_device_node
assert new_last_node is not None

self.inc_lock_ref(new_last_node)
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/mem_cache/swa_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,14 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
)

# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
match_result = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
(new_indices, new_last_node) = (
match_result.device_indices,
match_result.last_device_node,
)
Comment on lines +556 to +559
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tuple unpacking here is a bit verbose. You can directly assign the attributes of match_result for better readability and conciseness.

        new_indices = match_result.device_indices
        new_last_node = match_result.last_device_node


assert old_prefix_len <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
Expand Down
Loading