Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,12 @@ def __init__(
# for corss-endoder model
self.token_type_ids = token_type_ids

# The length of KV that have been removed in swa chunk cache
self.evicted_seqlen_local = 0
# The length of KV that have been removed in swa cache.
# SWA KV cache eviction behavior differs by cache type:
# - Radix cache: KV in range [cache_protected_len, swa_evicted_seqlen) is freed manually in
# `ScheduleBatch.maybe_evict_swa`; KV in range [0, cache_protected_len) is freed during radix cache eviction.
# - Chunk cache: KV in range [0, swa_evicted_seqlen) is freed manually in `ScheduleBatch.maybe_evict_swa`.
self.swa_evicted_seqlen = 0

# The index of the extend / decode batch
self.extend_batch_idx = 0
Expand Down Expand Up @@ -2264,6 +2268,59 @@ def copy(self):
dp_cooperation_info=self.dp_cooperation_info,
)

def maybe_evict_swa(self):
if self.tree_cache.supports_swa():
sliding_window_size = self.tree_cache.sliding_window_size
for idx, req in enumerate(self.reqs):
if self.forward_mode.is_decode():
# We set evict_swa condition here with two reasons:
# 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running.
# 2. Evict swa every window_size tokens to reduce the overhead.
if req.decode_batch_idx % sliding_window_size == 1:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat

self._evict_swa(req, req.seqlen - 1)
elif self.forward_mode.is_extend() and self.tree_cache.is_chunk_cache():
pre_len = self.prefix_lens[idx]
if self.enable_overlap:
# In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens
if req.extend_batch_idx < 2:
continue
else:
server_args = get_global_server_args()
pre_len = (
pre_len - server_args.chunked_prefill_size
if server_args.chunked_prefill_size > 0
else pre_len
)
self._evict_swa(req, pre_len)
else:
self._evict_swa(req, pre_len)

def _evict_swa(self, req: Req, pre_len: int):
assert self.tree_cache.supports_swa(), "prefix cache must support swa"
sliding_window_size = self.tree_cache.sliding_window_size

# For swa radix cache, we need to evict the tokens that are not in the tree cache and also not in the sliding window
assert (
req.cache_protected_len % self.tree_cache.page_size == 0
), "cache_protected_len must be page aligned"
req.swa_evicted_seqlen = max(req.swa_evicted_seqlen, req.cache_protected_len)

new_swa_evicted_seqlen = max(
req.swa_evicted_seqlen, pre_len - sliding_window_size
)

if self.tree_cache.page_size > 1:
new_swa_evicted_seqlen = (
new_swa_evicted_seqlen // self.tree_cache.page_size
) * self.tree_cache.page_size

if new_swa_evicted_seqlen > req.swa_evicted_seqlen:
free_slots = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.swa_evicted_seqlen : new_swa_evicted_seqlen
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.swa_evicted_seqlen = new_swa_evicted_seqlen

def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid_swa:
return (
Expand Down
7 changes: 3 additions & 4 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,9 @@ def init_cache_with_memory_pool(self):
else:
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache

params.sliding_window_size = self.model_config.sliding_window_size
params.attention_chunk_size = self.model_config.attention_chunk_size

self.tree_cache = SWAChunkCache(params)
self.tree_cache = SWAChunkCache(
params, sliding_window_size=self.sliding_window_size
)
else:

if envs.SGLANG_EXPERIMENTAL_CPP_RADIX_TREE.get():
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/mem_cache/cache_init_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ class CacheInitParams:

enable_mamba_extra_buffer: bool = False

# For SWAChunkCache
sliding_window_size: Optional[int] = None
attention_chunk_size: Optional[int] = None

pp_rank: int = 0
pp_size: int = 1

Expand Down
54 changes: 6 additions & 48 deletions python/sglang/srt/mem_cache/chunk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,62 +84,20 @@ def pretty_print(self):


class SWAChunkCache(ChunkCache):
"""ChunkCache with support for hybrid KV cache operations."""
"""ChunkCache with support for sliding window attention."""

def __init__(self, params: CacheInitParams):
def __init__(self, params: CacheInitParams, sliding_window_size: int):
assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
super().__init__(params)

assert (
params.sliding_window_size is not None
or params.attention_chunk_size is not None
), "Sliding window size or attention chunk size must be set for SWAChunkCache"

if (
params.sliding_window_size is not None
and params.attention_chunk_size is not None
):
logger.warning(
"Sliding window size and attention chunk size are both set, use sliding window size for chunk cache eviction."
)

self.sliding_window_size = params.sliding_window_size
self.attention_chunk_size = params.attention_chunk_size
self.window_size = self.sliding_window_size or self.attention_chunk_size

self.sliding_window_size = sliding_window_size
self.chunked_prefill_size = params.chunked_prefill_size

def supports_swa(self) -> bool:
assert (
self.sliding_window_size is not None
), "sliding_window_size must be set for SWAChunkCache"
return True

def evict_swa(
self,
req: Req,
prelen: int,
):
if self.sliding_window_size is not None:
# Sliding window attention (e.g. mimo-v2-flash, gpt-oss)
new_evicted_seqlen_local = max(
req.evicted_seqlen_local, prelen - self.sliding_window_size
)
elif self.attention_chunk_size is not None:
# Local attention (e.g. llama4)
new_evicted_seqlen_local = max(
req.evicted_seqlen_local,
prelen // self.attention_chunk_size * self.attention_chunk_size,
)

if self.page_size > 1:
new_evicted_seqlen_local = (
new_evicted_seqlen_local // self.page_size
) * self.page_size

if new_evicted_seqlen_local > req.evicted_seqlen_local:
free_slots = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local

def evict(self, num_tokens: int):
pass
22 changes: 3 additions & 19 deletions python/sglang/srt/mem_cache/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,18 +338,7 @@ def alloc_for_extend(
req_pool_indices: request pool indices as list
"""
# free out-of-window swa tokens
if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache():
for req, pre_len in zip(batch.reqs, batch.prefix_lens):
if batch.enable_overlap:
# In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens
if req.extend_batch_idx < 2:
continue
else:
batch.tree_cache.evict_swa(
req, pre_len - batch.tree_cache.chunked_prefill_size
)
else:
batch.tree_cache.evict_swa(req, pre_len)
batch.maybe_evict_swa()

bs = len(batch.reqs)
prefix_tensors = [r.prefix_indices for r in batch.reqs]
Expand Down Expand Up @@ -440,13 +429,8 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
Returns:
out_cache_loc: allocated cache locations
"""
if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache():
for req in batch.reqs:
# We set evict_swa condition here with two reasons:
# 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running.
# 2. Evict swa every window_size tokens to reduce the overhead.
if req.decode_batch_idx % batch.tree_cache.window_size == 1:
batch.tree_cache.evict_swa(req, req.seqlen - 1)

batch.maybe_evict_swa()

bs = batch.seq_lens.shape[0]

Expand Down
113 changes: 90 additions & 23 deletions python/sglang/srt/mem_cache/swa_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int):
##### Public API #####

def supports_swa(self) -> bool:
assert (
self.sliding_window_size is not None
), "sliding_window_size must be set for SWARadixCache"
return True

def reset(self) -> None:
Expand Down Expand Up @@ -418,7 +421,13 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
last_host_node=last_node,
)

def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
def insert(
self,
key: RadixKey,
value=None,
prev_prefix_len: int = 0,
swa_evicted_seqlen: int = 0,
Copy link
Collaborator

@hanming-lu hanming-lu Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. you are actually evicting them instead of unlocking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the generated tokens are not in the tree until the sequence finished. So we cannot unlock it.

) -> int:
if self.disable:
return 0

Expand All @@ -431,7 +440,9 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]

return self._insert_helper(self.root_node, key, value, prev_prefix_len)
return self._insert_helper(
self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen
)

def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
"""Cache request when it finishes."""
Expand Down Expand Up @@ -481,6 +492,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
old_prefix_len,
req.swa_evicted_seqlen,
)
else:
self.token_to_kv_pool_allocator.free(
Expand Down Expand Up @@ -878,6 +890,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod
new_node.full_lock_ref = child.full_lock_ref
new_node.swa_lock_ref = child.swa_lock_ref
new_node.key = child.key[:split_len]
assert len(new_node.key) > 0, f"new_node.key should not be empty"
new_node.value = child.value[:split_len].clone()
# parent inherits the swa_uuid from child for swa lock ref
new_node.swa_uuid = child.swa_uuid
Expand All @@ -891,6 +904,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod
self.swa_lru_list.remove_node(child)
child.parent = new_node
child.key = child.key[split_len:]
assert len(child.key) > 0, f"child.key should not be empty"
child.value = child.value[split_len:].clone()
new_node.parent.children[self.get_child_key_fn(key)] = new_node

Expand All @@ -904,7 +918,12 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod
return new_node

def _insert_helper(
self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
self,
node: TreeNode,
key: RadixKey,
value,
update_kv_after_len: int,
swa_evicted_seqlen: int = 0,
) -> int:
# Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first
Expand Down Expand Up @@ -936,23 +955,44 @@ def _insert_helper(
# contains tombstone. If this is the case and we don't update the kv value, then
# the prefill prefix matching will stuck.
if update_kv_after_len < total_prefix_length + prefix_len:
first_diff_idx = max(0, update_kv_after_len - total_prefix_length)
# For page_size > 1 and chunked prefill case, update_kv_after_len may be not page-aligned due to a trailing partial page
# (kept in the request but not inserted into the radix tree) appended to prefix_indices.
if node.swa_tombstone:
assert (
node.swa_lock_ref == 0
), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}"
self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:])
node.value = value[:prefix_len]
node.swa_tombstone = False

