diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 3269226459d8..26752d52dd54 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -217,11 +217,19 @@ def try_ensure_parallel_info(self, bootstrap_addr: str) -> bool: # Sanity checks if info.page_size is not None and info.page_size != self.kv_args.page_size: - raise RuntimeError( - f"Page size mismatch: prefill server has page_size={info.page_size}, " - f"but decode server has page_size={self.kv_args.page_size}. " - f"Both servers must use the same --page-size value." - ) + if self.server_args.enable_hisparse: + # HiSparse: decode host pool page_size=1, prefill device pool page_size >= 1. + # Transfer will use send_kvcache_hisparse with per-token item_lens. + logger.info( + f"HiSparse PD transfer mode: prefill page_size={info.page_size}, " + f"decode host page_size={self.kv_args.page_size}" + ) + else: + raise RuntimeError( + f"Page size mismatch: prefill server has page_size={info.page_size}, " + f"but decode server has page_size={self.kv_args.page_size}. " + f"Both servers must use the same --page-size value." + ) if ( info.kv_cache_dtype is not None diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 1af8afb98e5b..f54c882cc2ec 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -27,6 +27,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import numpy as np import torch from torch.distributed import ProcessGroup @@ -312,9 +313,16 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.pp_rank = self.pp_rank kv_args.system_dp_rank = self.scheduler.dp_rank - kv_data_ptrs, kv_data_lens, kv_item_lens = ( - self.token_to_kv_pool.get_contiguous_buf_infos() - ) + if self.scheduler.enable_hisparse: + # Direct-to-host: register host pool pointers so P writes to D's host memory + host_pool = self.scheduler.hisparse_coordinator.mem_pool_host + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + host_pool.get_contiguous_buf_infos() + ) + else: + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + self.token_to_kv_pool.get_contiguous_buf_infos() + ) if self.draft_token_to_kv_pool is not None: # We should also transfer draft model kv cache. The indices are # always shared with a target model. @@ -328,7 +336,10 @@ def _init_kv_manager(self) -> CommonKVManager: kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens - kv_args.page_size = self.token_to_kv_pool.page_size + # HiSparse Host pool has page_size=1; use it when hisparse is enabled + kv_args.page_size = ( + 1 if self.scheduler.enable_hisparse else self.token_to_kv_pool.page_size + ) kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() @@ -698,16 +709,21 @@ def pop_preallocated( break allocatable_tokens -= required_tokens_for_request - self._pre_alloc(decode_req.req) + dst_kv_indices = self._pre_alloc(decode_req.req) - kv_indices = ( - self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ - : len(decode_req.req.origin_input_ids) - ] - .cpu() - .numpy() - ) - page_size = self.token_to_kv_pool_allocator.page_size + origin_input_len = len(decode_req.req.origin_input_ids) + if self.scheduler.enable_hisparse: + # Must cast to int32 for ZMQ serialization — from_zmq reads np.int32. + kv_indices = ( + dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32) + ) + page_size = 1 # host pool page_size + else: + kv_indices_full = self.req_to_token_pool.req_to_token[ + decode_req.req.req_pool_idx + ][:origin_input_len] + kv_indices = kv_indices_full.cpu().numpy() + page_size = self.token_to_kv_pool_allocator.page_size # Prepare extra pool indices for hybrid models if isinstance(self.token_to_kv_pool, HybridLinearKVPool): @@ -744,7 +760,9 @@ def pop_preallocated( decode_req.req.req_pool_idx, :seq_len ] state_indices = kv_indices_full.cpu().numpy() - state_indices = kv_to_page_indices(state_indices, page_size) + # Indexer lives on device pool; always use device page_size + device_page_size = self.token_to_kv_pool.page_size + state_indices = kv_to_page_indices(state_indices, device_page_size) else: state_indices = None @@ -841,7 +859,30 @@ def _pre_alloc(self, req: Req) -> torch.Tensor: fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) req.kv_allocated_len = fill_len req.kv_committed_len = fill_len - if self.token_to_kv_pool_allocator.page_size == 1: + + if self.scheduler.enable_hisparse: + # Direct-to-host path: only allocate logical indices (no hisparse + # device indices) and allocate host indices for RDMA destination. + coordinator = self.scheduler.hisparse_coordinator + device = self.token_to_kv_pool_allocator.device + kv_loc = self.token_to_kv_pool_allocator.alloc_logical_only( + prefix_lens=torch.tensor([0], dtype=torch.int64, device=device), + prefix_lens_cpu=torch.tensor([0], dtype=torch.int64), + seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device), + seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64), + last_loc=torch.tensor([-1], dtype=torch.int64, device=device), + extend_num_tokens=fill_len, + ) + # Allocate host indices for the RDMA transfer target + host_indices = coordinator.mem_pool_host.alloc(fill_len) + if host_indices is None: + raise RuntimeError( + f"HiSparse host mem pool alloc failed for {fill_len} tokens " + f"in _pre_alloc (req {req.rid})" + ) + host_indices = host_indices.to(device=coordinator.device) + coordinator.req_to_host_pool[req.req_pool_idx, :fill_len] = host_indices + elif self.token_to_kv_pool_allocator.page_size == 1: kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len) else: device = self.token_to_kv_pool_allocator.device @@ -864,6 +905,9 @@ def _pre_alloc(self, req: Req) -> torch.Tensor: req.fill_ids = req.origin_input_ids + req.output_ids req.set_extend_input_len(len(req.fill_ids)) + # Return the transfer destination indices: + if self.scheduler.enable_hisparse: + return host_indices return kv_loc @@ -1034,6 +1078,8 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) + if self.scheduler.enable_hisparse: + self.scheduler.hisparse_coordinator.request_finished(decode_req.req) # release pre-allocated kv cache, but don't insert into the tree since it's failed release_kv_cache(decode_req.req, self.tree_cache, is_insert=False) indices_to_remove.add(i) @@ -1049,6 +1095,10 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req self.scheduler.stream_output( [decode_req.req], decode_req.req.return_logprob ) + if self.scheduler.enable_hisparse: + self.scheduler.hisparse_coordinator.request_finished( + decode_req.req + ) release_kv_cache( decode_req.req, self.tree_cache, is_insert=False ) @@ -1172,6 +1222,10 @@ def get_next_disagg_decode_batch_to_run( if not new_prebuilt_batch.is_empty(): if self.running_batch.is_empty(): self.running_batch = new_prebuilt_batch + if self.enable_hisparse: + self.running_batch.hisparse_coordinator = ( + self.hisparse_coordinator + ) else: self.running_batch.merge_batch(new_prebuilt_batch) @@ -1264,4 +1318,10 @@ def process_decode_queue(self: Scheduler): transferred_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived - self.waiting_queue.extend(transferred_reqs) + if self.enable_hisparse: + for req in transferred_reqs: + # Direct-to-host: KV data already in host pool, skip staging + self.hisparse_coordinator.admit_request_direct(req) + self.waiting_queue.extend(transferred_reqs) + else: + self.waiting_queue.extend(transferred_reqs) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index e4b23ed81229..fc74d555dbd0 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -127,6 +127,8 @@ class KVArgsRegisterInfo: # for mamba state different tp slice transfer dst_state_item_lens: list[int] dst_state_dim_per_tensor: list[int] + # HiSparse: decode host pool stores KV at token granularity + enable_hisparse: bool = False staging: Optional[StagingRegisterInfo] = None @classmethod @@ -152,7 +154,10 @@ def from_zmq(cls, msg: List[bytes]): if len(msg) > 11 and len(msg[11]) > 0 else [] ), - staging=StagingRegisterInfo.from_zmq_fields(msg, 12), + enable_hisparse=( + msg[12].decode("ascii") == "1" if len(msg) > 12 else False + ), + staging=StagingRegisterInfo.from_zmq_fields(msg, 13), ) @@ -694,6 +699,49 @@ def send_kvcache( executor=executor, ) + def send_kvcache_hisparse( + self, + mooncake_session_id: str, + prefill_kv_indices: npt.NDArray[np.int32], + dst_kv_ptrs: list[int], + dst_kv_indices: npt.NDArray[np.int32], + page_index_slice: slice, + executor: concurrent.futures.ThreadPoolExecutor, + ): + """HiSparse transfer: prefill page_size > decode host page_size=1. + + Receives page-level prefill_kv_indices and the full token-level + dst_kv_indices. Expands both to token granularity before transfer. + """ + page_size = self.kv_args.page_size + per_token_item_lens = [il // page_size for il in self.kv_args.kv_item_lens] + + # Expand page-level src indices to token-level + base = np.repeat(prefill_kv_indices * page_size, page_size) + offsets = np.tile(np.arange(page_size, dtype=np.int32), len(prefill_kv_indices)) + expanded_src = base + offsets + + # Expand page-level index_slice to token-level for dst + token_start = page_index_slice.start * page_size + token_end = min(page_index_slice.stop * page_size, len(dst_kv_indices)) + expanded_dst = dst_kv_indices[token_start:token_end] + + # Clip src to match dst length (last page may be partial) + expanded_src = expanded_src[: len(expanded_dst)] + + logger.debug( + f"Send KVCache for hisparse: {expanded_src.shape} -> {expanded_dst.shape}" + ) + return self._send_kvcache_generic( + mooncake_session_id=mooncake_session_id, + src_data_ptrs=self.kv_args.kv_data_ptrs, + dst_data_ptrs=dst_kv_ptrs, + item_lens=per_token_item_lens, + prefill_data_indices=expanded_src, + dst_data_indices=expanded_dst, + executor=executor, + ) + def send_kvcache_slice( self, mooncake_session_id: str, @@ -1165,13 +1213,23 @@ def transfer_worker( self.attn_tp_size == target_rank_registration_info.dst_attn_tp_size ): - ret = self.send_kvcache( - req.mooncake_session_id, - kv_chunk.prefill_kv_indices, - target_rank_registration_info.dst_kv_ptrs, - chunked_dst_kv_indice, - executor, - ) + if target_rank_registration_info.enable_hisparse: + ret = self.send_kvcache_hisparse( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + req.dst_kv_indices, + kv_chunk.index_slice, + executor, + ) + else: + ret = self.send_kvcache( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + target_rank_registration_info.dst_kv_ptrs, + chunked_dst_kv_indice, + executor, + ) elif ( self.enable_staging and staging_strategy is not None @@ -1715,6 +1773,7 @@ def _register_kv_args(self): dst_tp_rank = str(tp_rank).encode("ascii") dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii") + enable_hisparse = b"1" if self.kv_mgr.server_args.enable_hisparse else b"0" if ( self.kv_mgr.enable_staging @@ -1743,6 +1802,7 @@ def _register_kv_args(self): dst_kv_item_len, packed_state_item_lens, packed_state_dim_per_tensor, + enable_hisparse, packed_staging_base_ptr, staging_total_size_str, ] diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index eda985d4f80c..89740f73682e 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -56,7 +56,7 @@ def __init__( override_kv_cache_dim=self.mem_pool_device.kv_cache_dim, ) - max_num_reqs = req_to_token_pool.size + max_num_reqs = req_to_token_pool.req_to_token.shape[0] max_context_len = req_to_token_pool.max_context_len # to have an extra page for new tokens @@ -161,6 +161,53 @@ def admit_request_into_staging(self, req: Req) -> None: self.ack_staging_queue.append(HiSparseAct(start_event, finish_event, req)) + def admit_request_direct(self, req: Req) -> None: + """Direct-to-host path: KV data already resides in host pool via RDMA. + + Skips staging DMA entirely. Only allocates a small device buffer + (4KB) for decode-time swap-in, then marks the request as ready. + Host indices were already written to req_to_host_pool. + + Metadata fixups after alloc_device_buffer(): + - alloc_device_buffer() sets device_buffer_tokens = [0, 1, ..., buf_size-1], + which tells the swap-in kernel that those tokens are cached in the device + buffer. In the staging path this is correct (prefill filled the buffer), + but here the buffer is empty. + """ + self.alloc_device_buffer(req) + + if req.kv_allocated_len <= self.device_buffer_size: + # Short sequences (seq_len <= device_buffer_size): the kernel fast path + # returns device_buffer_locs directly without any host loading, so we + # must preload all tokens from host pool into the device buffer + # TODO(hzh0425): Optimize this. + self._preload_to_device_buffer(req) + else: + # Long sequence: reset device_buffer_tokens to -1 so the kernel + # sees all slots as empty → every top-k lookup is a miss → host load. + self.req_device_buffer_tokens[ + :, req.req_pool_idx, : self.device_buffer_size + ] = -1 + + req.staging = False + self._skip_first_backup[req.req_pool_idx] = True + logger.debug("HiSparse: admitting request %s directly", req.rid) + + def _preload_to_device_buffer(self, req: Req) -> None: + """Preload all tokens from host pool into the device buffer.""" + n = req.kv_allocated_len + host_indices = self.req_to_host_pool[req.req_pool_idx, :n] + device_locs = self.req_to_device_buffer[req.req_pool_idx, :n] + + for layer_id in range(self.mem_pool_device.layer_num): + self.mem_pool_host.load_to_device_per_layer( + self.mem_pool_device, + host_indices, + device_locs, + layer_id, + io_backend="kernel", + ) + def alloc_device_buffer(self, req: Req) -> None: allocated_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : req.kv_allocated_len diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 8d01f7792583..113073e3bd93 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -358,6 +358,8 @@ def check_tree_cache(self: Scheduler): self.tree_cache.sanity_check() def self_check_during_idle(self: Scheduler): + if self.enable_hisparse and self.hisparse_coordinator.has_ongoing_staging(): + return if self.disaggregation_mode == DisaggregationMode.PREFILL: if len(self.disagg_prefill_inflight_queue) > 0: return @@ -371,9 +373,6 @@ def self_check_during_idle(self: Scheduler): queue_size += len(self.decode_offload_manager.ongoing_offload) if queue_size: return - elif self.enable_hisparse: - if self.hisparse_coordinator.has_ongoing_staging(): - return self.check_memory() self.check_tree_cache() diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py index 5af8d257ad6b..0f2a53917175 100644 --- a/python/sglang/srt/mem_cache/hisparse_memory_pool.py +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -193,11 +193,38 @@ def alloc(self, need_size: int): "Page size = 1 is not supported in HiSparse allocator" ) + def alloc_logical_only( + self, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + """Allocate only logical indices without hisparse device indices. + + Used in the direct-to-host transfer path where KV data is written + directly to host memory by the prefill node, skipping GPU staging. + """ + return self.logical_attn_allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + def alloc_device_buffer(self, allocated_indices, need_size: int): assert need_size % self.page_size == 0 # clear original reference and isolate the buffer from outside addressing, allocate new buffer if needed hisparse_indices = self.full_to_hisparse_device_index_mapping[allocated_indices] self.full_to_hisparse_device_index_mapping[allocated_indices] = 0 + # Filter valid (non-zero) hisparse indices. + # In the direct-to-host path, mapping is all zeros since no hisparse + # device indices were pre-allocated. + hisparse_indices = hisparse_indices[hisparse_indices > 0] if len(hisparse_indices) >= need_size: buffer_indices = hisparse_indices[:need_size] self.free_hisparse_indices(hisparse_indices[need_size:]) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 9666080d3f72..1a9708c41430 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -826,6 +826,14 @@ def __init__( device=self.device_pool.device, ) + def get_contiguous_buf_infos(self): + """Return (data_ptrs, data_lens, item_lens) in the same format as device pool, + for registering host memory with the disaggregation transfer engine.""" + data_ptrs = [int(self.data_ptrs[i].item()) for i in range(self.layer_num)] + data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] + item_lens = [self.token_stride_size] * self.layer_num + return data_ptrs, data_lens, item_lens + def get_size_per_token(self): self.kv_lora_rank = self.device_pool.kv_lora_rank self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim