diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 068bd184917b..2eee73ecbd56 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -549,8 +549,12 @@ def __init__( # for corss-endoder model self.token_type_ids = token_type_ids - # The length of KV that have been removed in swa chunk cache - self.evicted_seqlen_local = 0 + # The length of KV that have been removed in swa cache. + # SWA KV cache eviction behavior differs by cache type: + # - Radix cache: KV in range [cache_protected_len, swa_evicted_seqlen) is freed manually in + # `ScheduleBatch.maybe_evict_swa`; KV in range [0, cache_protected_len) is freed during radix cache eviction. + # - Chunk cache: KV in range [0, swa_evicted_seqlen) is freed manually in `ScheduleBatch.maybe_evict_swa`. + self.swa_evicted_seqlen = 0 # The index of the extend / decode batch self.extend_batch_idx = 0 @@ -2264,6 +2268,59 @@ def copy(self): dp_cooperation_info=self.dp_cooperation_info, ) + def maybe_evict_swa(self): + if self.tree_cache.supports_swa(): + sliding_window_size = self.tree_cache.sliding_window_size + for idx, req in enumerate(self.reqs): + if self.forward_mode.is_decode(): + # We set evict_swa condition here with two reasons: + # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. + # 2. Evict swa every window_size tokens to reduce the overhead. + if req.decode_batch_idx % sliding_window_size == 1: + self._evict_swa(req, req.seqlen - 1) + elif self.forward_mode.is_extend() and self.tree_cache.is_chunk_cache(): + pre_len = self.prefix_lens[idx] + if self.enable_overlap: + # In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens + if req.extend_batch_idx < 2: + continue + else: + server_args = get_global_server_args() + pre_len = ( + pre_len - server_args.chunked_prefill_size + if server_args.chunked_prefill_size > 0 + else pre_len + ) + self._evict_swa(req, pre_len) + else: + self._evict_swa(req, pre_len) + + def _evict_swa(self, req: Req, pre_len: int): + assert self.tree_cache.supports_swa(), "prefix cache must support swa" + sliding_window_size = self.tree_cache.sliding_window_size + + # For swa radix cache, we need to evict the tokens that are not in the tree cache and also not in the sliding window + assert ( + req.cache_protected_len % self.tree_cache.page_size == 0 + ), "cache_protected_len must be page aligned" + req.swa_evicted_seqlen = max(req.swa_evicted_seqlen, req.cache_protected_len) + + new_swa_evicted_seqlen = max( + req.swa_evicted_seqlen, pre_len - sliding_window_size + ) + + if self.tree_cache.page_size > 1: + new_swa_evicted_seqlen = ( + new_swa_evicted_seqlen // self.tree_cache.page_size + ) * self.tree_cache.page_size + + if new_swa_evicted_seqlen > req.swa_evicted_seqlen: + free_slots = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, req.swa_evicted_seqlen : new_swa_evicted_seqlen + ] + self.token_to_kv_pool_allocator.free_swa(free_slots) + req.swa_evicted_seqlen = new_swa_evicted_seqlen + def _is_available_size_sufficient(self, num_tokens: int) -> bool: if self.is_hybrid_swa: return ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5471bfeed505..67243d5edbb5 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -638,10 +638,9 @@ def init_cache_with_memory_pool(self): else: from sglang.srt.mem_cache.chunk_cache import SWAChunkCache - params.sliding_window_size = self.model_config.sliding_window_size - params.attention_chunk_size = self.model_config.attention_chunk_size - - self.tree_cache = SWAChunkCache(params) + self.tree_cache = SWAChunkCache( + params, sliding_window_size=self.sliding_window_size + ) else: if envs.SGLANG_EXPERIMENTAL_CPP_RADIX_TREE.get(): diff --git a/python/sglang/srt/mem_cache/cache_init_params.py b/python/sglang/srt/mem_cache/cache_init_params.py index 2ed8803b0bcc..7f6111f8086e 100644 --- a/python/sglang/srt/mem_cache/cache_init_params.py +++ b/python/sglang/srt/mem_cache/cache_init_params.py @@ -27,10 +27,6 @@ class CacheInitParams: enable_mamba_extra_buffer: bool = False - # For SWAChunkCache - sliding_window_size: Optional[int] = None - attention_chunk_size: Optional[int] = None - pp_rank: int = 0 pp_size: int = 1 diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 38dbdd6fe192..86df54a7f276 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -84,62 +84,20 @@ def pretty_print(self): class SWAChunkCache(ChunkCache): - """ChunkCache with support for hybrid KV cache operations.""" + """ChunkCache with support for sliding window attention.""" - def __init__(self, params: CacheInitParams): + def __init__(self, params: CacheInitParams, sliding_window_size: int): assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) super().__init__(params) - assert ( - params.sliding_window_size is not None - or params.attention_chunk_size is not None - ), "Sliding window size or attention chunk size must be set for SWAChunkCache" - - if ( - params.sliding_window_size is not None - and params.attention_chunk_size is not None - ): - logger.warning( - "Sliding window size and attention chunk size are both set, use sliding window size for chunk cache eviction." - ) - - self.sliding_window_size = params.sliding_window_size - self.attention_chunk_size = params.attention_chunk_size - self.window_size = self.sliding_window_size or self.attention_chunk_size - + self.sliding_window_size = sliding_window_size self.chunked_prefill_size = params.chunked_prefill_size def supports_swa(self) -> bool: + assert ( + self.sliding_window_size is not None + ), "sliding_window_size must be set for SWAChunkCache" return True - def evict_swa( - self, - req: Req, - prelen: int, - ): - if self.sliding_window_size is not None: - # Sliding window attention (e.g. mimo-v2-flash, gpt-oss) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, prelen - self.sliding_window_size - ) - elif self.attention_chunk_size is not None: - # Local attention (e.g. llama4) - new_evicted_seqlen_local = max( - req.evicted_seqlen_local, - prelen // self.attention_chunk_size * self.attention_chunk_size, - ) - - if self.page_size > 1: - new_evicted_seqlen_local = ( - new_evicted_seqlen_local // self.page_size - ) * self.page_size - - if new_evicted_seqlen_local > req.evicted_seqlen_local: - free_slots = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local - ] - self.token_to_kv_pool_allocator.free_swa(free_slots) - req.evicted_seqlen_local = new_evicted_seqlen_local - def evict(self, num_tokens: int): pass diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 3cac7192130d..374614be5de0 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -338,18 +338,7 @@ def alloc_for_extend( req_pool_indices: request pool indices as list """ # free out-of-window swa tokens - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req, pre_len in zip(batch.reqs, batch.prefix_lens): - if batch.enable_overlap: - # In chunked prefill case, when the second extend batch is scheduling, the first extend batch is still running, so we cannot evict swa tokens - if req.extend_batch_idx < 2: - continue - else: - batch.tree_cache.evict_swa( - req, pre_len - batch.tree_cache.chunked_prefill_size - ) - else: - batch.tree_cache.evict_swa(req, pre_len) + batch.maybe_evict_swa() bs = len(batch.reqs) prefix_tensors = [r.prefix_indices for r in batch.reqs] @@ -440,13 +429,8 @@ def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor: Returns: out_cache_loc: allocated cache locations """ - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req in batch.reqs: - # We set evict_swa condition here with two reasons: - # 1. In overlap scheduler, we cannot evict swa when req.decode_batch_idx == 0 since the prev extend batch is still running. - # 2. Evict swa every window_size tokens to reduce the overhead. - if req.decode_batch_idx % batch.tree_cache.window_size == 1: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + + batch.maybe_evict_swa() bs = batch.seq_lens.shape[0] diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 9f2a75fae8dd..33b7a482de4d 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -363,6 +363,9 @@ def __init__(self, params: CacheInitParams, sliding_window_size: int): ##### Public API ##### def supports_swa(self) -> bool: + assert ( + self.sliding_window_size is not None + ), "sliding_window_size must be set for SWARadixCache" return True def reset(self) -> None: @@ -418,7 +421,13 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: last_host_node=last_node, ) - def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: + def insert( + self, + key: RadixKey, + value=None, + prev_prefix_len: int = 0, + swa_evicted_seqlen: int = 0, + ) -> int: if self.disable: return 0 @@ -431,7 +440,9 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int: # Make sure the value len equal to the EAGLE bigram key len value = value[: len(key)] - return self._insert_helper(self.root_node, key, value, prev_prefix_len) + return self._insert_helper( + self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen + ) def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: """Cache request when it finishes.""" @@ -481,6 +492,7 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, old_prefix_len, + req.swa_evicted_seqlen, ) else: self.token_to_kv_pool_allocator.free( @@ -878,6 +890,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod new_node.full_lock_ref = child.full_lock_ref new_node.swa_lock_ref = child.swa_lock_ref new_node.key = child.key[:split_len] + assert len(new_node.key) > 0, f"new_node.key should not be empty" new_node.value = child.value[:split_len].clone() # parent inherits the swa_uuid from child for swa lock ref new_node.swa_uuid = child.swa_uuid @@ -891,6 +904,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod self.swa_lru_list.remove_node(child) child.parent = new_node child.key = child.key[split_len:] + assert len(child.key) > 0, f"child.key should not be empty" child.value = child.value[split_len:].clone() new_node.parent.children[self.get_child_key_fn(key)] = new_node @@ -904,7 +918,12 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod return new_node def _insert_helper( - self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int + self, + node: TreeNode, + key: RadixKey, + value, + update_kv_after_len: int, + swa_evicted_seqlen: int = 0, ) -> int: # Update the last access time from root to leaf, so that # swa will tombstone the node closer to root first @@ -936,23 +955,44 @@ def _insert_helper( # contains tombstone. If this is the case and we don't update the kv value, then # the prefill prefix matching will stuck. if update_kv_after_len < total_prefix_length + prefix_len: - first_diff_idx = max(0, update_kv_after_len - total_prefix_length) + # For page_size > 1 and chunked prefill case, update_kv_after_len may be not page-aligned due to a trailing partial page + # (kept in the request but not inserted into the radix tree) appended to prefix_indices. if node.swa_tombstone: assert ( node.swa_lock_ref == 0 ), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}" - self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:]) - node.value = value[:prefix_len] - node.swa_tombstone = False - - # insert the node into the lru lists - self.swa_lru_list.insert_mru(node) - - self.swa_evictable_size_ += len(node.value) + assert ( + swa_evicted_seqlen % self.page_size == 0 + ), f"swa_evicted_seqlen must be page aligned, {swa_evicted_seqlen=}, {self.page_size=}" + if swa_evicted_seqlen <= total_prefix_length: + # Branch 1: all swa tokens of value[:prefix_len] are not evicted, so we can insert it to the tree directly. + # Free full tokens in the original tree node. + self.token_to_kv_pool_allocator.free(node.value[:prefix_len]) + # Overwrite the new value in request to the tree node. + node.value = value[:prefix_len].clone() + node.swa_tombstone = False + self.swa_lru_list.insert_mru(node) + self.swa_evictable_size_ += len(node.value) + elif swa_evicted_seqlen < total_prefix_length + prefix_len: + # Branch 2: part of swa tokens of value[:prefix_len] are evicted, so we need to split the node and insert the value to new node. + start_update_idx = swa_evicted_seqlen - total_prefix_length + self.token_to_kv_pool_allocator.free( + node.value[start_update_idx:prefix_len] + ) + self._split_node(node.key, node, start_update_idx) + # Here node is the new node after split, so we can overwrite the value to the new node. + # The old node is still swa tombstone and the full token is not freed. + node.value = value[start_update_idx:prefix_len].clone() + self.token_to_kv_pool_allocator.free(value[:start_update_idx]) + node.swa_tombstone = False + self.swa_lru_list.insert_mru(node) + self.swa_evictable_size_ += len(node.value) + else: + # Branch 3: all swa tokens of value[:prefix_len] are evicted, so we don't need to update the node. + self.token_to_kv_pool_allocator.free(value[:prefix_len]) else: - self.token_to_kv_pool_allocator.free( - value[first_diff_idx:prefix_len] - ) + # The node is not tombstone, so we don't need to update the node. + self.token_to_kv_pool_allocator.free(value[:prefix_len]) total_prefix_length += prefix_len key = key[prefix_len:] @@ -962,16 +1002,43 @@ def _insert_helper( child_key = self.get_child_key_fn(key) if len(key): - new_node = TreeNode() - new_node.parent = node - new_node.key = key - new_node.value = value - self.full_lru_list.insert_mru(new_node) + if ( + swa_evicted_seqlen > total_prefix_length + and swa_evicted_seqlen < total_prefix_length + len(key) + ): + swa_tombstone_len = swa_evicted_seqlen - total_prefix_length + node = self._add_new_node( + node, + key[:swa_tombstone_len], + value[:swa_tombstone_len], + swa_tombstone=True, + ) + key = key[swa_tombstone_len:] + value = value[swa_tombstone_len:] + + self._add_new_node(node, key, value, swa_tombstone=False) + return total_prefix_length + + def _add_new_node( + self, + parent: TreeNode, + key: RadixKey, + value: torch.Tensor, + swa_tombstone: bool = False, + ) -> TreeNode: + assert len(key) > 0, f"key should not be empty" + new_node = TreeNode() + new_node.parent = parent + new_node.key = key + new_node.value = value.clone() + new_node.swa_tombstone = swa_tombstone + parent.children[self.get_child_key_fn(key)] = new_node + self.full_lru_list.insert_mru(new_node) + self.full_evictable_size_ += len(value) + if not swa_tombstone: self.swa_lru_list.insert_mru(new_node) - node.children[child_key] = new_node - self.full_evictable_size_ += len(value) self.swa_evictable_size_ += len(value) - return total_prefix_length + return new_node def _iteratively_delete_tombstone_leaf( self, node: TreeNode diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index c943a5ec94b0..d9d5147e02bf 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -79,9 +79,7 @@ def assign_draft_cache_locs_page_size_1( @dataclass class EagleDraftInputV2Mixin: def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req in batch.reqs: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + batch.maybe_evict_swa() from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -101,6 +99,7 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): nxt_kv_lens_cpu.append(r.kv_allocated_len + x) num_needed_tokens += x r.kv_allocated_len += x + r.decode_batch_idx += 1 cur_kv_lens_cpu = torch.tensor(cur_kv_lens_cpu, dtype=torch.int32, device="cpu") nxt_kv_lens_cpu = torch.tensor(nxt_kv_lens_cpu, dtype=torch.int32, device="cpu") diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 9c29b3b1a7b7..5a6cc4b03526 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -374,9 +374,9 @@ def forward_target_extend( ) def _draft_preprocess_decode(self, batch: ScheduleBatch): - if batch.tree_cache.supports_swa() and batch.tree_cache.is_chunk_cache(): - for req in batch.reqs: - batch.tree_cache.evict_swa(req, req.seqlen - 1) + batch.maybe_evict_swa() + for req in batch.reqs: + req.decode_batch_idx += 1 # Parse args num_seqs = batch.batch_size()