Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ def __repr__(self) -> str:
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"


def maybe_bigram_convert(
is_eagle: bool,
key: RadixKey,
value: Optional[torch.Tensor] = None,
) -> Tuple[RadixKey, Optional[torch.Tensor]]:
if is_eagle and not key.is_bigram:
key.token_ids = convert_to_bigram_key(key.token_ids)
key.is_bigram = True
if value is not None:
value = value[: len(key)]
return key, value


def page_align_keys(key: list, page_size) -> list:
if page_size == 1:
return key
page_aligned_len = len(key) // page_size * page_size
return key[:page_aligned_len]


class TreeNode:

counter = 0
Expand Down Expand Up @@ -342,12 +362,7 @@ def reset(self):
def maybe_bigram_convert(
self, key: RadixKey, value: Optional[torch.Tensor] = None
) -> Tuple[RadixKey, Optional[torch.Tensor]]:
if self.is_eagle and not key.is_bigram:
key.token_ids = convert_to_bigram_key(key.token_ids)
if value is not None:
value = value[: len(key)]

return key, value
return maybe_bigram_convert(self.is_eagle, key, value)

def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the longest cached prefix of ``key`` in the radix tree.
Expand Down Expand Up @@ -437,12 +452,6 @@ def insert(self, params: InsertParams) -> InsertResult:
prefix_len = self._insert_helper(self.root_node, key, value, priority)
return InsertResult(prefix_len=prefix_len)

def _page_align_keys(self, key: list) -> list:
if self.page_size == 1:
return key
page_aligned_len = len(key) // self.page_size * self.page_size
return key[:page_aligned_len]

def cache_finished_req(self, req: Req, is_insert: bool = True):
"""Cache request when it finishes."""
# In deterministic mode, disable finished request insertion to radix cache
Expand All @@ -464,7 +473,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True):

# Maybe convert to bigram keys for EAGLE
keys = convert_to_bigram_key(token_ids) if self.is_eagle else token_ids
keys = self._page_align_keys(keys)
keys = page_align_keys(keys, self.page_size)
values = kv_indices[: len(keys)].to(dtype=torch.int64, copy=True)
radix_key = RadixKey(keys, req.extra_key, is_bigram=self.is_eagle)

Expand Down Expand Up @@ -502,7 +511,7 @@ def cache_unfinished_req(self, req: Req, chunked=False):

# Maybe convert to bigram keys for EAGLE
keys = convert_to_bigram_key(token_ids) if self.is_eagle else token_ids
keys = self._page_align_keys(keys)
keys = page_align_keys(keys, self.page_size)
values = kv_indices[: len(keys)].to(dtype=torch.int64, copy=True)
radix_key = RadixKey(keys, req.extra_key, is_bigram=self.is_eagle)

Expand Down
98 changes: 28 additions & 70 deletions python/sglang/srt/mem_cache/swa_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
_key_match_page_size1,
_key_match_paged,
get_child_key,
maybe_bigram_convert,
page_align_keys,
)
from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.utils import convert_to_bigram_key
Expand Down Expand Up @@ -426,14 +428,10 @@ def insert(self, params: InsertParams) -> InsertResult:
prev_prefix_len = params.prev_prefix_len
swa_evicted_seqlen = params.swa_evicted_seqlen

key.token_ids = self.key_convert_fn(key.token_ids)

if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)

if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
key, value = maybe_bigram_convert(self.is_eagle, key, value)

prefix_len = self._insert_helper(
self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen
Expand All @@ -451,40 +449,29 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
return

token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = kv_committed_len - 1 if self.is_eagle else kv_committed_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :kv_committed_len
]

if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)

page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
# Maybe convert to bigram keys for EAGLE
keys = self.key_convert_fn(token_ids)
keys = page_align_keys(keys, self.page_size)
page_aligned_len = len(keys)
values = kv_indices[:page_aligned_len].to(dtype=torch.int64, copy=True)
radix_key = RadixKey(
keys[:page_aligned_len],
req.extra_key,
is_bigram=self.is_eagle,
)

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.cache_protected_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
old_prefix_len = req.cache_protected_len

# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
if is_insert:
self.insert(
InsertParams(
key=RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
value=page_aligned_kv_indices,
key=radix_key,
value=values,
prev_prefix_len=old_prefix_len,
swa_evicted_seqlen=req.swa_evicted_seqlen,
)
Expand Down Expand Up @@ -512,58 +499,35 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
return

token_ids = req.fill_ids
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :all_token_len
req.req_pool_idx, : len(token_ids)
]

if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)

# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]

old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.cache_protected_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
keys = self.key_convert_fn(token_ids)
keys = page_align_keys(keys, self.page_size)
values = kv_indices[: len(keys)].to(dtype=torch.int64, copy=True)
radix_key = RadixKey(keys, req.extra_key, is_bigram=self.is_eagle)
old_prefix_len = req.cache_protected_len

# Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices
result = self.insert(
InsertParams(
key=RadixKey(page_aligned_token_ids, req.extra_key),
value=page_aligned_kv_indices,
key=radix_key,
value=values,
prev_prefix_len=old_prefix_len,
)
)
new_prefix_len = result.prefix_len

# The prefix indices could be updated, reuse it
match_result = self.match_prefix(
MatchPrefixParams(key=RadixKey(page_aligned_token_ids, req.extra_key))
)
match_result = self.match_prefix(MatchPrefixParams(key=radix_key))
new_indices, new_last_node = (
match_result.device_indices,
match_result.last_device_node,
)

assert old_prefix_len <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert old_prefix_len <= len(new_indices), f"{old_prefix_len=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
Expand All @@ -576,18 +540,12 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)

# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
if len(new_indices) < len(kv_indices):
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.prefix_indices = new_indices
req.last_node = new_last_node
req.swa_uuid_for_lock = swa_uuid_for_lock

Expand Down Expand Up @@ -874,7 +832,7 @@ def _match_prefix_helper(
def _match_pre_processor(self, params: MatchPrefixParams) -> Optional[RadixKey]:
"""Preprocess the key before matching."""
key = params.key
key.token_ids = self.key_convert_fn(key.token_ids)
key, _ = maybe_bigram_convert(self.is_eagle, key)

if self.disable or len(key) == 0:
return None
Expand Down
Loading
Loading