Skip to content
18 changes: 13 additions & 5 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,19 @@ def try_ensure_parallel_info(self, bootstrap_addr: str) -> bool:

# Sanity checks
if info.page_size is not None and info.page_size != self.kv_args.page_size:
raise RuntimeError(
f"Page size mismatch: prefill server has page_size={info.page_size}, "
f"but decode server has page_size={self.kv_args.page_size}. "
f"Both servers must use the same --page-size value."
)
if self.server_args.enable_hisparse:
# HiSparse: decode host pool page_size=1, prefill device pool page_size >= 1.
# Transfer will use send_kvcache_hisparse with per-token item_lens.
logger.info(
f"HiSparse PD transfer mode: prefill page_size={info.page_size}, "
f"decode host page_size={self.kv_args.page_size}"
)
else:
raise RuntimeError(
f"Page size mismatch: prefill server has page_size={info.page_size}, "
f"but decode server has page_size={self.kv_args.page_size}. "
f"Both servers must use the same --page-size value."
)

if (
info.kv_cache_dtype is not None
Expand Down
92 changes: 76 additions & 16 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch.distributed import ProcessGroup

Expand Down Expand Up @@ -312,9 +313,16 @@ def _init_kv_manager(self) -> CommonKVManager:

kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.scheduler.enable_hisparse:
# Direct-to-host: register host pool pointers so P writes to D's host memory
host_pool = self.scheduler.hisparse_coordinator.mem_pool_host
kv_data_ptrs, kv_data_lens, kv_item_lens = (
host_pool.get_contiguous_buf_infos()
)
else:
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
Expand All @@ -328,7 +336,10 @@ def _init_kv_manager(self) -> CommonKVManager:
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.page_size = self.token_to_kv_pool.page_size
# HiSparse Host pool has page_size=1; use it when hisparse is enabled
kv_args.page_size = (
1 if self.scheduler.enable_hisparse else self.token_to_kv_pool.page_size
)

kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
Expand Down Expand Up @@ -698,16 +709,21 @@ def pop_preallocated(
break

allocatable_tokens -= required_tokens_for_request
self._pre_alloc(decode_req.req)
dst_kv_indices = self._pre_alloc(decode_req.req)

kv_indices = (
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
: len(decode_req.req.origin_input_ids)
]
.cpu()
.numpy()
)
page_size = self.token_to_kv_pool_allocator.page_size
origin_input_len = len(decode_req.req.origin_input_ids)
if self.scheduler.enable_hisparse:
# Must cast to int32 for ZMQ serialization — from_zmq reads np.int32.
kv_indices = (
dst_kv_indices[:origin_input_len].cpu().numpy().astype(np.int32)
)
page_size = 1 # host pool page_size
else:
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx
][:origin_input_len]
kv_indices = kv_indices_full.cpu().numpy()
page_size = self.token_to_kv_pool_allocator.page_size

# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
Expand Down Expand Up @@ -744,7 +760,9 @@ def pop_preallocated(
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
# Indexer lives on device pool; always use device page_size
device_page_size = self.token_to_kv_pool.page_size
state_indices = kv_to_page_indices(state_indices, device_page_size)
else:
state_indices = None

Expand Down Expand Up @@ -841,7 +859,30 @@ def _pre_alloc(self, req: Req) -> torch.Tensor:
fill_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
req.kv_allocated_len = fill_len
req.kv_committed_len = fill_len
if self.token_to_kv_pool_allocator.page_size == 1:

if self.scheduler.enable_hisparse:
# Direct-to-host path: only allocate logical indices (no hisparse
# device indices) and allocate host indices for RDMA destination.
coordinator = self.scheduler.hisparse_coordinator
device = self.token_to_kv_pool_allocator.device
kv_loc = self.token_to_kv_pool_allocator.alloc_logical_only(
prefix_lens=torch.tensor([0], dtype=torch.int64, device=device),
prefix_lens_cpu=torch.tensor([0], dtype=torch.int64),
seq_lens=torch.tensor([fill_len], dtype=torch.int64, device=device),
seq_lens_cpu=torch.tensor([fill_len], dtype=torch.int64),
last_loc=torch.tensor([-1], dtype=torch.int64, device=device),
extend_num_tokens=fill_len,
)
# Allocate host indices for the RDMA transfer target
host_indices = coordinator.mem_pool_host.alloc(fill_len)
if host_indices is None:
raise RuntimeError(
f"HiSparse host mem pool alloc failed for {fill_len} tokens "
f"in _pre_alloc (req {req.rid})"
)
host_indices = host_indices.to(device=coordinator.device)
coordinator.req_to_host_pool[req.req_pool_idx, :fill_len] = host_indices
elif self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(fill_len)
else:
device = self.token_to_kv_pool_allocator.device
Expand All @@ -864,6 +905,9 @@ def _pre_alloc(self, req: Req) -> torch.Tensor:
req.fill_ids = req.origin_input_ids + req.output_ids
req.set_extend_input_len(len(req.fill_ids))

# Return the transfer destination indices:
if self.scheduler.enable_hisparse:
return host_indices
return kv_loc


Expand Down Expand Up @@ -1034,6 +1078,8 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
if self.scheduler.enable_hisparse:
self.scheduler.hisparse_coordinator.request_finished(decode_req.req)
# release pre-allocated kv cache, but don't insert into the tree since it's failed
release_kv_cache(decode_req.req, self.tree_cache, is_insert=False)
indices_to_remove.add(i)
Expand All @@ -1049,6 +1095,10 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
if self.scheduler.enable_hisparse:
self.scheduler.hisparse_coordinator.request_finished(
decode_req.req
)
release_kv_cache(
decode_req.req, self.tree_cache, is_insert=False
)
Expand Down Expand Up @@ -1172,6 +1222,10 @@ def get_next_disagg_decode_batch_to_run(
if not new_prebuilt_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = new_prebuilt_batch
if self.enable_hisparse:
self.running_batch.hisparse_coordinator = (
self.hisparse_coordinator
)
else:
self.running_batch.merge_batch(new_prebuilt_batch)

Expand Down Expand Up @@ -1264,4 +1318,10 @@ def process_decode_queue(self: Scheduler):
transferred_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(transferred_reqs)
if self.enable_hisparse:
for req in transferred_reqs:
# Direct-to-host: KV data already in host pool, skip staging
self.hisparse_coordinator.admit_request_direct(req)
self.waiting_queue.extend(transferred_reqs)
else:
self.waiting_queue.extend(transferred_reqs)
76 changes: 68 additions & 8 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class KVArgsRegisterInfo:
# for mamba state different tp slice transfer
dst_state_item_lens: list[int]
dst_state_dim_per_tensor: list[int]
# HiSparse: decode host pool stores KV at token granularity
enable_hisparse: bool = False
Comment on lines +130 to +131
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be updated too. We keep the Optional params at the behind

staging: Optional[StagingRegisterInfo] = None

@classmethod
Expand All @@ -152,7 +154,10 @@ def from_zmq(cls, msg: List[bytes]):
if len(msg) > 11 and len(msg[11]) > 0
else []
),
staging=StagingRegisterInfo.from_zmq_fields(msg, 12),
enable_hisparse=(
msg[12].decode("ascii") == "1" if len(msg) > 12 else False
),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move enable_hisparse before staging. We leave the large one (or maybe the optional one) behind; the flags can be put in the front.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please double check on this line:

struct.unpack("Q", msg[i])[0] if len(msg) > i and len(msg[i]) == 8 else 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC: @YAMY1234

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okko

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShangmingCai updated

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

staging=StagingRegisterInfo.from_zmq_fields(msg, 13),
)


Expand Down Expand Up @@ -694,6 +699,49 @@ def send_kvcache(
executor=executor,
)

def send_kvcache_hisparse(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
page_index_slice: slice,
executor: concurrent.futures.ThreadPoolExecutor,
):
"""HiSparse transfer: prefill page_size > decode host page_size=1.

Receives page-level prefill_kv_indices and the full token-level
dst_kv_indices. Expands both to token granularity before transfer.
"""
page_size = self.kv_args.page_size
per_token_item_lens = [il // page_size for il in self.kv_args.kv_item_lens]

# Expand page-level src indices to token-level
base = np.repeat(prefill_kv_indices * page_size, page_size)
offsets = np.tile(np.arange(page_size, dtype=np.int32), len(prefill_kv_indices))
expanded_src = base + offsets

# Expand page-level index_slice to token-level for dst
token_start = page_index_slice.start * page_size
token_end = min(page_index_slice.stop * page_size, len(dst_kv_indices))
expanded_dst = dst_kv_indices[token_start:token_end]

# Clip src to match dst length (last page may be partial)
expanded_src = expanded_src[: len(expanded_dst)]

logger.debug(
f"Send KVCache for hisparse: {expanded_src.shape} -> {expanded_dst.shape}"
)
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=per_token_item_lens,
prefill_data_indices=expanded_src,
dst_data_indices=expanded_dst,
executor=executor,
)

