Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/pd_disaggregation.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ export MC_FORCE_MNNVL=True
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` |

If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition.
Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection.
Expand Down
198 changes: 179 additions & 19 deletions python/sglang/srt/disaggregation/common/conn.py

Large diffs are not rendered by default.

51 changes: 50 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def pop_preallocated(
failed_reqs = []
preallocated_reqs = []
indices_to_remove = set()
bootstrap_table = None
data_parallel_rank = None

# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
Expand All @@ -475,7 +477,51 @@ def pop_preallocated(
indices_to_remove.add(i)

# Then, preallocate the remaining requests if possible
# Batch fetch all bootstrap_rooms at once to reduce network overhead
if bootstrap_table is None and any(
decode_req.kv_receiver.should_notify_dp_rank for decode_req in self.queue
):
bootstrap_rooms_to_fetch = []
for decode_req in self.queue:
if (
decode_req.kv_receiver.should_notify_dp_rank
and hasattr(
decode_req.kv_receiver, "_get_prefill_dp_rank_from_server"
)
and decode_req.req.bootstrap_host != FAKE_BOOTSTRAP_HOST
):
bootstrap_rooms_to_fetch.append(decode_req.req.bootstrap_room)

if bootstrap_rooms_to_fetch:
# Use the first kv_receiver to batch fetch all bootstrap_rooms
for decode_req in self.queue:
if hasattr(
decode_req.kv_receiver, "_get_prefill_dp_rank_from_server"
):
bootstrap_table = (
decode_req.kv_receiver._get_prefill_dp_rank_from_server(
bootstrap_rooms_to_fetch
)
)
break

for i, decode_req in enumerate(self.queue):
if decode_req.kv_receiver.should_notify_dp_rank:
# Do not check warmup requests with FAKE_BOOTSTRAP_HOST
if decode_req.req.bootstrap_host != FAKE_BOOTSTRAP_HOST:
if (
bootstrap_table
and str(decode_req.req.bootstrap_room) in bootstrap_table
):
data_parallel_rank = bootstrap_table[
str(decode_req.req.bootstrap_room)
]["dp_rank"]
else:
logger.debug(
f"bootstrap info for {decode_req.req.bootstrap_room} {decode_req.req.bootstrap_host} not found"
)
continue

if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
continue

Expand Down Expand Up @@ -571,7 +617,10 @@ def pop_preallocated(
assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices(kv_indices, page_size)
decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
page_indices,
decode_req.metadata_buffer_index,
state_indices,
data_parallel_rank,
)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/disaggregation/fake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
prefill_dp_rank: Optional[int] = None,
):
self.has_init = False
self.should_notify_dp_rank = False

def poll(self) -> KVPoll:
if self.has_init is False:
Expand All @@ -83,6 +84,7 @@ def init(
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
prefill_dp_rank: Optional[int] = None,
):
self.has_init = True
logger.debug(
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,13 @@ def init(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
prefill_dp_rank: Optional[int] = None,
):
# Dp rank for prefill server is synchronized now.
if self.should_notify_dp_rank:
self.prefill_dp_rank = prefill_dp_rank
self._setup_bootstrap_infos()

if self.bootstrap_infos is None:
self.kv_mgr.record_failure(
self.bootstrap_room,
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,13 @@ def init(
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
prefill_dp_rank: Optional[int] = None,
):
# Dp rank for prefill server is synchronized now.
if self.should_notify_dp_rank:
self.prefill_dp_rank = prefill_dp_rank
self._setup_bootstrap_infos()

if self.bootstrap_infos is None:
logger.error(
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class Envs:
SGLANG_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK = EnvFloat(None)
SGLANG_DATA_PARALLEL_BUDGET_INTERVAL = EnvInt(1)
SGLANG_QUEUED_TIMEOUT_MS = EnvInt(-1)
SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL = EnvInt(120)

# Test: pd-disaggregation
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/disagg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def start_disagg_service(
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
dp_size=server_args.dp_size,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
Expand Down
12 changes: 0 additions & 12 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,18 +787,6 @@ def _handle_load_balance_method(self):
)
return

# Backward compat: in PD prefill, legacy "round_robin" means `bootstrap_room` routing.
if (
self.disaggregation_mode == "prefill"
and self.load_balance_method == "round_robin"
):
logger.warning(
"In PD-disaggregation prefill mode, the 'round_robin' load balancing method "
"means `bootstrap_room` routing (use 'follow_bootstrap_room' instead). "
"Falling back to 'follow_bootstrap_room' for backward compatibility."
)
self.load_balance_method = "follow_bootstrap_room"

def _handle_deprecated_args(self):
# Handle deprecated tool call parsers
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
Expand Down