diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index f7d4092d85dd..20913d325ca4 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -136,9 +136,16 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: int = 0, ): """ Notify the prefill server about the kv indices, aux index, and state_indices. + + Args: + decode_prefix_len: Number of tokens already cached on the decode side. + When > 0, kv_indices contains only the incremental portion + (beyond the cached prefix), and the prefill side should skip + transferring the first decode_prefix_len tokens. """ ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 26752d52dd54..7bf609dc5eed 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -325,6 +325,21 @@ def _resolve_rank_mapping(self, info: PrefillServerInfo) -> None: info.required_dst_info_num = required_dst_info_num info.required_prefill_response_num = required_prefill_response_num + def get_decode_prefix_len(self, bootstrap_room: int) -> int: + """Get the decode_prefix_len for a given bootstrap_room from transfer_infos. + + Returns the max decode_prefix_len across all session/agent entries for the room. + Returns 0 if not available (backward compatible). + """ + if not hasattr(self, "transfer_infos"): + return 0 + room_infos = self.transfer_infos.get(bootstrap_room, {}) + if not room_infos: + return 0 + return max( + getattr(info, "decode_prefix_len", 0) for info in room_infos.values() + ) + def register_to_bootstrap(self): """Register prefill server info to bootstrap server via HTTP POST.""" if self.dist_init_addr: diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index ccf842ac0067..0d728f3499f1 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -29,6 +29,7 @@ import numpy as np import torch +import torch.distributed as dist from torch.distributed import ProcessGroup from sglang.srt.configs.mamba_utils import Mamba2CacheParams @@ -45,6 +46,7 @@ get_kv_class, is_mla_backend, kv_to_page_indices, + kv_to_page_num, poll_and_all_reduce, poll_and_all_reduce_with_staging, prepare_abort, @@ -54,8 +56,9 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, EvictParams, MatchPrefixParams +from sglang.srt.mem_cache.radix_cache import RadixKey +from sglang.srt.mem_cache.common import evict_from_tree_cache, release_kv_cache from sglang.srt.mem_cache.memory_pool import ( HybridLinearKVPool, HybridReqToTokenPool, @@ -197,18 +200,29 @@ def __init__( self.mamba_ping_pong_track_buffer_size = 2 if enable_overlap_schedule else 1 self.enable_mamba_extra_buffer = enable_mamba_extra_buffer self.enable_memory_saver = enable_memory_saver + # Each request needs 1 main mamba slot + ping-pong slots when extra_buffer is enabled. + # Cap the pool at max concurrent requests * slots_per_req to avoid OOM. + # Tree cache mamba entries are evicted on demand in _pre_alloc when pool is low. + slots_per_req = 1 + ( + self.mamba_ping_pong_track_buffer_size + if enable_mamba_extra_buffer + else 0 + ) + max_slots_needed = (size + pre_alloc_size) * slots_per_req if mamba_size is not None: - effective_mamba_size = min(mamba_size, size + pre_alloc_size) - if mamba_size > size + pre_alloc_size: + effective_mamba_size = min(mamba_size, max_slots_needed) + if mamba_size > max_slots_needed: logger.warning( - "mamba_size (%d) exceeds size + pre_alloc_size (%d), " + "mamba_size (%d) exceeds max_slots_needed (%d = %d reqs * %d slots/req), " "capping effective_mamba_size to %d", mamba_size, + max_slots_needed, size + pre_alloc_size, + slots_per_req, effective_mamba_size, ) else: - effective_mamba_size = size + pre_alloc_size + effective_mamba_size = max_slots_needed self.start_layer = start_layer if start_layer is not None else 0 self.layer_transfer_counter = None self._init_mamba_pool( @@ -273,7 +287,7 @@ def __init__( self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.scheduler = scheduler self.transfer_queue = transfer_queue - self.tree_cache = tree_cache # this is always a chunk cache + self.tree_cache = tree_cache self.gloo_group = gloo_group self.tp_rank = tp_rank self.tp_size = tp_size @@ -520,10 +534,24 @@ def resume_retracted_reqs( def _update_handshake_waiters( self, rids_to_check: Optional[List[str]] = None ) -> None: - if not self.queue: - return + need_poll = ( + len(self.queue) > 0 + and not all( + decode_req.waiting_for_input for decode_req in self.queue + ) + ) + # All TPs must agree on whether to poll and on queue size, otherwise + # poll_and_all_reduce (which sizes its tensor by queue length) hangs. + if dist.get_world_size(self.gloo_group) > 1: + n = len(self.queue) + local = torch.tensor( + [int(need_poll), n, -n], dtype=torch.int64, device="cpu" + ) + dist.all_reduce(local, op=dist.ReduceOp.MIN, group=self.gloo_group) + if local[0].item() == 0 or local[1].item() != -local[2].item(): + return - if all(decode_req.waiting_for_input for decode_req in self.queue): + if not need_poll: return polls = poll_and_all_reduce( @@ -673,6 +701,16 @@ def pop_preallocated( failed_reqs.append(decode_req) indices_to_remove.add(i) + # Pre-compute mamba pool constants (fixed after model startup) + has_mamba = self.tree_cache.supports_mamba() + if has_mamba: + mamba_pool = self.req_to_token_pool.mamba_pool + mamba_slots_per_req = 1 + ( + self.req_to_token_pool.mamba_ping_pong_track_buffer_size + if self.req_to_token_pool.enable_mamba_extra_buffer + else 0 + ) + # Then, preallocate the remaining requests if possible for i, decode_req in enumerate(self.queue): if rids_to_check is not None and decode_req.req.rid not in rids_to_check: @@ -690,6 +728,17 @@ def pop_preallocated( if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: break + # Mamba pool capacity check: evict tree cache entries first to + # free slots, break only if eviction cannot reclaim enough. + if has_mamba: + deficit = mamba_slots_per_req - mamba_pool.available_size() + if deficit > 0: + result = self.tree_cache.evict( + EvictParams(num_tokens=0, mamba_num=deficit) + ) + if result.mamba_num_evicted < deficit: + break + # Memory estimation: don't add if the projected memory cannot be met # TODO: add new_token ratio origin_input_len = len(decode_req.req.origin_input_ids) @@ -714,41 +763,92 @@ def pop_preallocated( break allocatable_tokens -= required_tokens_for_request - dst_kv_indices = self._pre_alloc(decode_req.req) + req = decode_req.req - origin_input_len = len(decode_req.req.origin_input_ids) if self.scheduler.enable_hisparse: - # Must cast to int32 for ZMQ serialization — from_zmq reads np.int32. + # HiSparse: no prefix matching, use host indices directly + page_size = 1 # host pool page_size + decode_prefix_len = 0 + req.disagg_decode_prefix_len = 0 + dst_kv_indices = self._pre_alloc(req) kv_indices = ( - dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32) + dst_kv_indices[:origin_input_len] + .cpu() + .numpy() + .astype(np.int32) ) - page_size = 1 # host pool page_size else: - kv_indices_full = self.req_to_token_pool.req_to_token[ - decode_req.req.req_pool_idx - ][:origin_input_len] - kv_indices = kv_indices_full.cpu().numpy() page_size = self.token_to_kv_pool_allocator.page_size + # Attempt prefix matching against decode-side tree cache + prefix_indices = None + decode_prefix_len = 0 + last_node = None + if self.scheduler.server_args.disaggregation_enable_decode_radix_cache: + token_ids_for_match = req.origin_input_ids[: origin_input_len - 1] + if len(token_ids_for_match) > 0: + match_result = self.tree_cache.match_prefix( + MatchPrefixParams( + key=RadixKey( + token_ids=token_ids_for_match, + extra_key=req.extra_key, + ), + req=req, + # For hybrid models (MambaRadixCache), skip mamba-based + # truncation — SSM state comes from prefill transfer, + # not from decode tree cache. + skip_mamba_truncation=self.tree_cache.supports_mamba(), + ) + ) + raw_prefix_len = len(match_result.device_indices) + # Page-align the prefix length + decode_prefix_len = (raw_prefix_len // page_size) * page_size + if decode_prefix_len > 0: + prefix_indices = match_result.device_indices[ + :decode_prefix_len + ] + last_node = match_result.last_device_node + # Lock the prefix nodes to prevent eviction during transfer + self.tree_cache.inc_lock_ref(last_node) + req.prefix_indices = prefix_indices + req.last_node = last_node + + req.disagg_decode_prefix_len = decode_prefix_len + self._pre_alloc( + req, + prefix_indices=prefix_indices, + prefix_len=decode_prefix_len, + ) + + # Send only incremental KV indices (beyond the cached prefix) + total_len = len(req.origin_input_ids) + kv_indices = ( + self.req_to_token_pool.req_to_token[req.req_pool_idx][ + decode_prefix_len:total_len + ] + .cpu() + .numpy() + ) + # Prepare extra pool indices for hybrid models if isinstance(self.token_to_kv_pool, HybridLinearKVPool): # Mamba hybrid model: single mamba state index state_indices = [ self.req_to_token_pool.req_index_to_mamba_index_mapping[ - decode_req.req.req_pool_idx + req.req_pool_idx ] .cpu() .numpy() ] elif isinstance(self.token_to_kv_pool, SWAKVPool): # SWA hybrid model: send decode-side SWA window indices - seq_len = len(decode_req.req.origin_input_ids) + seq_len = len(req.origin_input_ids) window_size = self.scheduler.sliding_window_size window_start = max(0, seq_len - window_size) window_start = (window_start // page_size) * page_size window_kv_indices_full = self.req_to_token_pool.req_to_token[ - decode_req.req.req_pool_idx, window_start:seq_len + req.req_pool_idx, window_start:seq_len ] # Translate to SWA pool indices @@ -760,9 +860,9 @@ def pop_preallocated( state_indices = window_kv_indices_swa.cpu().numpy() state_indices = kv_to_page_indices(state_indices, page_size) elif isinstance(self.token_to_kv_pool, NSATokenToKVPool): - seq_len = len(decode_req.req.origin_input_ids) + seq_len = len(req.origin_input_ids) kv_indices_full = self.req_to_token_pool.req_to_token[ - decode_req.req.req_pool_idx, :seq_len + req.req_pool_idx, :seq_len ] state_indices = kv_indices_full.cpu().numpy() # Indexer lives on device pool; always use device page_size @@ -777,7 +877,17 @@ def pop_preallocated( assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) decode_req.kv_receiver.send_metadata( - page_indices, decode_req.metadata_buffer_index, state_indices + page_indices, + decode_req.metadata_buffer_index, + state_indices, + decode_prefix_len=decode_prefix_len, + ) + logger.info( + f"Decode prealloc for {req.rid}: " + f"decode_prefix_len={decode_prefix_len}, " + f"decode_prefix_pages={kv_to_page_num(decode_prefix_len, page_size)}, " + f"total_pages={kv_to_page_num(total_len, page_size)}, " + f"bootstrap_room={req.bootstrap_room}" ) if ( self.transfer_queue.enable_staging @@ -820,7 +930,11 @@ def _allocatable_tokens( else 0 ) available_size = self.token_to_kv_pool_allocator.available_size() - allocatable_tokens = available_size - max( + if self.tree_cache.supports_mamba(): + evictable_size = self.tree_cache.full_evictable_size() + else: + evictable_size = self.tree_cache.evictable_size() + allocatable_tokens = available_size + evictable_size - max( # preserve some space for future decode self.num_reserved_decode_tokens * ( @@ -853,8 +967,20 @@ def _allocatable_tokens( ) return allocatable_tokens - def _pre_alloc(self, req: Req) -> torch.Tensor: - """Pre-allocate the memory for req_to_token and token_kv_pool""" + def _pre_alloc( + self, + req: Req, + prefix_indices: Optional[torch.Tensor] = None, + prefix_len: int = 0, + ) -> torch.Tensor: + """Pre-allocate the memory for req_to_token and token_kv_pool. + + Args: + prefix_indices: KV cache indices from the decode-side tree cache prefix match. + When provided, these are reused for [0, prefix_len) and only + [prefix_len, fill_len) is freshly allocated. + prefix_len: Page-aligned length of the matched prefix (0 = full alloc). + """ req_pool_indices = self.req_to_token_pool.alloc([req]) assert ( @@ -888,24 +1014,88 @@ def _pre_alloc(self, req: Req) -> torch.Tensor: ) host_indices = host_indices.to(device=coordinator.device) coordinator.req_to_host_pool[req.req_pool_idx, :fill_len] = host_indices - elif self.token_to_kv_pool_allocator.page_size == 1: - kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len) else: - device = self.token_to_kv_pool_allocator.device - kv_loc = self.token_to_kv_pool_allocator.alloc_extend( - prefix_lens=torch.tensor([0], dtype=torch.int64, device=device), - prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), - seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device), - seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), - last_loc=torch.tensor([-1], dtype=torch.int64, device=device), - extend_num_tokens=fill_len, - ) + # Non-hisparse path: evict tree cache if needed, then allocate. + alloc_need = fill_len - prefix_len + if alloc_need > 0 and self.scheduler.server_args.disaggregation_enable_decode_radix_cache: + evict_from_tree_cache(self.tree_cache, alloc_need) + # Merge freed pages back into the sorted free pool so that + # alloc() returns ordered indices, improving RDMA coalescing + # in group_concurrent_contiguous. + if self.token_to_kv_pool_allocator.need_sort: + self.token_to_kv_pool_allocator.merge_and_sort_free() + + if prefix_len > 0 and prefix_indices is not None and len(prefix_indices) > 0: + # Reuse prefix KV from decode-side tree cache + self.req_to_token_pool.write( + (req.req_pool_idx, slice(0, prefix_len)), prefix_indices + ) - assert ( - kv_loc is not None - ), "KV cache is full! There is a bug in memory estimation." + extend_num_tokens = fill_len - prefix_len + if extend_num_tokens > 0: + if self.token_to_kv_pool_allocator.page_size == 1: + kv_loc = self.token_to_kv_pool_allocator.alloc( + extend_num_tokens + ) + else: + device = self.token_to_kv_pool_allocator.device + last_loc = prefix_indices[-1].unsqueeze(0).to(device) + kv_loc = self.token_to_kv_pool_allocator.alloc_extend( + prefix_lens=torch.tensor( + [prefix_len], dtype=torch.int64, device=device + ), + prefix_lens_cpu=torch.tensor( + [prefix_len], dtype=torch.int64 + ), + seq_lens=torch.tensor( + [fill_len], dtype=torch.int64, device=device + ), + seq_lens_cpu=torch.tensor( + [fill_len], dtype=torch.int64 + ), + last_loc=last_loc, + extend_num_tokens=extend_num_tokens, + ) + assert ( + kv_loc is not None + ), "KV cache is full! There is a bug in memory estimation." + self.req_to_token_pool.write( + ( + req.req_pool_idx, + slice(prefix_len, prefix_len + len(kv_loc)), + ), + kv_loc, + ) + else: + kv_loc = torch.empty(0, dtype=torch.int64) + else: + # Full allocation (no prefix reuse) + if self.token_to_kv_pool_allocator.page_size == 1: + kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len) + else: + device = self.token_to_kv_pool_allocator.device + kv_loc = self.token_to_kv_pool_allocator.alloc_extend( + prefix_lens=torch.tensor( + [0], dtype=torch.int64, device=device + ), + prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), + seq_lens=torch.tensor( + [fill_len], dtype=torch.int64, device=device + ), + seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), + last_loc=torch.tensor( + [-1], dtype=torch.int64, device=device + ), + extend_num_tokens=fill_len, + ) - self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) + assert ( + kv_loc is not None + ), "KV cache is full! There is a bug in memory estimation." + + self.req_to_token_pool.write( + (req.req_pool_idx, slice(0, len(kv_loc))), kv_loc + ) # populate metadata req.fill_ids = req.origin_input_ids + req.output_ids @@ -1039,6 +1229,33 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool: decode_req.req.time_stats.set_wait_queue_entry_time() return True + def _abort_and_release(self, decode_req: DecodeRequest): + """Abort a failed/corrupted request and release its KV cache. + + Sets cache_protected_len to protect tree-owned prefix pages from being + freed, then calls release_kv_cache (which internally dec_lock_ref's the + prefix lock acquired in pop_preallocated). + """ + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) + if decode_req.req.disagg_decode_prefix_len > 0: + decode_req.req.cache_protected_len = ( + decode_req.req.disagg_decode_prefix_len + ) + # When no prefix was matched, last_node was never set (stays None). + # Point it to root_node so dec_lock_ref inside cache_finished_req + # is a no-op rather than crashing on None. + if decode_req.req.last_node is None and hasattr( + self.tree_cache, "root_node" + ): + decode_req.req.last_node = self.tree_cache.root_node + if self.scheduler.enable_hisparse: + self.scheduler.hisparse_coordinator.request_finished(decode_req.req) + release_kv_cache(decode_req.req, self.tree_cache, is_insert=False) + if self.scheduler.enable_metrics: + self.scheduler.metrics_collector.increment_transfer_failed_reqs() + def _poll_with_staging(self) -> list: return poll_and_all_reduce_with_staging( self.queue, self.staging_handler, self.gloo_group @@ -1056,7 +1273,18 @@ def _init_staging_handler(self, kv_manager): kv_manager._staging_handler = self.staging_handler def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]: - if not self.queue: + # Guard: all TPs must agree on queue size before poll_and_all_reduce. + # _resolve_pending_reqs does independent HTTP calls per TP, so queue + # sizes can transiently diverge; a mismatched all_reduce corrupts gloo. + if dist.get_world_size(self.gloo_group) > 1: + n = len(self.queue) + local = torch.tensor([n, -n], dtype=torch.int64, device="cpu") + dist.all_reduce(local, op=dist.ReduceOp.MIN, group=self.gloo_group) + if local[0].item() != -local[1].item(): + return [] + if local[0].item() == 0: + return [] + elif not self.queue: return [] if self.enable_staging: @@ -1084,35 +1312,15 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) - self.scheduler.stream_output( - [decode_req.req], decode_req.req.return_logprob - ) - if self.scheduler.enable_hisparse: - self.scheduler.hisparse_coordinator.request_finished(decode_req.req) - # release pre-allocated kv cache, but don't insert into the tree since it's failed - release_kv_cache(decode_req.req, self.tree_cache, is_insert=False) + self._abort_and_release(decode_req) indices_to_remove.add(i) - if self.scheduler.enable_metrics: - self.scheduler.metrics_collector.increment_transfer_failed_reqs() continue elif poll == KVPoll.Success: should_remove = self._commit_transfer_to_req(decode_req) if should_remove: indices_to_remove.add(i) - # Check if request was aborted due to corruption if isinstance(decode_req.req.finished_reason, FINISH_ABORT): - self.scheduler.stream_output( - [decode_req.req], decode_req.req.return_logprob - ) - if self.scheduler.enable_hisparse: - self.scheduler.hisparse_coordinator.request_finished( - decode_req.req - ) - release_kv_cache( - decode_req.req, self.tree_cache, is_insert=False - ) - if self.scheduler.enable_metrics: - self.scheduler.metrics_collector.increment_transfer_failed_reqs() + self._abort_and_release(decode_req) else: transferred_reqs.append(decode_req.req) elif poll in [ @@ -1275,7 +1483,32 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) - req.init_next_round_input(self.tree_cache) + if self.server_args.disaggregation_enable_decode_radix_cache: + if req.disagg_decode_prefix_len > 0: + # Prefix info (prefix_indices, last_node) was set and locked + # in pop_preallocated; skip match_prefix to preserve them. + req.init_next_round_input(tree_cache=None) + # Protect tree-owned prefix pages in req_to_token[0:prefix_len]. + req.cache_protected_len = req.disagg_decode_prefix_len + else: + req.init_next_round_input(self.tree_cache) + # Balance the dec_lock_ref in cache_unfinished/finished_req + # (mirrors PrefillAdder._req_inc_lock_ref). + self.tree_cache.inc_lock_ref(req.last_node) + # req_to_token holds _pre_alloc pages, not tree pages; + # allow cache_unfinished_req to free duplicates. + req.cache_protected_len = 0 + + # pop_transferred appended the first output token after + # _pre_alloc, making fill_ids 1 longer than kv_committed_len. + # Truncate to avoid reading uninitialized req_to_token slots. + if len(req.fill_ids) > req.kv_committed_len: + req.fill_ids = req.fill_ids[: req.kv_committed_len] + req.set_extend_input_len( + len(req.fill_ids) - len(req.prefix_indices) + ) + else: + req.init_next_round_input(self.tree_cache) else: waiting_queue.append(req) diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 34229658598b..4501e80f2c2d 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -42,8 +42,10 @@ def prepare_for_prebuilt(self: ScheduleBatch): for i, req in enumerate(reqs): req_pool_indices.append(req.req_pool_idx) + # Read KV indices for the extend portion (after prefix) + pre_len_i = len(req.prefix_indices) chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][ - : req.extend_input_len + pre_len_i : pre_len_i + req.extend_input_len ] assert ( offset + req.extend_input_len <= total_size @@ -60,7 +62,9 @@ def prepare_for_prebuilt(self: ScheduleBatch): ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}" if not req.retracted_stain: - req.cached_tokens += pre_len - req.already_computed + # In disagg decode, cached_tokens is already set by + # pop_transferred from the prefill side's metadata. + # Don't add decode-side prefix match to avoid double-counting. req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 60bf5465ba2e..5ce44e92af03 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -106,10 +106,11 @@ def send_metadata( kv_indices: list[int], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: int = 0, ): self.has_sent_metadata = True logger.debug( - f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}" + f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}, decode_prefix_len: {decode_prefix_len}" ) def failure_exception(self): diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 64d97f5c6966..78069504253f 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -81,6 +81,7 @@ class TransferInfo: dst_state_indices: List[int] required_dst_info_num: int is_dummy: bool + decode_prefix_len: int = 0 # Note: always put the optional staging field at the final (it will be set through 'STAGING_RSP' pkg when needed) staging: Optional[StagingTransferInfo] = None @@ -99,6 +100,10 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32)) is_dummy = False + # decode_prefix_len: backward compatible, default 0 if not present + decode_prefix_len = ( + int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else 0 + ) return cls( room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -109,6 +114,7 @@ def from_zmq(cls, msg: List[bytes]): dst_state_indices=dst_state_indices, required_dst_info_num=int(msg[7].decode("ascii")), is_dummy=is_dummy, + decode_prefix_len=decode_prefix_len, ) @@ -1822,6 +1828,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: int = 0, ): if self.bootstrap_infos is None: self.kv_mgr.record_failure( @@ -1862,6 +1869,7 @@ def send_metadata( else b"" ), str(self.required_dst_info_num).encode("ascii"), + str(decode_prefix_len).encode("ascii"), ] ) self.init_time = time.time() diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 70154f9e981c..eb008e588285 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -68,6 +68,7 @@ class TransferInfo: dst_aux_index: int required_dst_info_num: int is_dummy: bool + decode_prefix_len: int = 0 @classmethod def from_zmq(cls, payload: List[bytes]) -> TransferInfo: @@ -90,6 +91,12 @@ def from_zmq(cls, payload: List[bytes]) -> TransferInfo: int(payload[7].decode("ascii")) if len(payload) > 7 else 1 ) is_dummy = dst_kv_indices.size == 0 and dst_aux_index < 0 + # decode_prefix_len: backward compatible, default 0 if not present + decode_prefix_len = ( + int(payload[8].decode("ascii")) + if len(payload) > 8 and payload[8] != b"" + else 0 + ) return cls( room=room, endpoint=endpoint, @@ -99,6 +106,7 @@ def from_zmq(cls, payload: List[bytes]) -> TransferInfo: dst_aux_index=dst_aux_index, required_dst_info_num=required_dst_info_num, is_dummy=is_dummy, + decode_prefix_len=decode_prefix_len, ) @@ -1033,6 +1041,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: int = 0, ): if self.bootstrap_infos is None or self.bootstrap_room is None: return @@ -1058,6 +1067,7 @@ def send_metadata( aux_bytes if not is_dummy else b"", state_bytes, str(self.required_dst_info_num).encode("ascii"), + str(decode_prefix_len).encode("ascii"), ] ) self.init_time = time.time() diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 005d5b05c286..4432c649b2e5 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -44,6 +44,7 @@ class TransferInfo: dst_aux_index: int required_dst_info_num: int dst_state_indices: List[int] + decode_prefix_len: int = 0 def is_dummy(self): return self.dst_kv_indices.size == 0 @@ -56,6 +57,11 @@ def from_zmq(cls, msg: List[bytes]): else: dst_state_indices = [] + # decode_prefix_len: backward compatible, default 0 if not present + decode_prefix_len = ( + int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else 0 + ) + return cls( room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), @@ -65,6 +71,7 @@ def from_zmq(cls, msg: List[bytes]): dst_aux_index=int(msg[5].decode("ascii")), required_dst_info_num=int(msg[6].decode("ascii")), dst_state_indices=dst_state_indices, + decode_prefix_len=decode_prefix_len, ) @@ -1113,6 +1120,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: int = 0, ): if self.bootstrap_infos is None: logger.error( @@ -1146,6 +1154,7 @@ def send_metadata( if not is_dummy and state_indices is not None else b"" ), + str(decode_prefix_len).encode("ascii"), ] ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 8e3da245b9e8..1fb291832685 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -335,8 +335,28 @@ def pop_bootstrapped( ) assert req.metadata_buffer_index is not None - num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) - req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) + page_size = self.token_to_kv_pool.page_size + total_pages = kv_to_page_num(num_kv_indices, page_size) + + # Read decode_prefix_len from the bootstrap message to skip + # transferring KV that the decode side already has cached. + decode_prefix_len = self.kv_manager.get_decode_prefix_len( + req.bootstrap_room + ) + logger.info( + f"Prefill bootstrap for {req.rid}: " + f"decode_prefix_len={decode_prefix_len}, total_pages={total_pages}, " + f"bootstrap_room={req.bootstrap_room}" + ) + if decode_prefix_len > 0: + req.start_send_idx = decode_prefix_len + req.disagg_prefill_skip_tokens = decode_prefix_len + decode_prefix_pages = kv_to_page_num(decode_prefix_len, page_size) + incremental_pages = total_pages - decode_prefix_pages + else: + incremental_pages = total_pages + + req.disagg_kv_sender.init(incremental_pages, req.metadata_buffer_index) bootstrapped_reqs.append(req) indices_to_remove.add(i) @@ -658,18 +678,20 @@ def process_disagg_prefill_inflight_queue( req.time_stats.set_completion_time() page_size = self.token_to_kv_pool_allocator.page_size - kv_item_lens = ( - self.disagg_prefill_bootstrap_queue.kv_manager.kv_args.kv_item_lens - ) - bytes_per_page_all_layers = sum(kv_item_lens) + kv_args = self.disagg_prefill_bootstrap_queue.kv_manager.kv_args + bytes_per_page_all_layers = sum(kv_args.kv_item_lens) + state_bytes_per_req = sum(kv_args.state_item_lens) if kv_args.state_item_lens else 0 for req in done_reqs: if isinstance(req.finished_reason, FINISH_ABORT): continue + # Use actual transferred tokens (excluding decode-side cached prefix) + actual_transfer_tokens = len(req.origin_input_ids) - req.disagg_prefill_skip_tokens metrics = req.time_stats.compute_and_observe_kv_transfer_metrics( - num_tokens=len(req.origin_input_ids), + num_tokens=actual_transfer_tokens, page_size=page_size, bytes_per_page_all_layers=bytes_per_page_all_layers, + state_bytes_per_req=state_bytes_per_req, ) if metrics: # Update last-value for REST API @@ -767,7 +789,6 @@ def send_kv_chunk( .cpu() .numpy() ) - req.start_send_idx = end_idx state_indices = None if last_chunk: self.disagg_metadata_buffers.set_buf(req) @@ -819,4 +840,5 @@ def send_kv_chunk( f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" ) return + req.start_send_idx = end_idx req.disagg_kv_sender.send(page_indices, state_indices) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 078d6ab91f18..51b61bd2fe8c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -863,6 +863,15 @@ def __init__( # start_send_idx = len(req.fill_ids) self.start_send_idx: int = 0 + # For incremental KV transfer in PD disaggregation: + # Page-aligned prefix length matched on the decode side against its local tree cache. + # When > 0, only the incremental KV beyond this prefix is transferred from prefill. + self.disagg_decode_prefix_len: int = 0 + + # Number of tokens skipped on the prefill side due to decode-side prefix caching. + # Used for accurate transfer_total metrics reporting. + self.disagg_prefill_skip_tokens: int = 0 + # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap # This is because kv is not ready in `process_prefill_chunk`. # We use `tmp_end_idx` to store the end index of the kv cache to send. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5cfb32c68057..367508b5fcaf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -3265,6 +3265,16 @@ def abort_request(self, recv_req: AbortReq): if self.disaggregation_mode == DisaggregationMode.DECODE: if self.enable_hisparse: self.hisparse_coordinator.request_finished(req) + # Protect tree-owned prefix pages from being freed. + if req.disagg_decode_prefix_len > 0: + req.cache_protected_len = req.disagg_decode_prefix_len + # Ensure last_node is valid for dec_lock_ref inside + # cache_finished_req. When no prefix was matched in + # pop_preallocated, last_node was never set (stays None). + if req.last_node is None and hasattr( + self.tree_cache, "root_node" + ): + req.last_node = self.tree_cache.root_node release_kv_cache(req, self.tree_cache) # For disaggregation prefill mode, free the metadata buffer index if self.disaggregation_mode == DisaggregationMode.PREFILL: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e135dfb92680..1cb6fc376f31 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -580,16 +580,21 @@ def _mamba_prefix_cache_update( ) ) req.mamba_last_track_seqlen = seq_len - elif ( - not batch.spec_algorithm.is_none() - and result.accept_length_per_req_cpu is not None - ): + elif not batch.spec_algorithm.is_none(): # for spec decode, update mamba_last_track_seqlen if this iteration crosses a track interval actual_seq_len = req.seqlen - 1 + # Use mamba_last_track_seqlen as the reference for the previous bucket. + # For the first check (when tracking hasn't started), use the input length. + # This avoids relying on accept_length_per_req_cpu which can be inaccurate: + # V1 has intermediate forward passes with accept_len=0, and V2 may not + # include the bonus token in accept_len, both causing missed boundary crossings. + if req.mamba_last_track_seqlen is not None: + prev_ref = req.mamba_last_track_seqlen + else: + prev_ref = len(req.origin_input_ids) - 1 if ( actual_seq_len // mamba_track_interval - != (actual_seq_len - result.accept_length_per_req_cpu[i]) - // mamba_track_interval + > prev_ref // mamba_track_interval ): req.mamba_next_track_idx = ( batch.req_to_token_pool.get_mamba_ping_pong_other_idx( @@ -599,6 +604,31 @@ def _mamba_prefix_cache_update( req.mamba_last_track_seqlen = ( actual_seq_len // mamba_track_interval * mamba_track_interval ) + elif ( + req.mamba_last_track_seqlen is None + and self.server_args.disaggregation_mode == "decode" + and self.server_args.disaggregation_enable_decode_radix_cache + ): + # First spec-decode call and no boundary crossed. + # For disagg decode, the transferred mamba state has never + # been checkpointed on this node. Force a swap so that the + # current ping_pong buffer (populated by this forward pass) + # is preserved as the initial checkpoint. + # The state is at actual_seq_len rather than exactly at the + # boundary, but on the decode side tree cache mamba_value is + # not used for computation (skip_mamba_truncation=True). + # This is gated to disagg decode only — on single-machine, + # mamba_value IS used for state loading and must be exact. + initial_track = ( + prev_ref // mamba_track_interval * mamba_track_interval + ) + if initial_track > 0: + req.mamba_next_track_idx = ( + batch.req_to_token_pool.get_mamba_ping_pong_other_idx( + req.mamba_next_track_idx + ) + ) + req.mamba_last_track_seqlen = initial_track def _process_input_token_logprobs( self: Scheduler, req: Req, input_token_logprobs: List diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index be219339cef6..169d42ad08b2 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -42,6 +42,12 @@ class MatchPrefixParams: cow_mamba: bool = False req: Optional[Req] = None + # For disagg decode: skip mamba-based prefix truncation. + # When True, match_prefix returns all matched KV indices regardless of + # mamba_value presence on intermediate nodes. SSM state is transferred + # from prefill, so decode tree cache mamba_value is not needed. + skip_mamba_truncation: bool = False + @dataclasses.dataclass class InsertParams: diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index d07702cf1efd..c55634f8df6b 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -495,8 +495,12 @@ def match_prefix(self, params: MatchPrefixParams) -> MatchResult: last_host_node=self.root_node, ) - value, last_node, best_value_len = self._match_prefix_helper(key) - return self._match_post_processor(params, value, last_node, best_value_len) + value, actual_last_node, best_last_node, best_value_len = ( + self._match_prefix_helper(key) + ) + return self._match_post_processor( + params, value, actual_last_node, best_last_node, best_value_len + ) def insert(self, params: InsertParams) -> InsertResult: if self.disable: @@ -955,12 +959,18 @@ def available_and_evictable_str(self) -> str: def _match_prefix_helper( self, key: RadixKey - ) -> Tuple[List[torch.Tensor], TreeNode, int]: + ) -> Tuple[List[torch.Tensor], TreeNode, TreeNode, int]: """ Mamba prefix matching helper. It factors in the sliding window size such that the matched node is guaranteed to either 1. connected to root without mamba tombstone, or 2. the number of matching tokens from the matched node to the last mamba tombstone node is greater than or equal to the sliding window size. + + Returns: + (value, actual_last_node, best_last_node, best_value_len) + - actual_last_node: the deepest node reached during tree traversal + - best_last_node: the deepest node that has mamba_value (for truncation) + - best_value_len: number of value entries up to best_last_node """ node = self.root_node child_key = self.get_child_key_fn(key) @@ -993,7 +1003,7 @@ def _match_prefix_helper( best_value_len = len(value) best_last_node = node - return value, best_last_node, best_value_len + return value, node, best_last_node, best_value_len def _match_pre_processor(self, params: MatchPrefixParams) -> Optional[RadixKey]: """Preprocess the key before matching.""" @@ -1008,12 +1018,22 @@ def _match_post_processor( self, params: MatchPrefixParams, value: List[torch.Tensor], - last_node: TreeNode, + actual_last_node: TreeNode, + best_last_node: TreeNode, best_value_len: int, ) -> MatchResult: """Post-process the matched result.""" cow_mamba = params.cow_mamba req = params.req + skip_mamba_truncation = params.skip_mamba_truncation + + # When skip_mamba_truncation is set (disagg decode path), use the actual + # last matched node and all matched values. SSM state comes from prefill + # transfer, so we don't need mamba_value from the tree cache. + if skip_mamba_truncation: + last_node = actual_last_node + else: + last_node = best_last_node # update time for matched nodes, and make nodes closer to root to be least recently used # this allows mamba to evict nodes closer to root first @@ -1046,26 +1066,31 @@ def _match_post_processor( mamba_branching_seqlen = None # Copy mamba state to req local space if cow is true - if cow_mamba and last_node.mamba_value is not None: + if cow_mamba and best_last_node.mamba_value is not None: # for reqs without mamba cache if req.mamba_pool_idx is None: dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - # try to alloc again, protect last_node from eviction + # try to alloc again, protect best_last_node from eviction if dst_index is None: - self.inc_lock_ref(last_node) + self.inc_lock_ref(best_last_node) self.evict(EvictParams(num_tokens=0, mamba_num=1)) dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - self.dec_lock_ref(last_node) + self.dec_lock_ref(best_last_node) assert dst_index is not None, "Can not alloc mamba cache" - src_index = last_node.mamba_value + src_index = best_last_node.mamba_value self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) req.mamba_pool_idx = dst_index[0] else: - src_index = last_node.mamba_value + src_index = best_last_node.mamba_value dst_index = req.mamba_pool_idx.unsqueeze(0) self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) - value = value[:best_value_len] + if skip_mamba_truncation: + # Use all matched values — KV indices are valid regardless of mamba_value + pass + else: + value = value[:best_value_len] + if value: value = torch.cat(value) else: diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py index aa067152fd38..84b81647994b 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py @@ -769,6 +769,7 @@ def compute_and_observe_kv_transfer_metrics( num_tokens: int, page_size: int, bytes_per_page_all_layers: int, + state_bytes_per_req: int = 0, ) -> Optional[dict]: """Compute KV transfer metrics and observe them via the metrics collector. @@ -786,7 +787,7 @@ def compute_and_observe_kv_transfer_metrics( latency_ms = transfer_latency_s * 1000 num_pages = kv_to_page_num(num_tokens, page_size) - total_bytes = bytes_per_page_all_layers * num_pages + total_bytes = bytes_per_page_all_layers * num_pages + state_bytes_per_req total_mb = total_bytes / (1024 * 1024) self.transfer_total_mb = total_mb diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ae4feb3cf62f..f646d88f61e2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -705,6 +705,7 @@ class ServerArgs: disaggregation_bootstrap_port: int = 8998 disaggregation_ib_device: Optional[str] = None disaggregation_decode_enable_offload_kvcache: bool = False + disaggregation_enable_decode_radix_cache: bool = False num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 @@ -3439,8 +3440,21 @@ def _check_format(has_params: bool, has_hf_config: bool) -> bool: def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": - self.disable_radix_cache = True - logger.warning("KV cache is forced as chunk cache for decode server") + if self.disaggregation_enable_decode_radix_cache: + if not self.disable_radix_cache: + logger.info( + "[EXPERIMENTAL] Decode server using RadixCache for incremental " + "KV transfer. Prefix matching on the decode side enables transferring " + "only the incremental KV delta from prefill." + ) + else: + logger.warning( + "--disaggregation-enable-decode-radix-cache is set but " + "radix cache is disabled. Incremental transfer will not work." + ) + else: + self.disable_radix_cache = True + logger.info("Decode server using ChunkCache (radix cache disabled).") elif self.disaggregation_mode == "prefill": assert ( @@ -6020,6 +6034,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable async KV cache offloading on decode server (PD mode).", ) + parser.add_argument( + "--disaggregation-enable-decode-radix-cache", + action="store_true", + help="Enable radix cache prefix matching on decode side for incremental KV transfer in PD disaggregation.", + ) parser.add_argument( "--num-reserved-decode-tokens", type=int,