Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
73f769c
init
ishandhanani Mar 3, 2026
c316c42
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 3, 2026
0268551
fix stupid mem leak
ishandhanani Mar 3, 2026
5fa0819
fix: restore token_msg in _check_radix_cache_memory
ishandhanani Mar 3, 2026
49cf9e6
fix: set prefix_indices in _pre_alloc and always lock matched node
ishandhanani Mar 4, 2026
95dc9a7
rebase
ishandhanani Mar 5, 2026
c21f1cd
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 5, 2026
aa6a6e6
fix: handle 0-page KV transfer completion in disagg decode radix cache
ishandhanani Mar 6, 2026
5b93682
fix: evict decode radix cache before KV pre-allocation
ishandhanani Mar 6, 2026
64cf052
fix: don't treat radix cache 0-page transfers as CP dummy
ishandhanani Mar 6, 2026
9873033
fix: align decode radix cache prefix_len to page boundary
ishandhanani Mar 6, 2026
c782926
fix: account for locked prefix in pre-allocation budget
ishandhanani Mar 6, 2026
5e8fb59
disagg prefill: keep chunk send cursor monotonic
ishandhanani Mar 6, 2026
eacb121
fix: skip empty non-last chunk sends to prevent chunk_id inflation
ishandhanani Mar 6, 2026
2c3cf34
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 9, 2026
4a30312
feat: add --disaggregation-decode-enable-radix-cache CLI flag
ishandhanani Mar 9, 2026
90a4a8b
fix: use MatchResult fields instead of tuple unpacking for match_prefix
ishandhanani Mar 9, 2026
b990f29
fix: recompute allocatable_tokens from pool state after each pre-alloc
ishandhanani Mar 9, 2026
e141484
feat: soften DP attention + decode radix cache to warning
ishandhanani Mar 9, 2026
018c311
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 11, 2026
6628e96
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 12, 2026
58d52c7
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 13, 2026
05a8e65
fix: pass extra_key to RadixKey in decode radix cache prefix matching
ishandhanani Mar 13, 2026
4cd3ee8
Guard decode radix cache behind nixl backend
ishandhanani Mar 16, 2026
6d1171c
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 17, 2026
9ed1495
cleanup decode radix cache review nits
ishandhanani Mar 17, 2026
6b68b47
clean up decode prefix send cursor state
ishandhanani Mar 17, 2026
d1f3d74
move decode prefix handoff into kv manager state
ishandhanani Mar 17, 2026
824c974
fix: reset cache_protected_len on retract to prevent KV pool leak
ishandhanani Mar 18, 2026
40f8848
fix(radix): cache only committed KV in unfinished requests
ishandhanani Mar 21, 2026
8698b5f
fix(decode): budget preallocation in page units
ishandhanani Mar 21, 2026
10e47e9
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 24, 2026
74ac023
Guard empty final KV chunk on non-NIXL backends
ishandhanani Mar 24, 2026
5187e30
Fix FakeKVSender missing kv_mgr attribute in disagg prefill bootstrap
ishandhanani Mar 24, 2026
3cb66bc
Fix FakeKVReceiver.init() missing decode_prefix_len parameter
ishandhanani Mar 24, 2026
bed7d4a
Add decode_prefix_len param to all KVReceiver.init() implementations
ishandhanani Mar 24, 2026
67bfc8c
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 25, 2026
8bc1310
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 25, 2026
ee7bee3
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Mar 25, 2026
8da42a4
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Apr 1, 2026
963ea7f
Merge branch 'main' into ishan/add-radix-cache-decode
ishandhanani Apr 5, 2026
540ee41
Clarify hisparse vs decode radix cache config
ishandhanani Apr 5, 2026
a34a99d
Fix CI: move PD disaggregation validation before dummy-model short-ci…
ishandhanani Apr 6, 2026
53294d6
Merge radix/chunk-cache prefix match branches in pop_preallocated
ishandhanani Apr 7, 2026
03bf260
Assert unsupported cases for decode radix cache: spec dec, mamba, swa
ishandhanani Apr 7, 2026
9ac9bb7
Add TODO for retraction + radix cache interaction in _pre_alloc
ishandhanani Apr 7, 2026
7856688
Remove unnecessary running_batch guard in decode idle memory check
ishandhanani Apr 7, 2026
4523ebb
Revert "Remove unnecessary running_batch guard in decode idle memory …
ishandhanani Apr 7, 2026
6058b91
Review fixes: guard removal, cursor fix, TP sync, diagnostics
ishandhanani Apr 7, 2026
a6b6e8b
Flush radix cache after disagg warmup to discard garbage KV
ishandhanani Apr 8, 2026
a0ebaad
Add page_align_floor util and use in decode prealloc
ishandhanani Apr 8, 2026
01333d4
Move fill_ids truncation from radix_cache to decode scheduler
ishandhanani Apr 8, 2026
7e5e8a1
Fix Qwen3 rope_theta merge conflict artifact
ishandhanani Apr 8, 2026
83c76ca
Merge remote-tracking branch 'origin/main' into ishan/add-radix-cache…
ishandhanani Apr 8, 2026
f035678
Fix pre-commit: black formatting + move manual test out of registered/
ishandhanani Apr 9, 2026
421bfd6
Remove unrelated test file accidentally committed
ishandhanani Apr 9, 2026
4f2c34e
Add unit tests for lock_ref balance across decode transfer scenarios
ishandhanani Apr 9, 2026
c6e70aa
Merge remote-tracking branch 'origin/main' into ishan/add-radix-cache…
ishandhanani Apr 14, 2026
b24f58d
Restore decode transfer queue all-reduce guards
ishandhanani Apr 14, 2026
8ea6649
Flush prefill radix cache after warmup
ishandhanani Apr 14, 2026
b0bcc35
Restore Qwen3 rope config fallback
ishandhanani Apr 14, 2026
f7327da
Remove decode queue TP sync guards
ishandhanani Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: Optional[int] = None,
):
"""
Notify the prefill server about the kv indices, aux index, and state_indices.
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
)
self.register_to_bootstrap()
self.transfer_infos = {}
self.req_to_decode_prefix_len: Dict[int, int] = {}
self.decode_kv_args_table = {}
self.pp_group = get_pp_group()
# If a timeout happens on the prefill side, it means prefill instances
Expand Down Expand Up @@ -179,6 +180,12 @@ def check_status(self, bootstrap_room: int) -> KVPoll:
return self.request_status[bootstrap_room]

def update_status(self, bootstrap_room: int, status: KVPoll):
if (
status == KVPoll.Failed
and self.disaggregation_mode == DisaggregationMode.PREFILL
and hasattr(self, "req_to_decode_prefix_len")
):
self.req_to_decode_prefix_len.pop(bootstrap_room, None)
Comment on lines +183 to +188
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider putting this pop logic in the clear() function.

I can help with this when I am doing that following PR, we can keep it this way now.

if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
Expand Down
262 changes: 224 additions & 38 deletions python/sglang/srt/disaggregation/decode.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiezhq-hermann, we need your help. The logic of these tree cache changes looks reasonable, but I am not quite sure whether it will introduce a potential memory leak or bug when used with other features, and many of the logics here are not protected by the decode radix cache server args. We need an expert to check on this.

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
is_mla_backend: Optional[bool] = False,
):
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
self.req_to_decode_prefix_len = {}

def register_to_bootstrap(self):
pass
Expand All @@ -41,6 +42,7 @@ def __init__(
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.has_sent = False

def poll(self) -> KVPoll:
Expand Down Expand Up @@ -106,6 +108,7 @@ def send_metadata(
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: Optional[int] = None,
):
self.has_sent_metadata = True
logger.debug(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: Optional[int] = None,
):
if self.bootstrap_infos is None:
self.kv_mgr.record_failure(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/mori/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: Optional[int] = None,
):
if self.bootstrap_infos is None or self.bootstrap_room is None:
return
Expand Down
102 changes: 72 additions & 30 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ class TransferInfo:
dst_aux_index: int
required_dst_info_num: int
dst_state_indices: List[int]
decode_prefix_len: Optional[int] = None # for decode radix cache

def is_dummy(self):
# A transfer is "dummy" only for CP non-authoritative ranks.
# When dst_kv_indices is empty due to a decode-side radix cache
# full hit (decode_prefix_len > 0), the transfer is NOT dummy --
# aux/state data still needs to be sent.
if self.dst_kv_indices.size == 0 and self.decode_prefix_len:
return False
return self.dst_kv_indices.size == 0

@classmethod
Expand All @@ -65,6 +72,9 @@ def from_zmq(cls, msg: List[bytes]):
dst_aux_index=int(msg[5].decode("ascii")),
required_dst_info_num=int(msg[6].decode("ascii")),
dst_state_indices=dst_state_indices,
decode_prefix_len=(
int(msg[8].decode("ascii")) if len(msg) > 8 and msg[8] != b"" else None
), # hacky just add it into the message that will be sent
)


Expand Down Expand Up @@ -883,39 +893,44 @@ def add_transfer_request(
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table

notif = (
f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}"
)
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
# Skip KV RDMA transfer when there are no pages to send
# (e.g., decode-side radix cache matched the entire prefix).
# Aux data is still sent below when is_last=True.
if len(kv_indices) > 0:
notif = (
f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}"
)

handles.append(kv_xfer_handle)
if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)

handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
if state_indices is not None:
Expand All @@ -936,16 +951,24 @@ def add_transfer_request(
handles.append(state_xfer_handle)

assert aux_index is not None
# When no KV pages were sent (decode-side cache hit),
# encode pp_rank in aux notif so receiver can mark
# expected_kvs_per_pp[pp_rank] = 0.
if len(kv_indices) == 0:
aux_notif = f"{req.room}_aux_nokv_{self.kv_args.pp_rank}"
else:
aux_notif = f"{req.room}_aux"
aux_xfer_handle = self.send_aux(
req.agent_name,
aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
f"{req.room}_aux",
aux_notif,
)
handles.append(aux_xfer_handle)
if is_last:
del self.transfer_infos[bootstrap_room]
self.req_to_decode_prefix_len.pop(bootstrap_room, None)
return handles

def update_transfer_status(self):
Expand Down Expand Up @@ -978,6 +1001,15 @@ def update_transfer_status(self):
)
elif components[1] == "aux":
self.transfer_statuses[room].received_aux = True
# Handle "nokv" marker: no KV pages were sent for
# this pp_rank (decode-side radix cache hit).
if len(components) > 3 and components[2] == "nokv":
pp_rank = int(components[3])
self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = 0
if self.transfer_statuses[room].num_pp_ranks_expected is None:
self.transfer_statuses[room].num_pp_ranks_expected = (
self.required_prefill_response_num_table.get(room, 1)
)
elif components[1] == "state":
pp_rank = int(components[2]) if len(components) > 2 else 0
self.transfer_statuses[room].received_state_per_pp.add(pp_rank)
Expand Down Expand Up @@ -1019,6 +1051,14 @@ def bootstrap_thread():
].required_dst_info_num
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
if len(self.transfer_infos[room]) == required_dst_info_num:
self.req_to_decode_prefix_len[room] = next(
(
info.decode_prefix_len
for info in self.transfer_infos[room].values()
if info.decode_prefix_len is not None
),
0,
)
logger.debug(f"{room=} is bootstrapped")
self.update_status(room, KVPoll.WaitingForInput)

Expand Down Expand Up @@ -1113,6 +1153,7 @@ def send_metadata(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
decode_prefix_len: Optional[int] = None,
):
if self.bootstrap_infos is None:
logger.error(
Expand Down Expand Up @@ -1146,6 +1187,7 @@ def send_metadata(
if not is_dummy and state_indices is not None
else b""
),
str(decode_prefix_len or 0).encode("ascii"),
]
)

Expand Down
39 changes: 32 additions & 7 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def pop_bootstrapped(
self.scheduler.tree_cache.release_aborted_request(req.rid)
continue

# KV.WaitingForInput - init here
# KV.WaitingForInput - decode is ready to receive. initialize the kv sender
req.time_stats.set_bootstrap_done_time()
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
Expand All @@ -335,7 +335,19 @@ def pop_bootstrapped(
)
assert req.metadata_buffer_index is not None

num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
# Cal number of pages to send
# if decode has a cached prefix, we need to send the delta indices
# otherwise, send the entire request
decode_prefix_len = (
req.disagg_kv_sender.kv_mgr.req_to_decode_prefix_len.pop(
req.bootstrap_room, 0
)
)
req.start_send_idx = decode_prefix_len
num_kv_indices_to_send = num_kv_indices - decode_prefix_len
num_pages = kv_to_page_num(
num_kv_indices_to_send, self.token_to_kv_pool.page_size
)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)

bootstrapped_reqs.append(req)
Expand Down Expand Up @@ -768,12 +780,20 @@ def send_kv_chunk(
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size

if end_idx < start_idx:
logger.debug(
"send_kv_chunk skip: rid=%s start_send_idx=%s end_idx=%s",
req.rid,
start_idx,
end_idx,
)
return

Comment on lines +783 to +791
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better if we return early here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I changed the end_idx < start_idx case to return early instead of clamping. I kept the end_idx == start_idx final-chunk path intact so NIXL can still send the aux/state-only completion for full decode-cache hits.

kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
req.start_send_idx = end_idx
state_indices = None
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
Expand Down Expand Up @@ -820,9 +840,14 @@ def send_kv_chunk(
state_indices = kv_to_page_indices(state_indices, page_size)

page_indices = kv_to_page_indices(kv_indices, page_size)
# Skip empty non-last chunks for all backends. For empty last chunks,
# only NIXL currently defines the aux/state-only completion path used
# by decode-side radix cache; keep a conservative early return for
# other backends until they implement the same semantics.
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
if not last_chunk:
return
if self.transfer_backend != TransferBackend.NIXL:
return
req.disagg_kv_sender.send(page_indices, state_indices)
req.start_send_idx = end_idx
5 changes: 5 additions & 0 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
return (num_kv_indices + page_size - 1) // page_size


def page_align_floor(length: int, page_size: int) -> int:
"""Round length down to the nearest page boundary."""
return (length // page_size) * page_size


def page_indices_to_cp_rank_page_indices(
page_indices: np.ndarray,
total_pages: int,
Expand Down
27 changes: 26 additions & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ async def lifespan(fast_api_app: FastAPI):
server_args.warmups.split(","),
_global_state.tokenizer_manager,
)
if (
server_args.disaggregation_mode != "null"
and not server_args.disable_radix_cache
):
await _global_state.tokenizer_manager.flush_cache()
logger.info("Warmup ended")

# Execute the general warmup
Expand Down Expand Up @@ -1973,8 +1978,28 @@ def _execute_server_warmup(server_args: ServerArgs):
)
if res.status_code == 200:
logger.info(
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
f"Disaggregation warmup request completed with status {res.status_code}, resp: {res.json()}"
)
if (
server_args.disaggregation_mode != "null"
and not server_args.disable_radix_cache
):
try:
flush_res = requests.post(
url + "/flush_cache",
headers=headers,
timeout=30,
verify=ssl_verify,
)
if flush_res.status_code == 200:
logger.info("Flushed warmup cache")
else:
logger.warning(
f"Warmup cache flush failed: {flush_res.status_code}"
)
except Exception as e:
logger.warning(f"Warmup cache flush request failed: {e}")
logger.info("End of disaggregation warmup")
_global_state.tokenizer_manager.server_status = ServerStatus.Up
else:
logger.info(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,7 @@ def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
self.routed_experts = None
self.last_node = None
self.cache_protected_len = 0
self.swa_uuid_for_lock = None
self.extend_input_len = 0
self.is_retracted = True
Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,24 @@ def init_cache_with_memory_pool(self):
"Transformers backend to avoid multimodal prefix-cache mismatches."
)

# Decode radix cache is unsupported with hybrid SWA/SSM models —
# these use specialized memory pools incompatible with the
# prefix-match-and-lock allocation path.
if (
server_args.disaggregation_decode_enable_radix_cache
and server_args.disaggregation_mode == "decode"
):
if self.is_hybrid_swa:
raise ValueError(
"--disaggregation-decode-enable-radix-cache is incompatible "
"with sliding window attention (SWA) models"
)
if self.is_hybrid_ssm:
raise ValueError(
"--disaggregation-decode-enable-radix-cache is incompatible "
"with Mamba/SSM models"
)

effective_chunked_prefill_size = server_args.chunked_prefill_size
if self.model_config.is_multimodal and uses_transformers_backend:
effective_chunked_prefill_size = None
Expand Down
Loading
Loading