diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index f7d4092d85dd..93fc4b1c5a7e 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -136,6 +136,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: Optional[int] = None, ): """ Notify the prefill server about the kv indices, aux index, and state_indices. diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 26752d52dd54..b271d2cfe1f5 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -141,6 +141,7 @@ def __init__( ) self.register_to_bootstrap() self.transfer_infos = {} + self.req_to_decode_prefix_len: Dict[int, int] = {} self.decode_kv_args_table = {} self.pp_group = get_pp_group() # If a timeout happens on the prefill side, it means prefill instances @@ -179,6 +180,12 @@ def check_status(self, bootstrap_room: int) -> KVPoll: return self.request_status[bootstrap_room] def update_status(self, bootstrap_room: int, status: KVPoll): + if ( + status == KVPoll.Failed + and self.disaggregation_mode == DisaggregationMode.PREFILL + and hasattr(self, "req_to_decode_prefix_len") + ): + self.req_to_decode_prefix_len.pop(bootstrap_room, None) if bootstrap_room not in self.request_status: self.request_status[bootstrap_room] = status else: diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 875900ad8579..9fc8654598bc 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -45,6 +45,7 @@ get_kv_class, is_mla_backend, kv_to_page_indices, + page_align_floor, poll_and_all_reduce, poll_and_all_reduce_with_staging, prepare_abort, @@ -54,7 +55,11 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.base_prefix_cache import ( + BasePrefixCache, + EvictParams, + MatchPrefixParams, +) from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.memory_pool import ( HybridLinearKVPool, @@ -63,11 +68,13 @@ NSATokenToKVPool, ReqToTokenPool, ) +from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.observability.req_time_stats import ( set_schedule_time_batch, set_time_batch, ) +from sglang.srt.utils import get_num_new_pages from sglang.srt.utils.network import NetworkAddress from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -273,7 +280,7 @@ def __init__( self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.scheduler = scheduler self.transfer_queue = transfer_queue - self.tree_cache = tree_cache # this is always a chunk cache + self.tree_cache = tree_cache self.gloo_group = gloo_group self.tp_rank = tp_rank self.tp_size = tp_size @@ -427,6 +434,30 @@ def add(self, req: Req, is_retracted: bool = False) -> None: self.pending_reqs.append(decode_req) + def _match_prefix_and_lock(self, req: Req) -> Tuple[torch.Tensor, int]: + """ + Match a request against the decode-side radix cache, lock the matched + node to prevent eviction, and return the matched prefix information. + """ + result = self.tree_cache.match_prefix( + MatchPrefixParams( + key=RadixKey(req.origin_input_ids, extra_key=req.extra_key), + req=req, + cow_mamba=self.tree_cache.supports_mamba(), + ) + ) + prefix_indices = result.device_indices + last_device_node = result.last_device_node + # Always lock to match aggregated scheduling behavior + self.tree_cache.inc_lock_ref(last_device_node) + + # we do this to ensure that whenever dec_loc_ref is called + # on the Req object, we are not dereferencing a `None`. In the + # agg case, the scheduler does this already + req.last_node = last_device_node + + return prefix_indices, len(prefix_indices) + def _resolve_prefill_dp_rank(self, req: Req) -> Optional[int]: if req.disagg_prefill_dp_rank is not None: return req.disagg_prefill_dp_rank @@ -711,14 +742,36 @@ def pop_preallocated( # Memory estimation: don't add if the projected memory cannot be met # TODO: add new_token ratio origin_input_len = len(decode_req.req.origin_input_ids) + if self.scheduler.server_args.disaggregation_decode_enable_radix_cache: + # Match prefix against decode's radix cache. + prefix_indices, prefix_len = self._match_prefix_and_lock(decode_req.req) + # Align prefix_len down to page boundary so both prefill and + # decode agree on the page-aligned split point for KV transfer. + page_size = self.token_to_kv_pool_allocator.page_size + if page_size > 1 and prefix_len % page_size != 0: + prefix_len = page_align_floor(prefix_len, page_size) + prefix_indices = prefix_indices[:prefix_len] + + fill_len = origin_input_len + max( + len(decode_req.req.output_ids) - 1, 0 + ) + required_alloc_tokens = self._required_alloc_tokens( + fill_len=fill_len, prefix_len=prefix_len + ) + else: + prefix_indices = None + prefix_len = 0 + required_alloc_tokens = origin_input_len + required_tokens_for_request = ( - origin_input_len + self.num_reserved_decode_tokens + required_alloc_tokens + self.num_reserved_decode_tokens ) if ( max( required_tokens_for_request, origin_input_len + - prefix_len + min( decode_req.req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKEN, @@ -727,26 +780,43 @@ def pop_preallocated( ) > allocatable_tokens ): + if prefix_len > 0: + self.tree_cache.dec_lock_ref(decode_req.req.last_node) break if required_tokens_for_request > allocatable_tokens: + if prefix_len > 0: + self.tree_cache.dec_lock_ref(decode_req.req.last_node) break - allocatable_tokens -= required_tokens_for_request + dst_kv_indices = self._pre_alloc(decode_req.req, prefix_indices, prefix_len) hisparse_req_budget -= 1 - dst_kv_indices = self._pre_alloc(decode_req.req) + # Recompute from actual pool state for the next queue entry. + # This accounts for page rounding and newly locked evictable cache. + allocatable_tokens = self._allocatable_tokens( + retractable_tokens=retractable_tokens, + count_retracted=True, + extra_reserved_reqs=len(preallocated_reqs) + 1, + ) + decode_req.req.cache_protected_len = prefix_len - origin_input_len = len(decode_req.req.origin_input_ids) if self.scheduler.enable_hisparse: - # Must cast to int32 for ZMQ serialization — from_zmq reads np.int32. + # Must cast to int32 for ZMQ serialization -- from_zmq reads np.int32. kv_indices = ( - dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32) + dst_kv_indices[: origin_input_len - prefix_len] + .cpu() + .numpy() + .astype(np.int32) ) page_size = 1 # host pool page_size else: - kv_indices_full = self.req_to_token_pool.req_to_token[ - decode_req.req.req_pool_idx - ][:origin_input_len] - kv_indices = kv_indices_full.cpu().numpy() + # Only send delta indices (beyond prefix) to prefill. + kv_indices = ( + self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ + prefix_len:origin_input_len + ] + .cpu() + .numpy() + ) page_size = self.token_to_kv_pool_allocator.page_size # Prepare extra pool indices for hybrid models @@ -765,7 +835,7 @@ def pop_preallocated( window_size = self.scheduler.sliding_window_size window_start = max(0, seq_len - window_size) - window_start = (window_start // page_size) * page_size + window_start = page_align_floor(window_start, page_size) window_kv_indices_full = self.req_to_token_pool.req_to_token[ decode_req.req.req_pool_idx, window_start:seq_len ] @@ -796,7 +866,10 @@ def pop_preallocated( assert decode_req.metadata_buffer_index is not None page_indices = kv_to_page_indices(kv_indices, page_size) decode_req.kv_receiver.send_metadata( - page_indices, decode_req.metadata_buffer_index, state_indices + page_indices, + decode_req.metadata_buffer_index, + state_indices, + decode_prefix_len=prefix_len, ) if ( self.transfer_queue.enable_staging @@ -823,7 +896,10 @@ def num_tokens_pre_allocated(self): ) def _allocatable_tokens( - self, retractable_tokens: Optional[int] = None, count_retracted: bool = True + self, + retractable_tokens: Optional[int] = None, + count_retracted: bool = True, + extra_reserved_reqs: int = 0, ) -> int: need_space_for_single_req = ( max( @@ -846,6 +922,10 @@ def _allocatable_tokens( ) else: available_size = self.token_to_kv_pool_allocator.available_size() + # Include evictable decode-radix cache entries in the budget -- they + # can be freed on demand before allocation. + if self.scheduler.server_args.disaggregation_decode_enable_radix_cache: + available_size += self.tree_cache.evictable_size() allocatable_tokens = available_size - max( # preserve some space for future decode self.num_reserved_decode_tokens @@ -853,6 +933,7 @@ def _allocatable_tokens( len(self.scheduler.running_batch.reqs) + len(self.transfer_queue.queue) + len(self.scheduler.waiting_queue) + + extra_reserved_reqs ), # make sure each request can finish if reach max_tokens with all other requests retracted need_space_for_single_req, @@ -879,63 +960,150 @@ def _allocatable_tokens( ) return allocatable_tokens - def _pre_alloc(self, req: Req) -> torch.Tensor: + def _required_alloc_tokens(self, *, fill_len: int, prefix_len: int) -> int: + page_size = self.token_to_kv_pool_allocator.page_size + if page_size == 1: + return fill_len - prefix_len + + num_new_pages = get_num_new_pages( + seq_lens=torch.tensor([fill_len], dtype=torch.int64), + prefix_lens=torch.tensor([prefix_len], dtype=torch.int64), + page_size=page_size, + ) + return num_new_pages * page_size + + def _pre_alloc( + self, + req: Req, + prefix_indices: Optional[torch.Tensor] = None, + prefix_len: Optional[int] = None, + ) -> torch.Tensor: """Pre-allocate the memory for req_to_token and token_kv_pool""" + if prefix_len is None: + prefix_len = 0 + req_pool_indices = self.req_to_token_pool.alloc([req]) assert ( req_pool_indices is not None ), "req_pool_indices is full! There is a bug in memory estimation." - # Alloc all tokens for the prebuilt req (except for the reserved input token for decoding) fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) req.kv_allocated_len = fill_len req.kv_committed_len = fill_len + if prefix_len > 0: + self.req_to_token_pool.write( + (req.req_pool_idx, slice(0, prefix_len)), prefix_indices + ) + + # TODO(retraction): when retraction is implemented with radix cache + # awareness, a retracted request should re-match the tree here + # instead of re-allocating from scratch. See resume_retracted_reqs. + delta_len = fill_len - prefix_len + required_alloc_tokens = self._required_alloc_tokens( + fill_len=fill_len, prefix_len=prefix_len + ) + + # Evict cached entries if the pool doesn't have enough free pages. + if ( + self.scheduler.server_args.disaggregation_decode_enable_radix_cache + and self.token_to_kv_pool_allocator.available_size() < required_alloc_tokens + ): + num_to_evict = ( + required_alloc_tokens - self.token_to_kv_pool_allocator.available_size() + ) + result = self.tree_cache.evict(EvictParams(num_tokens=num_to_evict)) + if self.token_to_kv_pool_allocator.available_size() < required_alloc_tokens: + logger.warning( + f"Eviction insufficient: needed {required_alloc_tokens} tokens, " + f"available {self.token_to_kv_pool_allocator.available_size()} " + f"after evicting {result.num_tokens_evicted}/{num_to_evict} tokens. " + f"evictable_size={self.tree_cache.evictable_size()}, " + f"protected_size={self.tree_cache.protected_size()}, " + f"fill_len={fill_len}, prefix_len={prefix_len}, delta_len={delta_len}, " + f"page_size={self.token_to_kv_pool_allocator.page_size}, " + f"req={req.rid}" + ) + if self.scheduler.enable_hisparse: # Direct-to-host path: only allocate logical indices (no hisparse # device indices) and allocate host indices for RDMA destination. coordinator = self.scheduler.hisparse_coordinator device = self.token_to_kv_pool_allocator.device + last_loc = ( + prefix_indices[-1:].to(dtype=torch.int64, device=device) + if prefix_len > 0 + else torch.tensor([-1], dtype=torch.int64, device=device) + ) kv_loc = self.token_to_kv_pool_allocator.alloc_logical_only( - prefix_lens=torch.tensor([0], dtype=torch.int64, device=device), - prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), + prefix_lens=torch.tensor( + [prefix_len], dtype=torch.int64, device=device + ), + prefix_lens_cpu=torch.tensor([prefix_len], dtype=torch.int64), seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device), seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), - last_loc=torch.tensor([-1], dtype=torch.int64, device=device), - extend_num_tokens=fill_len, + last_loc=last_loc, + extend_num_tokens=delta_len, ) - # Allocate host indices for the RDMA transfer target - host_indices = coordinator.mem_pool_host.alloc(fill_len) + # Allocate host indices only for the transfer delta. + host_indices = coordinator.mem_pool_host.alloc(delta_len) if host_indices is None: raise RuntimeError( - f"HiSparse host mem pool alloc failed for {fill_len} tokens " + f"HiSparse host mem pool alloc failed for {delta_len} tokens " f"in _pre_alloc (req {req.rid})" ) host_indices = host_indices.to(device=coordinator.device) - coordinator.req_to_host_pool[req.req_pool_idx, :fill_len] = host_indices + coordinator.req_to_host_pool[ + req.req_pool_idx, prefix_len : prefix_len + delta_len + ] = host_indices elif self.token_to_kv_pool_allocator.page_size == 1: - kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len) + kv_loc = self.token_to_kv_pool_allocator.alloc(delta_len) else: device = self.token_to_kv_pool_allocator.device + last_loc = ( + prefix_indices[-1:].to(dtype=torch.int64, device=device) + if prefix_len > 0 + else torch.tensor([-1], dtype=torch.int64, device=device) + ) kv_loc = self.token_to_kv_pool_allocator.alloc_extend( - prefix_lens=torch.tensor([0], dtype=torch.int64, device=device), - prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), + prefix_lens=torch.tensor( + [prefix_len], dtype=torch.int64, device=device + ), + prefix_lens_cpu=torch.tensor([prefix_len], dtype=torch.int64), seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device), seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), - last_loc=torch.tensor([-1], dtype=torch.int64, device=device), - extend_num_tokens=fill_len, + last_loc=last_loc, + extend_num_tokens=delta_len, ) - assert ( - kv_loc is not None - ), "KV cache is full! There is a bug in memory estimation." + assert kv_loc is not None, ( + f"KV cache is full! Bug in memory estimation. " + f"available={self.token_to_kv_pool_allocator.available_size()}, " + f"evictable={self.tree_cache.evictable_size()}, " + f"protected={self.tree_cache.protected_size()}, " + f"required_alloc={required_alloc_tokens}, delta={delta_len}, " + f"fill={fill_len}, prefix={prefix_len}, " + f"page_size={self.token_to_kv_pool_allocator.page_size}, " + f"req={req.rid}" + ) - self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) + self.req_to_token_pool.write( + (req.req_pool_idx, slice(prefix_len, prefix_len + len(kv_loc))), kv_loc + ) - # populate metadata - req.fill_ids = req.origin_input_ids + req.output_ids - req.set_extend_input_len(len(req.fill_ids)) + # Truncate fill_ids to kv_committed_len so cache_unfinished_req only + # inserts committed KV into the radix tree. The last output token + # hasn't had KV committed yet (fill_ids is 1 ahead). + req.fill_ids = (req.origin_input_ids + req.output_ids)[: req.kv_committed_len] + # Set prefix_indices so downstream consumers (init_next_round_input, + # prepare_for_extend) see the correct prefix length. In the agg path + # this is done inside init_next_round_input, but decode-disagg needs + # allocation info before batch assembly so we set it here. + req.prefix_indices = ( + prefix_indices if prefix_len > 0 else torch.empty((0,), dtype=torch.int64) + ) + req.set_extend_input_len(len(req.fill_ids) - prefix_len) # Return the transfer destination indices: if self.scheduler.enable_hisparse: @@ -1303,7 +1471,25 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) - req.init_next_round_input(self.tree_cache) + # Decode-radix path: do NOT re-match prefix here. + # `pop_preallocated` already took a tree snapshot and used it + # to (1) pre-allocate KV, (2) choose delta pages for transfer, + # and (3) set cache_protected_len/last_node for correct frees. + # Re-matching now can observe a newer tree (other reqs may have + # inserted the same prefix) and overwrite cache_protected_len, + # making `cache_unfinished_req` free the wrong range (leak). + # Non-radix decode keeps the original behavior. + tree_cache = ( + None + if self.server_args.disaggregation_decode_enable_radix_cache + else self.tree_cache + ) + req.init_next_round_input(tree_cache) + # Truncate fill_ids to kv_committed_len so cache_unfinished_req + # only sees committed KV (fill_ids includes one uncommitted token). + if req.kv_committed_len is not None: + req.fill_ids = req.fill_ids[: req.kv_committed_len] + req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)) else: waiting_queue.append(req) diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 60bf5465ba2e..073faecfb472 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -27,6 +27,7 @@ def __init__( is_mla_backend: Optional[bool] = False, ): super().__init__(args, disaggregation_mode, server_args, is_mla_backend) + self.req_to_decode_prefix_len = {} def register_to_bootstrap(self): pass @@ -41,6 +42,7 @@ def __init__( dest_tp_ranks: List[int], pp_rank: int, ): + self.kv_mgr = mgr self.has_sent = False def poll(self) -> KVPoll: @@ -106,6 +108,7 @@ def send_metadata( kv_indices: list[int], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: Optional[int] = None, ): self.has_sent_metadata = True logger.debug( diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 64d97f5c6966..ccd1ad4890e1 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -1822,6 +1822,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None: self.kv_mgr.record_failure( diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 70154f9e981c..6226c19df823 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -1033,6 +1033,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None or self.bootstrap_room is None: return diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 005d5b05c286..ac63280ccc9b 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -44,8 +44,15 @@ class TransferInfo: dst_aux_index: int required_dst_info_num: int dst_state_indices: List[int] + decode_prefix_len: Optional[int] = None # for decode radix cache def is_dummy(self): + # A transfer is "dummy" only for CP non-authoritative ranks. + # When dst_kv_indices is empty due to a decode-side radix cache + # full hit (decode_prefix_len > 0), the transfer is NOT dummy -- + # aux/state data still needs to be sent. + if self.dst_kv_indices.size == 0 and self.decode_prefix_len: + return False return self.dst_kv_indices.size == 0 @classmethod @@ -65,6 +72,9 @@ def from_zmq(cls, msg: List[bytes]): dst_aux_index=int(msg[5].decode("ascii")), required_dst_info_num=int(msg[6].decode("ascii")), dst_state_indices=dst_state_indices, + decode_prefix_len=( + int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else None + ), # hacky just add it into the message that will be sent ) @@ -883,39 +893,44 @@ def add_transfer_request( assert len(chunked_dst_kv_indice) == len(kv_indices) assert req.agent_name in self.decode_kv_args_table - notif = ( - f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}" - ) decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size - if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): - kv_xfer_handle = self.send_kvcache( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - ) - else: - kv_xfer_handle = self.send_kvcache_slice( - req.agent_name, - kv_indices, - self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, - chunked_dst_kv_indice, - self.decode_kv_args_table[req.agent_name].gpu_id, - notif, - prefill_tp_size=self.attn_tp_size, - decode_tp_size=decode_tp_size, - decode_tp_rank=self.decode_kv_args_table[ - req.agent_name - ].decode_tp_rank, - dst_kv_item_len=self.decode_kv_args_table[ - req.agent_name - ].dst_kv_item_len, + # Skip KV RDMA transfer when there are no pages to send + # (e.g., decode-side radix cache matched the entire prefix). + # Aux data is still sent below when is_last=True. + if len(kv_indices) > 0: + notif = ( + f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" ) - handles.append(kv_xfer_handle) + if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): + kv_xfer_handle = self.send_kvcache( + req.agent_name, + kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + ) + else: + kv_xfer_handle = self.send_kvcache_slice( + req.agent_name, + kv_indices, + self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, + chunked_dst_kv_indice, + self.decode_kv_args_table[req.agent_name].gpu_id, + notif, + prefill_tp_size=self.attn_tp_size, + decode_tp_size=decode_tp_size, + decode_tp_rank=self.decode_kv_args_table[ + req.agent_name + ].decode_tp_rank, + dst_kv_item_len=self.decode_kv_args_table[ + req.agent_name + ].dst_kv_item_len, + ) + + handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. if is_last: if state_indices is not None: @@ -936,16 +951,24 @@ def add_transfer_request( handles.append(state_xfer_handle) assert aux_index is not None + # When no KV pages were sent (decode-side cache hit), + # encode pp_rank in aux notif so receiver can mark + # expected_kvs_per_pp[pp_rank] = 0. + if len(kv_indices) == 0: + aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}" + else: + aux_notif = f"{req.room}_aux" aux_xfer_handle = self.send_aux( req.agent_name, aux_index, self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, req.dst_aux_index, - f"{req.room}_aux", + aux_notif, ) handles.append(aux_xfer_handle) if is_last: del self.transfer_infos[bootstrap_room] + self.req_to_decode_prefix_len.pop(bootstrap_room, None) return handles def update_transfer_status(self): @@ -978,6 +1001,15 @@ def update_transfer_status(self): ) elif components[1] == "aux": self.transfer_statuses[room].received_aux = True + # Handle "nokv" marker: no KV pages were sent for + # this pp_rank (decode-side radix cache hit). + if len(components) > 3 and components[2] == "nokv": + pp_rank = int(components[3]) + self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = 0 + if self.transfer_statuses[room].num_pp_ranks_expected is None: + self.transfer_statuses[room].num_pp_ranks_expected = ( + self.required_prefill_response_num_table.get(room, 1) + ) elif components[1] == "state": pp_rank = int(components[2]) if len(components) > 2 else 0 self.transfer_statuses[room].received_state_per_pp.add(pp_rank) @@ -1019,6 +1051,14 @@ def bootstrap_thread(): ].required_dst_info_num logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}") if len(self.transfer_infos[room]) == required_dst_info_num: + self.req_to_decode_prefix_len[room] = next( + ( + info.decode_prefix_len + for info in self.transfer_infos[room].values() + if info.decode_prefix_len is not None + ), + 0, + ) logger.debug(f"{room=} is bootstrapped") self.update_status(room, KVPoll.WaitingForInput) @@ -1113,6 +1153,7 @@ def send_metadata( kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, + decode_prefix_len: Optional[int] = None, ): if self.bootstrap_infos is None: logger.error( @@ -1146,6 +1187,7 @@ def send_metadata( if not is_dummy and state_indices is not None else b"" ), + str(decode_prefix_len or 0).encode("ascii"), ] ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index d9016d0f94d3..6fa62aa024ca 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -324,7 +324,7 @@ def pop_bootstrapped( self.scheduler.tree_cache.release_aborted_request(req.rid) continue - # KV.WaitingForInput - init here + # KV.WaitingForInput - decode is ready to receive. initialize the kv sender req.time_stats.set_bootstrap_done_time() num_kv_indices = len(req.origin_input_ids) if self.req_to_metadata_buffer_idx_allocator.available_size() == 0: @@ -335,7 +335,19 @@ def pop_bootstrapped( ) assert req.metadata_buffer_index is not None - num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) + # Cal number of pages to send + # if decode has a cached prefix, we need to send the delta indices + # otherwise, send the entire request + decode_prefix_len = ( + req.disagg_kv_sender.kv_mgr.req_to_decode_prefix_len.pop( + req.bootstrap_room, 0 + ) + ) + req.start_send_idx = decode_prefix_len + num_kv_indices_to_send = num_kv_indices - decode_prefix_len + num_pages = kv_to_page_num( + num_kv_indices_to_send, self.token_to_kv_pool.page_size + ) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) bootstrapped_reqs.append(req) @@ -768,12 +780,20 @@ def send_kv_chunk( # if not the last chunk and the last page is partial, delay the last partial page to the next send end_idx = end_idx - end_idx % page_size + if end_idx < start_idx: + logger.debug( + "send_kv_chunk skip: rid=%s start_send_idx=%s end_idx=%s", + req.rid, + start_idx, + end_idx, + ) + return + kv_indices = ( self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] .cpu() .numpy() ) - req.start_send_idx = end_idx state_indices = None if last_chunk: self.disagg_metadata_buffers.set_buf(req) @@ -820,9 +840,14 @@ def send_kv_chunk( state_indices = kv_to_page_indices(state_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size) + # Skip empty non-last chunks for all backends. For empty last chunks, + # only NIXL currently defines the aux/state-only completion path used + # by decode-side radix cache; keep a conservative early return for + # other backends until they implement the same semantics. if len(page_indices) == 0: - logger.info( - f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" - ) - return + if not last_chunk: + return + if self.transfer_backend != TransferBackend.NIXL: + return req.disagg_kv_sender.send(page_indices, state_indices) + req.start_send_idx = end_idx diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index d7956a60487b..35bf6d5dcc22 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -451,6 +451,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int): return (num_kv_indices + page_size - 1) // page_size +def page_align_floor(length: int, page_size: int) -> int: + """Round length down to the nearest page boundary.""" + return (length // page_size) * page_size + + def page_indices_to_cp_rank_page_indices( page_indices: np.ndarray, total_pages: int, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index aba143c83c16..d9d3c6b81faa 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -384,6 +384,11 @@ async def lifespan(fast_api_app: FastAPI): server_args.warmups.split(","), _global_state.tokenizer_manager, ) + if ( + server_args.disaggregation_mode != "null" + and not server_args.disable_radix_cache + ): + await _global_state.tokenizer_manager.flush_cache() logger.info("Warmup ended") # Execute the general warmup @@ -1973,8 +1978,28 @@ def _execute_server_warmup(server_args: ServerArgs): ) if res.status_code == 200: logger.info( - f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}" + f"Disaggregation warmup request completed with status {res.status_code}, resp: {res.json()}" ) + if ( + server_args.disaggregation_mode != "null" + and not server_args.disable_radix_cache + ): + try: + flush_res = requests.post( + url + "/flush_cache", + headers=headers, + timeout=30, + verify=ssl_verify, + ) + if flush_res.status_code == 200: + logger.info("Flushed warmup cache") + else: + logger.warning( + f"Warmup cache flush failed: {flush_res.status_code}" + ) + except Exception as e: + logger.warning(f"Warmup cache flush request failed: {e}") + logger.info("End of disaggregation warmup") _global_state.tokenizer_manager.server_status = ServerStatus.Up else: logger.info( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a6172890fcc8..72a21259afe0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1201,6 +1201,7 @@ def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.routed_experts = None self.last_node = None + self.cache_protected_len = 0 self.swa_uuid_for_lock = None self.extend_input_len = 0 self.is_retracted = True diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5297f070b09f..103a9d48c448 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -769,6 +769,24 @@ def init_cache_with_memory_pool(self): "Transformers backend to avoid multimodal prefix-cache mismatches." ) + # Decode radix cache is unsupported with hybrid SWA/SSM models — + # these use specialized memory pools incompatible with the + # prefix-match-and-lock allocation path. + if ( + server_args.disaggregation_decode_enable_radix_cache + and server_args.disaggregation_mode == "decode" + ): + if self.is_hybrid_swa: + raise ValueError( + "--disaggregation-decode-enable-radix-cache is incompatible " + "with sliding window attention (SWA) models" + ) + if self.is_hybrid_ssm: + raise ValueError( + "--disaggregation-decode-enable-radix-cache is incompatible " + "with Mamba/SSM models" + ) + effective_chunked_prefill_size = server_args.chunked_prefill_size if self.model_config.is_multimodal and uses_transformers_backend: effective_chunked_prefill_size = None diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 3bbd3e5f8ee4..0e09eac6d1be 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -84,15 +84,22 @@ def _get_cached_tokens_details(self: Scheduler, req: Req) -> Optional[dict]: def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): assert self.disaggregation_mode == DisaggregationMode.DECODE - for req in batch.reqs: - req.time_stats.set_decode_prebuilt_finish_time() - req.check_finished() - if req.finished(): - req.time_stats.set_quick_finish_time() - release_kv_cache(req, self.tree_cache) - - # Note: Logprobs should be handled on the prefill engine. - self.stream_output(batch.reqs, batch.return_logprob) + use_free_group = self.server_args.disaggregation_decode_enable_radix_cache + if use_free_group: + self.token_to_kv_pool_allocator.free_group_begin() + try: + for req in batch.reqs: + req.time_stats.set_decode_prebuilt_finish_time() + req.check_finished() + if req.finished(): + req.time_stats.set_quick_finish_time() + release_kv_cache(req, self.tree_cache) + + # Note: Logprobs should be handled on the prefill engine. + self.stream_output(batch.reqs, batch.return_logprob) + finally: + if use_free_group: + self.token_to_kv_pool_allocator.free_group_end() def maybe_collect_routed_experts(self: Scheduler, req: Req): """Collect routed experts for a finished request.""" diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 5d58bcde530c..6641bcfa0a56 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -30,6 +30,13 @@ class ChunkCache(BasePrefixCache): + """ + ChunkCache is used when radix cache is disabled. + + That includes standard chunked-prefill setups and the decode side of P/D + disaggregation when decode radix cache is not enabled. + """ + def __init__(self, params: CacheInitParams): self.req_to_token_pool = params.req_to_token_pool self.token_to_kv_pool_allocator = params.token_to_kv_pool_allocator diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 3ad6421a1323..fec5d41037f5 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -395,8 +395,8 @@ def match_prefix(self, params: MatchPrefixParams) -> MatchResult: Returns: MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of the concatenated KV cache indices corresponding to the longest - cached prefix (may be length 0). ``last_device_node`` and - ``last_host_node`` (currently the same) are the tree node objects + cached prefix (may be length 0). + ``last_device_node`` and ``last_host_node`` (currently the same) are the tree node objects representing the terminal node of the matched prefix. This method may mutate internal structure by splitting an existing node if the match ends inside a stored segment. @@ -491,10 +491,9 @@ def cache_finished_req(self, req: Req, is_insert: bool = True): result = self.insert( InsertParams(key=radix_key, value=values, priority=priority) ) - new_prefix_len = result.prefix_len # Free the duplicates that were already in the tree self.token_to_kv_pool_allocator.free( - kv_indices[req.cache_protected_len : new_prefix_len] + kv_indices[req.cache_protected_len : result.prefix_len] ) else: self.token_to_kv_pool_allocator.free( @@ -505,7 +504,8 @@ def cache_finished_req(self, req: Req, is_insert: bool = True): self.token_to_kv_pool_allocator.free(kv_indices[len(keys) :]) # Remove req slot release the cache lock - self.dec_lock_ref(req.last_node) + if req.last_node is not None: + self.dec_lock_ref(req.last_node) def cache_unfinished_req(self, req: Req, chunked=False): """Cache request when it is unfinished.""" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0a45c1dc6d38..0ec6df5a06a3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -710,6 +710,7 @@ class ServerArgs: disaggregation_transfer_backend: str = "mooncake" disaggregation_bootstrap_port: int = 8998 disaggregation_ib_device: Optional[str] = None + disaggregation_decode_enable_radix_cache: bool = False disaggregation_decode_enable_offload_kvcache: bool = False num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD # FIXME: hack to reduce ITL when decode bs is small @@ -770,6 +771,9 @@ def __post_init__(self): # Validate SSL arguments early (before dummy-model short-circuit). self._handle_ssl_validation() + # Validate PD disaggregation flags early (before dummy-model short-circuit). + self._handle_pd_disaggregation() + if self.model_path.lower() in ["none", "dummy"]: # Skip for dummy models return @@ -848,9 +852,6 @@ def __post_init__(self): # Handle model loading format. self._handle_load_format() - # Handle PD disaggregation. - self._handle_pd_disaggregation() - # Handle Encoder disaggregation. self._handle_encoder_disaggregation() @@ -3520,8 +3521,33 @@ def _check_format(has_params: bool, has_hf_config: bool) -> bool: def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": - self.disable_radix_cache = True - logger.warning("KV cache is forced as chunk cache for decode server") + if self.disaggregation_decode_enable_radix_cache: + if self.enable_hisparse: + raise ValueError( + "--disaggregation-decode-enable-radix-cache is incompatible " + "with --enable-hisparse" + ) + if self.disaggregation_transfer_backend != "nixl": + raise ValueError( + "--disaggregation-decode-enable-radix-cache currently " + "requires --disaggregation-transfer-backend nixl" + ) + if self.speculative_algorithm is not None: + raise ValueError( + "--disaggregation-decode-enable-radix-cache is incompatible " + "with speculative decoding " + f"(--speculative-algorithm {self.speculative_algorithm})" + ) + if self.enable_dp_attention: + logger.warning( + "EXPERIMENTAL: Decode radix cache with DP attention. " + "Requires prefix-aware DP rank routing for optimal cache hits." + ) + self.disable_radix_cache = False + logger.warning("EXPERIMENTAL: Radix cache is enabled for decode server") + else: + self.disable_radix_cache = True + logger.warning("KV cache is forced as chunk cache for decode server") elif self.disaggregation_mode == "prefill": assert ( @@ -6149,6 +6175,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " "Default is None, which triggers automatic device detection when mooncake backend is enabled.", ) + parser.add_argument( + "--disaggregation-decode-enable-radix-cache", + action="store_true", + help="Enable radix cache on decode server (PD mode). Caches KV prefixes to avoid redundant transfers. Requires --disaggregation-transfer-backend nixl and is incompatible with --enable-hisparse.", + ) parser.add_argument( "--disaggregation-decode-enable-offload-kvcache", action="store_true", diff --git a/test/registered/unit/mem_cache/test_decode_radix_lock_ref.py b/test/registered/unit/mem_cache/test_decode_radix_lock_ref.py new file mode 100644 index 000000000000..ca400ed77557 --- /dev/null +++ b/test/registered/unit/mem_cache/test_decode_radix_lock_ref.py @@ -0,0 +1,321 @@ +""" +Unit tests for lock_ref correctness in decode disagg radix cache scenarios. + +Verifies that inc_lock_ref / dec_lock_ref are balanced across the four +transfer scenarios identified in PR #19746: + +1. Incremental transfer & success (prefix match > 0) + inc_lock_ref(pop_preallocated) -> dec+inc(cache_unfinished_req) -> dec(cache_finished_req) + +2. Full transfer & success (prefix match == 0, full KV transferred) + inc_lock_ref(get_new_prebuilt_batch) -> dec+inc(cache_unfinished_req) -> dec(cache_finished_req) + +3. Incremental transfer & failure (prefix match > 0, transfer fails) + inc_lock_ref(pop_preallocated) -> dec(cache_finished_req via release_kv_cache is_insert=False) + +4. Full transfer & failure (prefix match == 0, transfer fails) + no inc_lock_ref -> dec(root_node) is no-op since root lock_ref starts at 1 + +Usage: + python -m pytest test/registered/unit/mem_cache/test_decode_radix_lock_ref.py -v +""" + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=10, suite="stage-b-test-1-gpu-small") + +import unittest +from unittest.mock import MagicMock + +import torch + +from sglang.srt.mem_cache.base_prefix_cache import ( + InsertParams, + MatchPrefixParams, +) +from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey + + +def _make_cache_with_pools(page_size=1): + """Create a RadixCache with mock pools sufficient for cache_unfinished/finished_req.""" + mock_allocator = MagicMock() + mock_allocator.device = torch.device("cpu") + + # req_to_token pool: stores kv indices per request slot + max_seq_len = 64 + max_batch = 4 + req_to_token = torch.zeros(max_batch, max_seq_len, dtype=torch.int64) + + mock_pool = MagicMock() + mock_pool.req_to_token = req_to_token + mock_pool.write = lambda idx_tuple, values: req_to_token.__setitem__( + idx_tuple, values + ) + + cache = RadixCache.create_simulated( + mock_allocator=mock_allocator, page_size=page_size + ) + cache.req_to_token_pool = mock_pool + return cache, req_to_token + + +class MockReq: + """Minimal mock Req with fields needed by cache_unfinished/finished_req.""" + + def __init__(self, fill_ids, req_pool_idx=0, cache_protected_len=0, last_node=None): + self.fill_ids = list(fill_ids) + self.origin_input_ids = list(fill_ids[:-1]) if len(fill_ids) > 1 else list(fill_ids) + self.output_ids = [fill_ids[-1]] if len(fill_ids) > 1 else [] + self.req_pool_idx = req_pool_idx + self.cache_protected_len = cache_protected_len + self.last_node = last_node + self.extra_key = None + self.prefix_indices = torch.empty(0, dtype=torch.int64) + self.priority = 0 + self.kv_committed_len = len(fill_ids) + self.kv_allocated_len = len(fill_ids) + self.kv_committed_freed = False + + def pop_committed_kv_cache(self): + self.kv_committed_freed = True + return self.kv_committed_len + + def pop_overallocated_kv_cache(self): + return (self.kv_committed_len, self.kv_allocated_len) + + +def _make_req(fill_ids, req_pool_idx=0, cache_protected_len=0, last_node=None): + return MockReq(fill_ids, req_pool_idx, cache_protected_len, last_node) + + +class TestDecodeLockRefScenarios(unittest.TestCase): + """Test lock_ref balance across decode transfer scenarios.""" + + def _populate_prefix(self, cache, prefix_ids, prefix_values): + """Insert a prefix into the tree so future requests can match it.""" + cache.insert( + InsertParams( + key=RadixKey(prefix_ids), + value=torch.tensor(prefix_values, dtype=torch.int64), + ) + ) + + def test_incremental_transfer_success(self): + """Scenario 1: prefix match > 0, transfer succeeds. + + Flow: inc_lock_ref(pop_preallocated) + -> dec_lock_ref + inc_lock_ref(cache_unfinished_req) + -> dec_lock_ref(cache_finished_req) + """ + cache, req_to_token = _make_cache_with_pools() + + # Pre-populate a prefix [1,2,3] in the tree + prefix = [1, 2, 3] + prefix_vals = [10, 20, 30] + self._populate_prefix(cache, prefix, prefix_vals) + + # Match prefix (simulates _match_prefix_and_lock in pop_preallocated) + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(prefix))) + matched_node = result.last_device_node + prefix_len = len(result.device_indices) + self.assertEqual(prefix_len, 3) + + # Step 1: inc_lock_ref (pop_preallocated locks the matched node) + cache.inc_lock_ref(matched_node) + self.assertGreater(matched_node.lock_ref, 0) + + # Simulate _pre_alloc: write prefix + new tokens to req_to_token + full_ids = [1, 2, 3, 4, 5] # prefix + 2 new tokens + full_vals = [10, 20, 30, 40, 50] + req_to_token[0, : len(full_vals)] = torch.tensor(full_vals, dtype=torch.int64) + + req = _make_req( + fill_ids=full_ids, + req_pool_idx=0, + cache_protected_len=prefix_len, + last_node=matched_node, + ) + + # Step 2: cache_unfinished_req (dec old lock, inc new lock) + cache.cache_unfinished_req(req) + + # Step 3: cache_finished_req with is_insert=True (dec lock) + cache.cache_finished_req(req) + + # Verify: all non-root nodes should have lock_ref == 0 + # (root always has lock_ref == 1) + self.assertEqual(cache.root_node.lock_ref, 1) + self.assertEqual(cache.protected_size(), 0) + # The evictable size should equal total inserted tokens + self.assertEqual(cache.evictable_size(), len(full_ids)) + + def test_full_transfer_success(self): + """Scenario 2: no prefix match, full KV transferred, succeeds. + + Flow: inc_lock_ref(root, via init_next_round_input/get_new_prebuilt_batch) + -> dec_lock_ref + inc_lock_ref(cache_unfinished_req) + -> dec_lock_ref(cache_finished_req) + """ + cache, req_to_token = _make_cache_with_pools() + + # No prefix in tree -- match returns root + full_ids = [10, 20, 30] + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(full_ids))) + matched_node = result.last_device_node + self.assertEqual(len(result.device_indices), 0) # no match + # matched_node is root + + root_lock_before = cache.root_node.lock_ref + # Step 1: inc_lock_ref on root (simulates get_new_prebuilt_batch) + # Note: inc/dec_lock_ref skip the root node (while node != root_node), + # so this is a no-op. Root always keeps lock_ref=1. + cache.inc_lock_ref(matched_node) + self.assertEqual(cache.root_node.lock_ref, root_lock_before) # no-op on root + + # Write full KV to pool + full_vals = [100, 200, 300] + req_to_token[0, : len(full_vals)] = torch.tensor(full_vals, dtype=torch.int64) + + req = _make_req( + fill_ids=full_ids, + req_pool_idx=0, + cache_protected_len=0, + last_node=matched_node, + ) + + # Step 2: cache_unfinished_req (dec root=no-op, inc new leaf) + cache.cache_unfinished_req(req) + + # Step 3: cache_finished_req (dec leaf) + cache.cache_finished_req(req) + + # Root lock unchanged, all nodes unlocked + self.assertEqual(cache.root_node.lock_ref, root_lock_before) + self.assertEqual(cache.protected_size(), 0) + self.assertEqual(cache.evictable_size(), len(full_ids)) + + def test_incremental_transfer_failure(self): + """Scenario 3: prefix match > 0, transfer fails. + + Flow: inc_lock_ref(pop_preallocated) + -> dec_lock_ref(cache_finished_req via release_kv_cache is_insert=False) + """ + cache, req_to_token = _make_cache_with_pools() + + # Pre-populate prefix + prefix = [1, 2, 3] + prefix_vals = [10, 20, 30] + self._populate_prefix(cache, prefix, prefix_vals) + + # Match and lock + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(prefix))) + matched_node = result.last_device_node + prefix_len = len(result.device_indices) + + cache.inc_lock_ref(matched_node) + # Prefix tokens should now be protected (locked) + self.assertGreater(cache.protected_size(), 0) + + # Simulate _pre_alloc with additional tokens + full_ids = [1, 2, 3, 4, 5] + full_vals = [10, 20, 30, 40, 50] + req_to_token[0, : len(full_vals)] = torch.tensor(full_vals, dtype=torch.int64) + + req = _make_req( + fill_ids=full_ids, + req_pool_idx=0, + cache_protected_len=prefix_len, + last_node=matched_node, + ) + + # Transfer fails -> cache_finished_req with is_insert=False + # This frees delta tokens and dec_lock_ref on last_node + cache.cache_finished_req(req, is_insert=False) + + # The prefix node should be unlocked (back to evictable) + self.assertEqual(cache.root_node.lock_ref, 1) + self.assertEqual(cache.protected_size(), 0) + # Prefix tokens should still be in tree and evictable + self.assertEqual(cache.evictable_size(), len(prefix)) + + def test_full_transfer_failure(self): + """Scenario 4: no prefix match, transfer fails. + + Flow: _match_prefix_and_lock sets last_node=root and calls + inc_lock_ref(root) which is a no-op. On failure, + cache_finished_req calls dec_lock_ref(root) which is also + a no-op. Net: balanced. + """ + cache, req_to_token = _make_cache_with_pools() + + root_lock_before = cache.root_node.lock_ref + + # No prefix in tree -- match returns root (simulates _match_prefix_and_lock) + full_ids = [10, 20, 30] + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(full_ids))) + matched_node = result.last_device_node + self.assertIs(matched_node, cache.root_node) + + # inc_lock_ref(root) is a no-op + cache.inc_lock_ref(matched_node) + self.assertEqual(cache.root_node.lock_ref, root_lock_before) + + full_vals = [100, 200, 300] + req_to_token[0, : len(full_vals)] = torch.tensor(full_vals, dtype=torch.int64) + + # last_node = root (as set by _match_prefix_and_lock) + req = _make_req( + fill_ids=full_ids, + req_pool_idx=0, + cache_protected_len=0, + last_node=matched_node, + ) + + # Transfer fails -> cache_finished_req with is_insert=False + # dec_lock_ref(root) is a no-op + cache.cache_finished_req(req, is_insert=False) + + # Root lock unchanged, nothing protected or evictable + self.assertEqual(cache.root_node.lock_ref, root_lock_before) + self.assertEqual(cache.protected_size(), 0) + self.assertEqual(cache.evictable_size(), 0) + + def test_repeated_incremental_no_leak(self): + """Multiple incremental transfers shouldn't leak lock_refs.""" + cache, req_to_token = _make_cache_with_pools() + + prefix = [1, 2, 3] + prefix_vals = [10, 20, 30] + self._populate_prefix(cache, prefix, prefix_vals) + + for iteration in range(5): + result = cache.match_prefix(MatchPrefixParams(key=RadixKey(prefix))) + matched_node = result.last_device_node + prefix_len = len(result.device_indices) + + cache.inc_lock_ref(matched_node) + + suffix_token = 40 + iteration + full_ids = prefix + [suffix_token] + full_vals = prefix_vals + [100 + iteration] + req_to_token[0, : len(full_vals)] = torch.tensor( + full_vals, dtype=torch.int64 + ) + + req = _make_req( + fill_ids=full_ids, + req_pool_idx=0, + cache_protected_len=prefix_len, + last_node=matched_node, + ) + + cache.cache_unfinished_req(req) + cache.cache_finished_req(req) + + # After all iterations, root lock should be 1, no protected nodes + self.assertEqual(cache.root_node.lock_ref, 1) + self.assertEqual(cache.protected_size(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/server_args/test_server_args.py b/test/registered/unit/server_args/test_server_args.py index 29b962b74183..dafc6abe5b49 100644 --- a/test/registered/unit/server_args/test_server_args.py +++ b/test/registered/unit/server_args/test_server_args.py @@ -47,6 +47,22 @@ def test_pd_decode_defaults_to_round_robin(self): server_args = ServerArgs(model_path="dummy", disaggregation_mode="decode") self.assertEqual(server_args.load_balance_method, "round_robin") + def test_pd_decode_radix_cache_rejects_hisparse(self): + with self.assertRaises(ValueError) as context: + ServerArgs( + model_path="dummy", + disaggregation_mode="decode", + disaggregation_decode_enable_radix_cache=True, + disaggregation_transfer_backend="nixl", + enable_hisparse=True, + ) + + self.assertIn( + "--disaggregation-decode-enable-radix-cache is incompatible with " + "--enable-hisparse", + str(context.exception), + ) + class TestPortArgs(unittest.TestCase): @patch("sglang.srt.server_args.get_free_port")