# insert the node into the lru lists
self.swa_lru_list.insert_mru(node)

self.swa_evictable_size_ += len(node.value)
assert (
swa_evicted_seqlen % self.page_size == 0
), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}"
if swa_evicted_seqlen <= total_prefix_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments for each branch please

# Branch 1: all swa tokens of value[:prefix_len] are not evicted, so we can insert it to the tree directly.
# Free full tokens in the original tree node.
self.token_to_kv_pool_allocator.free(node.value[:prefix_len])
# Overwrite the new value in request to the tree node.
node.value = value[:prefix_len].clone()
node.swa_tombstone = False
self.swa_lru_list.insert_mru(node)
self.swa_evictable_size_ += len(node.value)
elif swa_evicted_seqlen < total_prefix_length + prefix_len:
# Branch 2: part of swa tokens of value[:prefix_len] are evicted, so we need to split the node and insert the value to new node.
start_update_idx = swa_evicted_seqlen - total_prefix_length
self.token_to_kv_pool_allocator.free(
node.value[start_update_idx:prefix_len]
)
self._split_node(node.key, node, start_update_idx)
# Here node is the new node after split, so we can overwrite the value to the new node.
# The old node is still swa tombstone and the full token is not freed.
node.value = value[start_update_idx:prefix_len].clone()
self.token_to_kv_pool_allocator.free(value[:start_update_idx])
node.swa_tombstone = False
self.swa_lru_list.insert_mru(node)
self.swa_evictable_size_ += len(node.value)
else:
# Branch 3: all swa tokens of value[:prefix_len] are evicted, so we don't need to update the node.
self.token_to_kv_pool_allocator.free(value[:prefix_len])
else:
self.token_to_kv_pool_allocator.free(
value[first_diff_idx:prefix_len]
)
# The node is not tombstone, so we don't need to update the node.
self.token_to_kv_pool_allocator.free(value[:prefix_len])

total_prefix_length += prefix_len
key = key[prefix_len:]
Expand All @@ -962,16 +1002,43 @@ def _insert_helper(
child_key = self.get_child_key_fn(key)

if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
self.full_lru_list.insert_mru(new_node)
if (
swa_evicted_seqlen > total_prefix_length
and swa_evicted_seqlen < total_prefix_length + len(key)
):
swa_tombstone_len = swa_evicted_seqlen - total_prefix_length
node = self._add_new_node(
node,
key[:swa_tombstone_len],
value[:swa_tombstone_len],
swa_tombstone=True,
)
key = key[swa_tombstone_len:]
value = value[swa_tombstone_len:]

self._add_new_node(node, key, value, swa_tombstone=False)
return total_prefix_length

def _add_new_node(
self,
parent: TreeNode,
key: RadixKey,
value: torch.Tensor,
swa_tombstone: bool = False,
) -> TreeNode:
assert len(key) > 0, f"key should not be empty"
new_node = TreeNode()
new_node.parent = parent
new_node.key = key
new_node.value = value.clone()
new_node.swa_tombstone = swa_tombstone
parent.children[self.get_child_key_fn(key)] = new_node
self.full_lru_list.insert_mru(new_node)
self.full_evictable_size_ += len(value)
if not swa_tombstone:
self.swa_lru_list.insert_mru(new_node)
node.children[child_key] = new_node
self.full_evictable_size_ += len(value)
self.swa_evictable_size_ += len(value)
return total_prefix_length
return new_node

def _iteratively_delete_tombstone_leaf(
self, node: TreeNode
Expand Down
Loading
Loading