Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: int = 0,
):
"""
Notify the prefill server about the kv indices, aux index, and state_indices.

Args:
decode_prefix_len: Number of tokens already cached on the decode side.
When > 0, kv_indices contains only the incremental portion
(beyond the cached prefix), and the prefill side should skip
transferring the first decode_prefix_len tokens.
"""
...

Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ def _resolve_rank_mapping(self, info: PrefillServerInfo) -> None:
info.required_dst_info_num = required_dst_info_num
info.required_prefill_response_num = required_prefill_response_num

def get_decode_prefix_len(self, bootstrap_room: int) -> int:
"""Get the decode_prefix_len for a given bootstrap_room from transfer_infos.

Returns the max decode_prefix_len across all session/agent entries for the room.
Returns 0 if not available (backward compatible).
"""
if not hasattr(self, "transfer_infos"):
return 0
room_infos = self.transfer_infos.get(bootstrap_room, {})
if not room_infos:
return 0
return max(
getattr(info, "decode_prefix_len", 0) for info in room_infos.values()
)

def register_to_bootstrap(self):
"""Register prefill server info to bootstrap server via HTTP POST."""
if self.dist_init_addr:
Expand Down
367 changes: 300 additions & 67 deletions python/sglang/srt/disaggregation/decode.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def prepare_for_prebuilt(self: ScheduleBatch):
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)

# Read KV indices for the extend portion (after prefix)
pre_len_i = len(req.prefix_indices)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
pre_len_i : pre_len_i + req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
Expand All @@ -60,7 +62,9 @@ def prepare_for_prebuilt(self: ScheduleBatch):
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"

if not req.retracted_stain:
req.cached_tokens += pre_len - req.already_computed
# In disagg decode, cached_tokens is already set by
# pop_transferred from the prefill side's metadata.
# Don't add decode-side prefix match to avoid double-counting.
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def send_metadata(
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: int = 0,
):
self.has_sent_metadata = True
logger.debug(
f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
f"FakeKVReceiver send_metadata with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}, decode_prefix_len: {decode_prefix_len}"
)

def failure_exception(self):
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TransferInfo:
dst_state_indices: List[int]
required_dst_info_num: int
is_dummy: bool
decode_prefix_len: int = 0
# Note: always put the optional staging field at the final (it will be set through 'STAGING_RSP' pkg when needed)
staging: Optional[StagingTransferInfo] = None

Expand All @@ -99,6 +100,10 @@ def from_zmq(cls, msg: List[bytes]):
else:
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
is_dummy = False
# decode_prefix_len: backward compatible, default 0 if not present
decode_prefix_len = (
int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else 0
)
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
Expand All @@ -109,6 +114,7 @@ def from_zmq(cls, msg: List[bytes]):
dst_state_indices=dst_state_indices,
required_dst_info_num=int(msg[7].decode("ascii")),
is_dummy=is_dummy,
decode_prefix_len=decode_prefix_len,
)


Expand Down Expand Up @@ -1822,6 +1828,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: int = 0,
):
if self.bootstrap_infos is None:
self.kv_mgr.record_failure(
Expand Down Expand Up @@ -1862,6 +1869,7 @@ def send_metadata(
else b""
),
str(self.required_dst_info_num).encode("ascii"),
str(decode_prefix_len).encode("ascii"),
]
)
self.init_time = time.time()
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/disaggregation/mori/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class TransferInfo:
dst_aux_index: int
required_dst_info_num: int
is_dummy: bool
decode_prefix_len: int = 0

@classmethod
def from_zmq(cls, payload: List[bytes]) -> TransferInfo:
Expand All @@ -90,6 +91,12 @@ def from_zmq(cls, payload: List[bytes]) -> TransferInfo:
int(payload[7].decode("ascii")) if len(payload) > 7 else 1
)
is_dummy = dst_kv_indices.size == 0 and dst_aux_index < 0
# decode_prefix_len: backward compatible, default 0 if not present
decode_prefix_len = (
int(payload[8].decode("ascii"))
if len(payload) > 8 and payload[8] != b""
else 0
)
return cls(
room=room,
endpoint=endpoint,
Expand All @@ -99,6 +106,7 @@ def from_zmq(cls, payload: List[bytes]) -> TransferInfo:
dst_aux_index=dst_aux_index,
required_dst_info_num=required_dst_info_num,
is_dummy=is_dummy,
decode_prefix_len=decode_prefix_len,
)