def send_kvcache_slice(
self,
mooncake_session_id: str,
Expand Down Expand Up @@ -1165,13 +1213,23 @@ def transfer_worker(
self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
):
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
if target_rank_registration_info.enable_hisparse:
ret = self.send_kvcache_hisparse(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
req.dst_kv_indices,
kv_chunk.index_slice,
executor,
)
else:
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
elif (
self.enable_staging
and staging_strategy is not None
Expand Down Expand Up @@ -1715,6 +1773,7 @@ def _register_kv_args(self):
dst_tp_rank = str(tp_rank).encode("ascii")
dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")
enable_hisparse = b"1" if self.kv_mgr.server_args.enable_hisparse else b"0"

if (
self.kv_mgr.enable_staging
Expand Down Expand Up @@ -1743,6 +1802,7 @@ def _register_kv_args(self):
dst_kv_item_len,
packed_state_item_lens,
packed_state_dim_per_tensor,
enable_hisparse,
packed_staging_base_ptr,
staging_total_size_str,
]
Expand Down
49 changes: 48 additions & 1 deletion python/sglang/srt/managers/hisparse_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
override_kv_cache_dim=self.mem_pool_device.kv_cache_dim,
)

max_num_reqs = req_to_token_pool.size
max_num_reqs = req_to_token_pool.req_to_token.shape[0]
max_context_len = req_to_token_pool.max_context_len

# to have an extra page for new tokens
Expand Down Expand Up @@ -161,6 +161,53 @@ def admit_request_into_staging(self, req: Req) -> None:

self.ack_staging_queue.append(HiSparseAct(start_event, finish_event, req))

def admit_request_direct(self, req: Req) -> None:
"""Direct-to-host path: KV data already resides in host pool via RDMA.

Skips staging DMA entirely. Only allocates a small device buffer
(4KB) for decode-time swap-in, then marks the request as ready.
Host indices were already written to req_to_host_pool.

Metadata fixups after alloc_device_buffer():
- alloc_device_buffer() sets device_buffer_tokens = [0, 1, ..., buf_size-1],
which tells the swap-in kernel that those tokens are cached in the device
buffer. In the staging path this is correct (prefill filled the buffer),
but here the buffer is empty.
"""
self.alloc_device_buffer(req)

if req.kv_allocated_len <= self.device_buffer_size:
# Short sequences (seq_len <= device_buffer_size): the kernel fast path
# returns device_buffer_locs directly without any host loading, so we
# must preload all tokens from host pool into the device buffer
# TODO(hzh0425): Optimize this.
self._preload_to_device_buffer(req)
else:
# Long sequence: reset device_buffer_tokens to -1 so the kernel
# sees all slots as empty → every top-k lookup is a miss → host load.
self.req_device_buffer_tokens[
:, req.req_pool_idx, : self.device_buffer_size
] = -1

req.staging = False
self._skip_first_backup[req.req_pool_idx] = True
logger.debug("HiSparse: admitting request %s directly", req.rid)

def _preload_to_device_buffer(self, req: Req) -> None:
"""Preload all tokens from host pool into the device buffer."""
n = req.kv_allocated_len
host_indices = self.req_to_host_pool[req.req_pool_idx, :n]
device_locs = self.req_to_device_buffer[req.req_pool_idx, :n]

for layer_id in range(self.mem_pool_device.layer_num):
self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_device,
host_indices,
device_locs,
layer_id,
io_backend="kernel",
)

def alloc_device_buffer(self, req: Req) -> None:
allocated_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : req.kv_allocated_len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def check_tree_cache(self: Scheduler):
self.tree_cache.sanity_check()

def self_check_during_idle(self: Scheduler):
if self.enable_hisparse and self.hisparse_coordinator.has_ongoing_staging():
return
if self.disaggregation_mode == DisaggregationMode.PREFILL:
if len(self.disagg_prefill_inflight_queue) > 0:
return
Expand All @@ -371,9 +373,6 @@ def self_check_during_idle(self: Scheduler):
queue_size += len(self.decode_offload_manager.ongoing_offload)
if queue_size:
return
elif self.enable_hisparse:
if self.hisparse_coordinator.has_ongoing_staging():
return

self.check_memory()
self.check_tree_cache()
Expand Down
Loading
Loading