From 56a9e5696a89caedde20d18e7a91d511568c9c35 Mon Sep 17 00:00:00 2001 From: ispobock Date: Wed, 14 Jan 2026 14:18:41 +0000 Subject: [PATCH 01/12] tmp --- python/sglang/srt/mem_cache/common.py | 5 +- .../sglang/srt/mem_cache/swa_radix_cache.py | 83 ++++++++++++++++--- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index f5b001a8852c..b2e7eb1a7534 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -12,6 +12,7 @@ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import support_triton from sglang.srt.utils.common import ceil_align @@ -442,7 +443,9 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: Returns: out_cache_loc: allocated cache locations """ - if isinstance(batch.tree_cache, SWAChunkCache): + if isinstance(batch.tree_cache, SWAChunkCache) or isinstance( + batch.tree_cache, SWARadixCache + ): for req in batch.reqs: # We set evict_swa condition here with two reasons: # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 625a8a166d7a..ff96380d4dbc 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -358,6 +358,7 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int): self.init_metrics_collector() self.sliding_window_size = sliding_window_size + self.window_size = self.sliding_window_size self.reset() ##### Public API ##### @@ -415,7 +416,13 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: last_host_node=last_node, ) - def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: + def insert( + self, + key: RadixKey, + value=None, + prev_prefix_len: int = 0, + evicted_seqlen: int = 0, + ) -> int: if self.disable: return 0 @@ -428,7 +435,9 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: # Make sure the value len equal to the EAGLE bigram key len value = value[: len(key)] - return self._insert_helper(self.root_node, key, value, prev_prefix_len) + return self._insert_helper( + self.root_node, key, value, prev_prefix_len, evicted_seqlen + ) def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: """Cache request when it finishes.""" @@ -478,6 +487,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, old_prefix_len, + req.evicted_seqlen_local, ) else: self.token_to_kv_pool_allocator.free( @@ -673,6 +683,27 @@ def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None: self.update_eviction_metrics(full_num_evicted + swa_num_evicted, start_time) + def evict_swa(self, req: Req, pre_len: int) -> None: + # evict the swa tokens that not in the tree cache and also not in the sliding window + req.evicted_seqlen_local = max( + req.evicted_seqlen_local, req.cache_protected_len + ) + new_evicted_seqlen_local = max( + req.evicted_seqlen_local, pre_len - self.sliding_window_size + ) + + if self.page_size > 1: + new_evicted_seqlen_local = ( + new_evicted_seqlen_local // self.page_size + ) * self.page_size + + if new_evicted_seqlen_local > req.evicted_seqlen_local: + free_slots = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local + ] + self.token_to_kv_pool_allocator.free_swa(free_slots) + req.evicted_seqlen_local = new_evicted_seqlen_local + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: """ Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs @@ -902,7 +933,12 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod return new_node def _insert_helper( - self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int + self, + node: TreeNode, + key: RadixKey, + value, + update_kv_after_len: int, + evicted_seqlen: int = 0, ) -> int: # Update the last access time from root to leaf, so that # swa will tombstone the node closer to root first @@ -935,7 +971,7 @@ def _insert_helper( # the prefill prefix matching will stuck. if update_kv_after_len < total_prefix_length + prefix_len: first_diff_idx = max(0, update_kv_after_len - total_prefix_length) - if node.swa_tombstone: + if node.swa_tombstone and evicted_seqlen < total_prefix_length: assert ( node.swa_lock_ref == 0 ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}" @@ -960,16 +996,39 @@ def _insert_helper( child_key = self.get_child_key_fn(key) if len(key): - new_node = TreeNode() - new_node.parent = node - new_node.key = key - new_node.value = value - self.full_lru_list.insert_mru(new_node) + if ( + evicted_seqlen > total_prefix_length + and evicted_seqlen < total_prefix_length + len(key) + ): + swa_evicted_len = evicted_seqlen - total_prefix_length + node = self._add_new_node( + node, key[:swa_evicted_len], value[:swa_evicted_len], True + ) + key = key[swa_evicted_len:] + value = value[swa_evicted_len:] + + self._add_new_node(node, key, value, False) + return total_prefix_length + + def _add_new_node( + self, + node: TreeNode, + key: RadixKey, + value: torch.Tensor, + swa_tombstone: bool = False, + ) -> TreeNode: + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + new_node.swa_tombstone = swa_tombstone + node.children[self.get_child_key_fn(key)] = new_node + self.full_lru_list.insert_mru(new_node) + self.full_evictable_size_ += len(value) + if not swa_tombstone: self.swa_lru_list.insert_mru(new_node) - node.children[child_key] = new_node - self.full_evictable_size_ += len(value) self.swa_evictable_size_ += len(value) - return total_prefix_length + return new_node def _iteratively_delete_tombstone_leaf( self, node: TreeNode From 1aab9279222314df7e978233f231f02ead31d3d9 Mon Sep 17 00:00:00 2001 From: ispobock Date: Thu, 15 Jan 2026 08:57:16 +0000 Subject: [PATCH 02/12] update --- python/sglang/srt/managers/schedule_batch.py | 56 +++++++++++++++++++ python/sglang/srt/managers/scheduler.py | 7 +-- .../sglang/srt/mem_cache/cache_init_params.py | 4 -- python/sglang/srt/mem_cache/chunk_cache.py | 56 ++----------------- python/sglang/srt/mem_cache/common.py | 21 +------ .../sglang/srt/mem_cache/swa_radix_cache.py | 24 +------- .../sglang/srt/speculative/eagle_info_v2.py | 4 +- python/sglang/srt/speculative/eagle_worker.py | 4 +- 8 files changed, 72 insertions(+), 104 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5fae27b642e2..3fbc9de6595e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2205,6 +2205,62 @@ def copy(self): dp_cooperation_info=self.dp_cooperation_info, ) + def maybe_evict_swa(self): + if self.tree_cache.supports_swa(): + sliding_window_size = self.tree_cache.sliding_window_size + for idx, req in enumerate(self.reqs): + # TODO(ispobock): handle spec batch idx update + if self.forward_mode.is_decode(): + # We set evict_swa condition here with two reasons: + # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. + # 2. Evict swa every window_size tokens to reduce the overhead. + if req.decode_batch_idx % sliding_window_size == 1: + self._evict_swa(req, req.seqlen - 1) + elif self.forward_mode.is_extend() and self.tree_cache.is_chunk_cache(): + pre_len = self.prefix_lens[idx] + if self.enable_overlap: + # In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens + if req.extend_batch_idx < 2: + continue + else: + server_args = get_global_server_args() + pre_len = ( + pre_len - server_args.chunked_prefill_size + if server_args.chunked_prefill_size > 0 + else pre_len + ) + self._evict_swa(req, pre_len) + else: + self._evict_swa(req, pre_len) + + def _evict_swa(self, req: Req, pre_len: int): + assert self.tree_cache.supports_swa(), "prefix cache must support swa" + sliding_window_size = self.tree_cache.sliding_window_size + + # For swa radix cache, we need to evict the tokens that are not in the tree cache and also not in the sliding window + assert ( + req.cache_protected_len % self.tree_cache.page_size == 0 + ), "cache_protected_len must be page aligned" + req.evicted_seqlen_local = max( + req.evicted_seqlen_local, req.cache_protected_len + ) + + new_evicted_seqlen_local = max( + req.evicted_seqlen_local, pre_len - sliding_window_size + ) + + if self.tree_cache.page_size > 1: + new_evicted_seqlen_local = ( + new_evicted_seqlen_local // self.tree_cache.page_size + ) * self.tree_cache.page_size + + if new_evicted_seqlen_local > req.evicted_seqlen_local: + free_slots = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local + ] + self.token_to_kv_pool_allocator.free_swa(free_slots) + req.evicted_seqlen_local = new_evicted_seqlen_local + def _is_available_size_sufficient(self, num_tokens: int) -> bool: if self.is_hybrid_swa: return ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4ff5845f5c94..dc4ba521c104 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -634,10 +634,9 @@ def init_cache_with_memory_pool(self): else: from sglang.srt.mem_cache.chunk_cache import SWAChunkCache - params.sliding_window_size = self.model_config.sliding_window_size - params.attention_chunk_size = self.model_config.attention_chunk_size - - self.tree_cache = SWAChunkCache(params) + self.tree_cache = SWAChunkCache( + params, sliding_window_size=self.sliding_window_size + ) else: if envs.SGLANG_EXPERIMENTAL_CPP_RADIX_TREE.get(): diff --git a/python/sglang/srt/mem_cache/cache_init_params.py b/python/sglang/srt/mem_cache/cache_init_params.py index 2ed8803b0bcc..7f6111f8086e 100644 --- a/python/sglang/srt/mem_cache/cache_init_params.py +++ b/python/sglang/srt/mem_cache/cache_init_params.py @@ -27,10 +27,6 @@ class CacheInitParams: enable_mamba_extra_buffer: bool = False - # For SWAChunkCache - sliding_window_size: Optional[int] = None - attention_chunk_size: Optional[int] = None - pp_rank: int = 0 pp_size: int = 1 diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 38dbdd6fe192..9aa18da140f1 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -8,7 +8,6 @@ import torch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult -from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -84,62 +83,19 @@ def pretty_print(self): class SWAChunkCache(ChunkCache): - """ChunkCache with support for hybrid KV cache operations.""" + """ChunkCache with support for sliding window attention.""" - def __init__(self, params: CacheInitParams): - assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + def __init__(self, params: CacheInitParams, sliding_window_size: int): super().__init__(params) - assert ( - params.sliding_window_size is not None - or params.attention_chunk_size is not None - ), "Sliding window size or attention chunk size must be set for SWAChunkCache" - - if ( - params.sliding_window_size is not None - and params.attention_chunk_size is not None - ): - logger.warning( - "Sliding window size and attention chunk size are both set, use sliding window size for chunk cache eviction." - ) - - self.sliding_window_size = params.sliding_window_size - self.attention_chunk_size = params.attention_chunk_size - self.window_size = self.sliding_window_size or self.attention_chunk_size - + self.sliding_window_size = sliding_window_size self.chunked_prefill_size = params.chunked_prefill_size def supports_swa(self) -> bool: + assert ( + self.sliding_window_size is not None + ), "sliding_window_size must be set for SWAChunkCache" return True - def evict_swa( - self, - req: Req, - prelen: int, - ): - if self.sliding_window_size is not None: - # Sliding window attention (e.g. mimo-v2-flash, gpt-oss) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, prelen - self.sliding_window_size - ) - elif self.attention_chunk_size is not None: - # Local attention (e.g. llama4) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, - prelen // self.attention_chunk_size * self.attention_chunk_size, - ) - - if self.page_size > 1: - new_evicted_seqlen_local = ( - new_evicted_seqlen_local // self.page_size - ) * self.page_size - - if new_evicted_seqlen_local > req.evicted_seqlen_local: - free_slots = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local - ] - self.token_to_kv_pool_allocator.free_swa(free_slots) - req.evicted_seqlen_local = new_evicted_seqlen_local - def evict(self, num_tokens: int): pass diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 1380aed6095f..374614be5de0 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -338,18 +338,7 @@ def alloc_for_extend( req_pool_indices: request pool indices as list """ # free out-of-window swa tokens - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req, pre_len in zip(batch.reqs, batch.prefix_lens): - if batch.enable_overlap: - # In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens - if req.extend_batch_idx < 2: - continue - else: - batch.tree_cache.evict_swa( - req, pre_len - batch.tree_cache.chunked_prefill_size - ) - else: - batch.tree_cache.evict_swa(req, pre_len) + batch.maybe_evict_swa() bs = len(batch.reqs) prefix_tensors = [r.prefix_indices for r in batch.reqs] @@ -441,13 +430,7 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: out_cache_loc: allocated cache locations """ - if batch.tree_cache.supports_swa(): - for req in batch.reqs: - # We set evict_swa condition here with two reasons: - # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. - # 2. Evict swa every window_size tokens to reduce the overhead. - if req.decode_batch_idx % batch.tree_cache.window_size == 1: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + batch.maybe_evict_swa() bs = batch.seq_lens.shape[0] diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index da84f74a4c74..fdc07af5bdaf 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -364,6 +364,9 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int): ##### Public API ##### def supports_swa(self) -> bool: + assert ( + self.sliding_window_size is not None + ), "sliding_window_size must be set for SWARadixCache" return True def reset(self) -> None: @@ -686,27 +689,6 @@ def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None: self.update_eviction_metrics(full_num_evicted + swa_num_evicted, start_time) - def evict_swa(self, req: Req, pre_len: int) -> None: - # evict the swa tokens that not in the tree cache and also not in the sliding window - req.evicted_seqlen_local = max( - req.evicted_seqlen_local, req.cache_protected_len - ) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, pre_len - self.sliding_window_size - ) - - if self.page_size > 1: - new_evicted_seqlen_local = ( - new_evicted_seqlen_local // self.page_size - ) * self.page_size - - if new_evicted_seqlen_local > req.evicted_seqlen_local: - free_slots = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local - ] - self.token_to_kv_pool_allocator.free_swa(free_slots) - req.evicted_seqlen_local = new_evicted_seqlen_local - def inc_lock_ref(self, node: TreeNode) -> Optional[int]: """ Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index c943a5ec94b0..b12a3a296d87 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -79,9 +79,7 @@ def assign_draft_cache_locs_page_size_1( @dataclass class EagleDraftInputV2Mixin: def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req in batch.reqs: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + batch.maybe_evict_swa() from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 9c29b3b1a7b7..21496d8a1a67 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -374,9 +374,7 @@ def forward_target_extend( ) def _draft_preprocess_decode(self, batch: ScheduleBatch): - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req in batch.reqs: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + batch.maybe_evict_swa() # Parse args num_seqs = batch.batch_size() From a982ada9f6c9765e4bab629e2724e80ff7e737ab Mon Sep 17 00:00:00 2001 From: ispobock Date: Thu, 15 Jan 2026 12:51:59 +0000 Subject: [PATCH 03/12] fix insert --- .../sglang/srt/mem_cache/swa_radix_cache.py | 59 +++++++++++++------ 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index fdc07af5bdaf..84fc6f3e7ca0 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -427,7 +427,7 @@ def insert( key: RadixKey, value=None, prev_prefix_len: int = 0, - evicted_seqlen: int = 0, + swa_evicted_seqlen: int = 0, ) -> int: if self.disable: return 0 @@ -442,7 +442,7 @@ def insert( value = value[: len(key)] return self._insert_helper( - self.root_node, key, value, prev_prefix_len, evicted_seqlen + self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen ) def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: @@ -922,7 +922,7 @@ def _insert_helper( key: RadixKey, value, update_kv_after_len: int, - evicted_seqlen: int = 0, + swa_evicted_seqlen: int = 0, ) -> int: # Update the last access time from root to leaf, so that # swa will tombstone the node closer to root first @@ -955,18 +955,40 @@ def _insert_helper( # the prefill prefix matching will stuck. if update_kv_after_len < total_prefix_length + prefix_len: first_diff_idx = max(0, update_kv_after_len - total_prefix_length) - if node.swa_tombstone and evicted_seqlen < total_prefix_length: + assert ( + first_diff_idx == 0 + ), f"first_diff_idx should be 0, {first_diff_idx=}, {update_kv_after_len=}, {total_prefix_length=}, {prefix_len=}" + if node.swa_tombstone: assert ( node.swa_lock_ref == 0 ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}" - self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:]) - node.value = value[:prefix_len] - node.swa_tombstone = False - - # insert the node into the lru lists - self.swa_lru_list.insert_mru(node) - - self.swa_evictable_size_ += len(node.value) + assert ( + swa_evicted_seqlen % self.page_size == 0 + ), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}" + if swa_evicted_seqlen < total_prefix_length: + self.token_to_kv_pool_allocator.free( + node.value[first_diff_idx:] + ) + node.value = value[:prefix_len] + node.swa_tombstone = False + # insert the node into the lru lists + self.swa_lru_list.insert_mru(node) + self.swa_evictable_size_ += len(node.value) + elif swa_evicted_seqlen < total_prefix_length + prefix_len: + swa_evicted_len = swa_evicted_seqlen - total_prefix_length + self.token_to_kv_pool_allocator.free( + node.value[first_diff_idx:] + ) + self._split_node(node.key, node, swa_evicted_len) + node.value = value[swa_evicted_len:prefix_len] + node.swa_tombstone = False + # insert the node into the lru lists + self.swa_lru_list.insert_mru(node) + self.swa_evictable_size_ += len(node.value) + else: + self.token_to_kv_pool_allocator.free( + node.value[first_diff_idx:prefix_len] + ) else: self.token_to_kv_pool_allocator.free( value[first_diff_idx:prefix_len] @@ -981,17 +1003,20 @@ def _insert_helper( if len(key): if ( - evicted_seqlen > total_prefix_length - and evicted_seqlen < total_prefix_length + len(key) + swa_evicted_seqlen > total_prefix_length + and swa_evicted_seqlen < total_prefix_length + len(key) ): - swa_evicted_len = evicted_seqlen - total_prefix_length + swa_evicted_len = swa_evicted_seqlen - total_prefix_length node = self._add_new_node( - node, key[:swa_evicted_len], value[:swa_evicted_len], True + node, + key[:swa_evicted_len], + value[:swa_evicted_len], + swa_tombstone=True, ) key = key[swa_evicted_len:] value = value[swa_evicted_len:] - self._add_new_node(node, key, value, False) + self._add_new_node(node, key, value, swa_tombstone=False) return total_prefix_length def _add_new_node( From 3d35138af8e0db51ec0876a9b5566441fe366cd6 Mon Sep 17 00:00:00 2001 From: ispobock Date: Thu, 15 Jan 2026 13:19:58 +0000 Subject: [PATCH 04/12] update spec --- python/sglang/srt/managers/schedule_batch.py | 1 - python/sglang/srt/speculative/eagle_info_v2.py | 2 -- python/sglang/srt/speculative/eagle_worker.py | 2 -- 3 files changed, 5 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3fbc9de6595e..fef2b8a71309 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2209,7 +2209,6 @@ def maybe_evict_swa(self): if self.tree_cache.supports_swa(): sliding_window_size = self.tree_cache.sliding_window_size for idx, req in enumerate(self.reqs): - # TODO(ispobock): handle spec batch idx update if self.forward_mode.is_decode(): # We set evict_swa condition here with two reasons: # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index b12a3a296d87..3894c2176829 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -79,8 +79,6 @@ def assign_draft_cache_locs_page_size_1( @dataclass class EagleDraftInputV2Mixin: def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): - batch.maybe_evict_swa() - from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func bs = batch.batch_size() diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 21496d8a1a67..d5791b928a79 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -374,8 +374,6 @@ def forward_target_extend( ) def _draft_preprocess_decode(self, batch: ScheduleBatch): - batch.maybe_evict_swa() - # Parse args num_seqs = batch.batch_size() spec_info = batch.spec_info From 829c29926089b315c712496efa06b8578a042a1f Mon Sep 17 00:00:00 2001 From: ispobock Date: Thu, 15 Jan 2026 16:06:37 +0000 Subject: [PATCH 05/12] update --- python/sglang/srt/mem_cache/swa_radix_cache.py | 5 ++++- python/sglang/srt/speculative/eagle_info_v2.py | 3 +++ python/sglang/srt/speculative/eagle_worker.py | 4 ++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 84fc6f3e7ca0..d12a1c271e23 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -891,6 +891,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod new_node.full_lock_ref = child.full_lock_ref new_node.swa_lock_ref = child.swa_lock_ref new_node.key = child.key[:split_len] + assert len(new_node.key) > 0, f"new_node.key should not be empty" new_node.value = child.value[:split_len] # parent inherits the swa_uuid from child for swa lock ref new_node.swa_uuid = child.swa_uuid @@ -904,6 +905,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod self.swa_lru_list.remove_node(child) child.parent = new_node child.key = child.key[split_len:] + assert len(child.key) > 0, f"child.key should not be empty" child.value = child.value[split_len:] new_node.parent.children[self.get_child_key_fn(key)] = new_node @@ -965,7 +967,7 @@ def _insert_helper( assert ( swa_evicted_seqlen % self.page_size == 0 ), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}" - if swa_evicted_seqlen < total_prefix_length: + if swa_evicted_seqlen <= total_prefix_length: self.token_to_kv_pool_allocator.free( node.value[first_diff_idx:] ) @@ -1026,6 +1028,7 @@ def _add_new_node( value: torch.Tensor, swa_tombstone: bool = False, ) -> TreeNode: + assert len(key) > 0, f"key should not be empty" new_node = TreeNode() new_node.parent = node new_node.key = key diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 3894c2176829..d9d5147e02bf 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -79,6 +79,8 @@ def assign_draft_cache_locs_page_size_1( @dataclass class EagleDraftInputV2Mixin: def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): + batch.maybe_evict_swa() + from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func bs = batch.batch_size() @@ -97,6 +99,7 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): nxt_kv_lens_cpu.append(r.kv_allocated_len + x) num_needed_tokens += x r.kv_allocated_len += x + r.decode_batch_idx += 1 cur_kv_lens_cpu = torch.tensor(cur_kv_lens_cpu, dtype=torch.int32, device="cpu") nxt_kv_lens_cpu = torch.tensor(nxt_kv_lens_cpu, dtype=torch.int32, device="cpu") diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index d5791b928a79..5a6cc4b03526 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -374,6 +374,10 @@ def forward_target_extend( ) def _draft_preprocess_decode(self, batch: ScheduleBatch): + batch.maybe_evict_swa() + for req in batch.reqs: + req.decode_batch_idx += 1 + # Parse args num_seqs = batch.batch_size() spec_info = batch.spec_info From e1bd8ce0f3478506958c4cc79ee29ba4eb95805e Mon Sep 17 00:00:00 2001 From: ispobock Date: Fri, 16 Jan 2026 03:47:28 +0000 Subject: [PATCH 06/12] fix memory leak --- python/sglang/srt/mem_cache/swa_radix_cache.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index d12a1c271e23..f28054a1cd11 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -979,17 +979,20 @@ def _insert_helper( elif swa_evicted_seqlen < total_prefix_length + prefix_len: swa_evicted_len = swa_evicted_seqlen - total_prefix_length self.token_to_kv_pool_allocator.free( - node.value[first_diff_idx:] + node.value[swa_evicted_len:] ) self._split_node(node.key, node, swa_evicted_len) node.value = value[swa_evicted_len:prefix_len] + self.token_to_kv_pool_allocator.free( + value[first_diff_idx:swa_evicted_len] + ) node.swa_tombstone = False # insert the node into the lru lists self.swa_lru_list.insert_mru(node) self.swa_evictable_size_ += len(node.value) else: self.token_to_kv_pool_allocator.free( - node.value[first_diff_idx:prefix_len] + value[first_diff_idx:prefix_len] ) else: self.token_to_kv_pool_allocator.free( From 46915dcd5f0545d198058b1b7178a7251897d3a1 Mon Sep 17 00:00:00 2001 From: ispobock Date: Fri, 16 Jan 2026 10:31:04 +0000 Subject: [PATCH 07/12] fix --- python/sglang/srt/mem_cache/swa_radix_cache.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index f28054a1cd11..aed4472502fb 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -957,9 +957,6 @@ def _insert_helper( # the prefill prefix matching will stuck. if update_kv_after_len < total_prefix_length + prefix_len: first_diff_idx = max(0, update_kv_after_len - total_prefix_length) - assert ( - first_diff_idx == 0 - ), f"first_diff_idx should be 0, {first_diff_idx=}, {update_kv_after_len=}, {total_prefix_length=}, {prefix_len=}" if node.swa_tombstone: assert ( node.swa_lock_ref == 0 @@ -977,14 +974,16 @@ def _insert_helper( self.swa_lru_list.insert_mru(node) self.swa_evictable_size_ += len(node.value) elif swa_evicted_seqlen < total_prefix_length + prefix_len: - swa_evicted_len = swa_evicted_seqlen - total_prefix_length + start_update_idx = max( + first_diff_idx, swa_evicted_seqlen - total_prefix_length + ) self.token_to_kv_pool_allocator.free( - node.value[swa_evicted_len:] + node.value[start_update_idx:] ) - self._split_node(node.key, node, swa_evicted_len) - node.value = value[swa_evicted_len:prefix_len] + self._split_node(node.key, node, start_update_idx) + node.value = value[start_update_idx:prefix_len] self.token_to_kv_pool_allocator.free( - value[first_diff_idx:swa_evicted_len] + value[first_diff_idx:start_update_idx] ) node.swa_tombstone = False # insert the node into the lru lists From 08f03a556fba94f6b7e73f17b101136ec0b7b28c Mon Sep 17 00:00:00 2001 From: ispobock Date: Sat, 17 Jan 2026 06:26:24 +0000 Subject: [PATCH 08/12] update --- python/sglang/srt/mem_cache/chunk_cache.py | 2 ++ python/sglang/srt/mem_cache/swa_radix_cache.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 9aa18da140f1..86df54a7f276 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -8,6 +8,7 @@ import torch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -86,6 +87,7 @@ class SWAChunkCache(ChunkCache): """ChunkCache with support for sliding window attention.""" def __init__(self, params: CacheInitParams, sliding_window_size: int): + assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) super().__init__(params) self.sliding_window_size = sliding_window_size diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index aed4472502fb..9dc1a3b3689e 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -358,7 +358,6 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int): self.init_metrics_collector() self.sliding_window_size = sliding_window_size - self.window_size = self.sliding_window_size self.reset() ##### Public API ##### From 822c5253661dff886e393c6e1c39502b60b2a8b6 Mon Sep 17 00:00:00 2001 From: ispobock Date: Sat, 17 Jan 2026 06:31:34 +0000 Subject: [PATCH 09/12] rename --- python/sglang/srt/managers/schedule_batch.py | 26 ++++++++++--------- .../sglang/srt/mem_cache/swa_radix_cache.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fef2b8a71309..d61b9e3e004e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -549,8 +549,12 @@ def __init__( # for corss-endoder model self.token_type_ids = token_type_ids - # The length of KV that have been removed in swa chunk cache - self.evicted_seqlen_local = 0 + # The length of KV that have been removed in swa cache. + # SWA KV cache eviction behavior differs by cache type: + # - Radix cache: KV in range [cache_protected_len, swa_evicted_seqlen) is freed manually in + # `ScheduleBatch.maybe_evict_swa`; KV in range [0, cache_protected_len) is freed during radix cache eviction. + # - Chunk cache: KV in range [0, swa_evicted_seqlen) is freed manually in `ScheduleBatch.maybe_evict_swa`. + self.swa_evicted_seqlen = 0 # The index of the extend / decode batch self.extend_batch_idx = 0 @@ -2240,25 +2244,23 @@ def _evict_swa(self, req: Req, pre_len: int): assert ( req.cache_protected_len % self.tree_cache.page_size == 0 ), "cache_protected_len must be page aligned" - req.evicted_seqlen_local = max( - req.evicted_seqlen_local, req.cache_protected_len - ) + req.swa_evicted_seqlen = max(req.swa_evicted_seqlen, req.cache_protected_len) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, pre_len - sliding_window_size + new_swa_evicted_seqlen = max( + req.swa_evicted_seqlen, pre_len - sliding_window_size ) if self.tree_cache.page_size > 1: - new_evicted_seqlen_local = ( - new_evicted_seqlen_local // self.tree_cache.page_size + new_swa_evicted_seqlen = ( + new_swa_evicted_seqlen // self.tree_cache.page_size ) * self.tree_cache.page_size - if new_evicted_seqlen_local > req.evicted_seqlen_local: + if new_swa_evicted_seqlen > req.swa_evicted_seqlen: free_slots = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local + req.req_pool_idx, req.swa_evicted_seqlen : new_swa_evicted_seqlen ] self.token_to_kv_pool_allocator.free_swa(free_slots) - req.evicted_seqlen_local = new_evicted_seqlen_local + req.swa_evicted_seqlen = new_swa_evicted_seqlen def _is_available_size_sufficient(self, num_tokens: int) -> bool: if self.is_hybrid_swa: diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 9dc1a3b3689e..2f0f3eaaf818 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -492,7 +492,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, old_prefix_len, - req.evicted_seqlen_local, + req.swa_evicted_seqlen, ) else: self.token_to_kv_pool_allocator.free( From 943d6f6babb5b00e7690bbe2d297afbd653dcb0d Mon Sep 17 00:00:00 2001 From: ispobock Date: Sat, 17 Jan 2026 07:54:04 +0000 Subject: [PATCH 10/12] rename --- python/sglang/srt/mem_cache/swa_radix_cache.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 2f0f3eaaf818..bba9932ae40a 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -1009,33 +1009,33 @@ def _insert_helper( swa_evicted_seqlen > total_prefix_length and swa_evicted_seqlen < total_prefix_length + len(key) ): - swa_evicted_len = swa_evicted_seqlen - total_prefix_length + swa_tombstone_len = swa_evicted_seqlen - total_prefix_length node = self._add_new_node( node, - key[:swa_evicted_len], - value[:swa_evicted_len], + key[:swa_tombstone_len], + value[:swa_tombstone_len], swa_tombstone=True, ) - key = key[swa_evicted_len:] - value = value[swa_evicted_len:] + key = key[swa_tombstone_len:] + value = value[swa_tombstone_len:] self._add_new_node(node, key, value, swa_tombstone=False) return total_prefix_length def _add_new_node( self, - node: TreeNode, + parent: TreeNode, key: RadixKey, value: torch.Tensor, swa_tombstone: bool = False, ) -> TreeNode: assert len(key) > 0, f"key should not be empty" new_node = TreeNode() - new_node.parent = node + new_node.parent = parent new_node.key = key new_node.value = value new_node.swa_tombstone = swa_tombstone - node.children[self.get_child_key_fn(key)] = new_node + parent.children[self.get_child_key_fn(key)] = new_node self.full_lru_list.insert_mru(new_node) self.full_evictable_size_ += len(value) if not swa_tombstone: From b05fd782494a76a814486c1d63ad5da75c40a78c Mon Sep 17 00:00:00 2001 From: ispobock Date: Sun, 18 Jan 2026 17:24:20 +0000 Subject: [PATCH 11/12] fix and comment --- .../sglang/srt/mem_cache/swa_radix_cache.py | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 2d779aa595ae..a7d38ba5c9d3 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -955,7 +955,8 @@ def _insert_helper( # contains tombstone. If this is the case and we don't update the kv value, then # the prefill prefix matching will stuck. if update_kv_after_len < total_prefix_length + prefix_len: - first_diff_idx = max(0, update_kv_after_len - total_prefix_length) + # For page_size > 1 and chunked prefill case, update_kv_after_len may be not page-aligned due to a trailing partial page + # (kept in the request but not inserted into the radix tree) appended to prefix_indices. if node.swa_tombstone: assert ( node.swa_lock_ref == 0 @@ -964,38 +965,34 @@ def _insert_helper( swa_evicted_seqlen % self.page_size == 0 ), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}" if swa_evicted_seqlen <= total_prefix_length: - self.token_to_kv_pool_allocator.free( - node.value[first_diff_idx:] - ) + # Branch 1: all swa tokens of value[:prefix_len] are not evicted, so we can insert it to the tree directly. + # Free full tokens in the original tree node. + self.token_to_kv_pool_allocator.free(node.value[:prefix_len]) + # Overwrite the new value in request to the tree node. node.value = value[:prefix_len] node.swa_tombstone = False - # insert the node into the lru lists self.swa_lru_list.insert_mru(node) self.swa_evictable_size_ += len(node.value) elif swa_evicted_seqlen < total_prefix_length + prefix_len: - start_update_idx = max( - first_diff_idx, swa_evicted_seqlen - total_prefix_length - ) + # Branch 2: part of swa tokens of value[:prefix_len] are evicted, so we need to split the node and insert the value to new node. + start_update_idx = swa_evicted_seqlen - total_prefix_length self.token_to_kv_pool_allocator.free( - node.value[start_update_idx:] + node.value[start_update_idx:prefix_len] ) self._split_node(node.key, node, start_update_idx) + # Here node is the new node after split, so we can overwrite the value to the new node. + # The old node is still swa tombstone and the full token is not freed. node.value = value[start_update_idx:prefix_len] - self.token_to_kv_pool_allocator.free( - value[first_diff_idx:start_update_idx] - ) + self.token_to_kv_pool_allocator.free(value[:start_update_idx]) node.swa_tombstone = False - # insert the node into the lru lists self.swa_lru_list.insert_mru(node) self.swa_evictable_size_ += len(node.value) else: - self.token_to_kv_pool_allocator.free( - value[first_diff_idx:prefix_len] - ) + # Branch 3: all swa tokens of value[:prefix_len] are evicted, so we don't need to update the node. + self.token_to_kv_pool_allocator.free(value[:prefix_len]) else: - self.token_to_kv_pool_allocator.free( - value[first_diff_idx:prefix_len] - ) + # The node is not tombstone, so we don't need to update the node. + self.token_to_kv_pool_allocator.free(value[:prefix_len]) total_prefix_length += prefix_len key = key[prefix_len:] From 61d78ff172979d1baef3addf0929fe357ba88494 Mon Sep 17 00:00:00 2001 From: ispobock Date: Mon, 19 Jan 2026 12:06:38 +0000 Subject: [PATCH 12/12] value clone --- python/sglang/srt/mem_cache/swa_radix_cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index a7d38ba5c9d3..33b7a482de4d 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -969,7 +969,7 @@ def _insert_helper( # Free full tokens in the original tree node. self.token_to_kv_pool_allocator.free(node.value[:prefix_len]) # Overwrite the new value in request to the tree node. - node.value = value[:prefix_len] + node.value = value[:prefix_len].clone() node.swa_tombstone = False self.swa_lru_list.insert_mru(node) self.swa_evictable_size_ += len(node.value) @@ -982,7 +982,7 @@ def _insert_helper( self._split_node(node.key, node, start_update_idx) # Here node is the new node after split, so we can overwrite the value to the new node. # The old node is still swa tombstone and the full token is not freed. - node.value = value[start_update_idx:prefix_len] + node.value = value[start_update_idx:prefix_len].clone() self.token_to_kv_pool_allocator.free(value[:start_update_idx]) node.swa_tombstone = False self.swa_lru_list.insert_mru(node) @@ -1030,7 +1030,7 @@ def _add_new_node( new_node = TreeNode() new_node.parent = parent new_node.key = key - new_node.value = value + new_node.value = value.clone() new_node.swa_tombstone = swa_tombstone parent.children[self.get_child_key_fn(key)] = new_node self.full_lru_list.insert_mru(new_node)