Expand Down Expand Up @@ -1033,6 +1041,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: int = 0,
):
if self.bootstrap_infos is None or self.bootstrap_room is None:
return
Expand All @@ -1058,6 +1067,7 @@ def send_metadata(
aux_bytes if not is_dummy else b"",
state_bytes,
str(self.required_dst_info_num).encode("ascii"),
str(decode_prefix_len).encode("ascii"),
]
)
self.init_time = time.time()
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TransferInfo:
dst_aux_index: int
required_dst_info_num: int
dst_state_indices: List[int]
decode_prefix_len: int = 0

def is_dummy(self):
return self.dst_kv_indices.size == 0
Expand All @@ -56,6 +57,11 @@ def from_zmq(cls, msg: List[bytes]):
else:
dst_state_indices = []

# decode_prefix_len: backward compatible, default 0 if not present
decode_prefix_len = (
int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else 0
)

return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
Expand All @@ -65,6 +71,7 @@ def from_zmq(cls, msg: List[bytes]):
dst_aux_index=int(msg[5].decode("ascii")),
required_dst_info_num=int(msg[6].decode("ascii")),
dst_state_indices=dst_state_indices,
decode_prefix_len=decode_prefix_len,
)


Expand Down Expand Up @@ -1113,6 +1120,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: int = 0,
):
if self.bootstrap_infos is None:
logger.error(
Expand Down Expand Up @@ -1146,6 +1154,7 @@ def send_metadata(
if not is_dummy and state_indices is not None
else b""
),
str(decode_prefix_len).encode("ascii"),
]
)

Expand Down
38 changes: 30 additions & 8 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,28 @@ def pop_bootstrapped(
)
assert req.metadata_buffer_index is not None

num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
page_size = self.token_to_kv_pool.page_size
total_pages = kv_to_page_num(num_kv_indices, page_size)

# Read decode_prefix_len from the bootstrap message to skip
# transferring KV that the decode side already has cached.
decode_prefix_len = self.kv_manager.get_decode_prefix_len(
req.bootstrap_room
)
logger.info(
f"Prefill bootstrap for {req.rid}: "
f"decode_prefix_len={decode_prefix_len}, total_pages={total_pages}, "
f"bootstrap_room={req.bootstrap_room}"
)
if decode_prefix_len > 0:
req.start_send_idx = decode_prefix_len
req.disagg_prefill_skip_tokens = decode_prefix_len
decode_prefix_pages = kv_to_page_num(decode_prefix_len, page_size)
incremental_pages = total_pages - decode_prefix_pages
else:
incremental_pages = total_pages

req.disagg_kv_sender.init(incremental_pages, req.metadata_buffer_index)

bootstrapped_reqs.append(req)
indices_to_remove.add(i)
Expand Down Expand Up @@ -658,18 +678,20 @@ def process_disagg_prefill_inflight_queue(
req.time_stats.set_completion_time()

page_size = self.token_to_kv_pool_allocator.page_size
kv_item_lens = (
self.disagg_prefill_bootstrap_queue.kv_manager.kv_args.kv_item_lens
)
bytes_per_page_all_layers = sum(kv_item_lens)
kv_args = self.disagg_prefill_bootstrap_queue.kv_manager.kv_args
bytes_per_page_all_layers = sum(kv_args.kv_item_lens)
state_bytes_per_req = sum(kv_args.state_item_lens) if kv_args.state_item_lens else 0

for req in done_reqs:
if isinstance(req.finished_reason, FINISH_ABORT):
continue
# Use actual transferred tokens (excluding decode-side cached prefix)
actual_transfer_tokens = len(req.origin_input_ids) - req.disagg_prefill_skip_tokens
metrics = req.time_stats.compute_and_observe_kv_transfer_metrics(
num_tokens=len(req.origin_input_ids),
num_tokens=actual_transfer_tokens,
page_size=page_size,
bytes_per_page_all_layers=bytes_per_page_all_layers,
state_bytes_per_req=state_bytes_per_req,
)
if metrics:
# Update last-value for REST API
Expand Down Expand Up @@ -767,7 +789,6 @@ def send_kv_chunk(
.cpu()
.numpy()
)
req.start_send_idx = end_idx
state_indices = None
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
Expand Down Expand Up @@ -819,4 +840,5 @@ def send_kv_chunk(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.start_send_idx = end_idx
req.disagg_kv_sender.send(page_indices, state_indices)
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,15 @@ def __init__(
# start_send_idx = len(req.fill_ids)
self.start_send_idx: int = 0

# For incremental KV transfer in PD disaggregation:
# Page-aligned prefix length matched on the decode side against its local tree cache.
# When > 0, only the incremental KV beyond this prefix is transferred from prefill.
self.disagg_decode_prefix_len: int = 0

# Number of tokens skipped on the prefill side due to decode-side prefix caching.
# Used for accurate transfer_total metrics reporting.
self.disagg_prefill_skip_tokens: int = 0

# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3265,6 +3265,16 @@ def abort_request(self, recv_req: AbortReq):
if self.disaggregation_mode == DisaggregationMode.DECODE:
if self.enable_hisparse:
self.hisparse_coordinator.request_finished(req)
# Protect tree-owned prefix pages from being freed.
if req.disagg_decode_prefix_len > 0:
req.cache_protected_len = req.disagg_decode_prefix_len
# Ensure last_node is valid for dec_lock_ref inside
# cache_finished_req. When no prefix was matched in
# pop_preallocated, last_node was never set (stays None).
if req.last_node is None and hasattr(
self.tree_cache, "root_node"
):
req.last_node = self.tree_cache.root_node
release_kv_cache(req, self.tree_cache)
# For disaggregation prefill mode, free the metadata buffer index
if self.disaggregation_mode == DisaggregationMode.PREFILL:
Expand Down
42 changes: 36 additions & 6 deletions python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,16 +580,21 @@ def _mamba_prefix_cache_update(
)
)
req.mamba_last_track_seqlen = seq_len
elif (
not batch.spec_algorithm.is_none()
and result.accept_length_per_req_cpu is not None
):
elif not batch.spec_algorithm.is_none():
# for spec decode, update mamba_last_track_seqlen if this iteration crosses a track interval
actual_seq_len = req.seqlen - 1
# Use mamba_last_track_seqlen as the reference for the previous bucket.
# For the first check (when tracking hasn't started), use the input length.
# This avoids relying on accept_length_per_req_cpu which can be inaccurate:
# V1 has intermediate forward passes with accept_len=0, and V2 may not
# include the bonus token in accept_len, both causing missed boundary crossings.
if req.mamba_last_track_seqlen is not None:
prev_ref = req.mamba_last_track_seqlen
else:
prev_ref = len(req.origin_input_ids) - 1
if (
actual_seq_len // mamba_track_interval
!= (actual_seq_len - result.accept_length_per_req_cpu[i])
// mamba_track_interval
> prev_ref // mamba_track_interval
):
req.mamba_next_track_idx = (
batch.req_to_token_pool.get_mamba_ping_pong_other_idx(
Expand All @@ -599,6 +604,31 @@ def _mamba_prefix_cache_update(
req.mamba_last_track_seqlen = (
actual_seq_len // mamba_track_interval * mamba_track_interval
)
elif (
req.mamba_last_track_seqlen is None
and self.server_args.disaggregation_mode == "decode"
and self.server_args.disaggregation_enable_decode_radix_cache
):
# First spec-decode call and no boundary crossed.
# For disagg decode, the transferred mamba state has never
# been checkpointed on this node. Force a swap so that the
# current ping_pong buffer (populated by this forward pass)
# is preserved as the initial checkpoint.
# The state is at actual_seq_len rather than exactly at the
# boundary, but on the decode side tree cache mamba_value is
# not used for computation (skip_mamba_truncation=True).
# This is gated to disagg decode only — on single-machine,
# mamba_value IS used for state loading and must be exact.
initial_track = (
prev_ref // mamba_track_interval * mamba_track_interval
)
if initial_track > 0:
req.mamba_next_track_idx = (
batch.req_to_token_pool.get_mamba_ping_pong_other_idx(
req.mamba_next_track_idx
)
)
req.mamba_last_track_seqlen = initial_track

def _process_input_token_logprobs(
self: Scheduler, req: Req, input_token_logprobs: List
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/mem_cache/base_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class MatchPrefixParams:
cow_mamba: bool = False
req: Optional[Req] = None

# For disagg decode: skip mamba-based prefix truncation.
# When True, match_prefix returns all matched KV indices regardless of
# mamba_value presence on intermediate nodes. SSM state is transferred
# from prefill, so decode tree cache mamba_value is not needed.
skip_mamba_truncation: bool = False


@dataclasses.dataclass
class InsertParams:
Expand Down
Loading
Loading