Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/pd_disaggregation.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ The `SGLANG_MOONCAKE_CUSTOM_MEM_POOL` environment variable enables the custom me
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `300` |
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL`** | Interval (seconds) between cleanups of bootstrap entries | `120` |

If a greater mean TTFT is acceptable, you can `export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600` (10 minutes) to relax the timeout condition.
Please be aware that this setting will cause prefill instances to take a longer time to clean up the affected memory resources when a running decode node loses connection.
Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(
is_mla_backend: Optional[bool] = False,
): ...

@abstractmethod
def register_to_bootstrap(self):
"""Register to the bootstrap server."""
...


class BaseKVSender(ABC):

Expand Down Expand Up @@ -158,4 +163,4 @@ def abort(self):

class BaseKVBootstrapServer(ABC):
@abstractmethod
def __init__(self, host: str, port: int): ...
def __init__(self, host: str, port: int, dp_size: int = 1): ...
291 changes: 196 additions & 95 deletions python/sglang/srt/disaggregation/common/conn.py

Large diffs are not rendered by default.

120 changes: 97 additions & 23 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.retracted_queue: List[Req] = []
self.pending_reqs: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager()

Expand Down Expand Up @@ -345,32 +346,59 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
req.retraction_mb_id = None
self.retracted_queue.append(req)
else:
# Auto enable FAKE mode if configured
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or (
req.bootstrap_host is None
and self.scheduler.server_args.disaggregation_transfer_backend == "fake"
):
kv_receiver_class = get_kv_class(
TransferBackend.FAKE, KVClassType.RECEIVER
)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
dp_rank = self._resolve_dp_rank(req)
if dp_rank is None:
self.pending_reqs.append(req)
return
self._create_receiver_and_enqueue(req, dp_rank)

def _resolve_dp_rank(self, req: Req) -> Optional[int]:
if req.data_parallel_rank is not None:
return req.data_parallel_rank

if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or (
req.bootstrap_host is None
and self.scheduler.server_args.disaggregation_transfer_backend == "fake"
):
return 0

kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=req.data_parallel_rank,
bootstrap_addr = f"{req.bootstrap_host}:{req.bootstrap_port}"

if bootstrap_addr not in self.kv_manager.prefill_dp_size_table:
return None

if self.kv_manager.follow_bootstrap_room_table[bootstrap_addr]:
return (
req.bootstrap_room
% self.kv_manager.prefill_dp_size_table[bootstrap_addr]
)

req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
return None

def _create_receiver_and_enqueue(self, req: Req, dp_rank: int) -> None:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST or (
req.bootstrap_host is None
and self.scheduler.server_args.disaggregation_transfer_backend == "fake"
):
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)

kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=dp_rank,
)

req.add_latency(RequestStage.DECODE_PREPARE)
trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)

def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
Expand Down Expand Up @@ -465,10 +493,56 @@ def _update_handshake_waiters(
else:
raise ValueError(f"Unexpected poll case: {poll}")

def _resolve_pending_reqs(self) -> None:
"""Batch-resolve dp_ranks for pending requests and create receivers."""
if not self.pending_reqs:
return

bootstrap_addr = f"{self.pending_reqs[0].bootstrap_host}:{self.pending_reqs[0].bootstrap_port}"

# If a request is following the bootstrap room,
# we need get the prefill info before resolving the dp_rank,
# which is a conflict with the lazy resolve logic in CommonKVReceiver,
# so we need to ensure the parallel info before resolving the dp_rank
if not self.kv_manager.ensure_parallel_info(bootstrap_addr):
return

resolved = []
need_query = []
for req in self.pending_reqs:
# NOTE: we need resolve it again because we may ensure the parallel info here
dp_rank = self._resolve_dp_rank(req)
if dp_rank is not None:
resolved.append((req, dp_rank))
else:
need_query.append(req)

if need_query:
from sglang.srt.disaggregation.common.conn import CommonKVReceiver

rooms = [req.bootstrap_room for req in need_query]
room_to_rank = CommonKVReceiver.query_prefill_dp_ranks(
bootstrap_addr, rooms
)
remaining = []
for req in need_query:
room_key = str(req.bootstrap_room)
if room_key in room_to_rank:
resolved.append((req, int(room_to_rank[room_key])))
else:
remaining.append(req)
self.pending_reqs = remaining
else:
self.pending_reqs = []

for req, dp_rank in resolved:
self._create_receiver_and_enqueue(req, dp_rank)

def pop_preallocated(
self, rids_to_check: Optional[List[str]] = None
) -> Tuple[List[DecodeRequest], List[DecodeRequest]]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
self._resolve_pending_reqs()
self._update_handshake_waiters(rids_to_check)

failed_reqs = []
Expand Down Expand Up @@ -1086,7 +1160,7 @@ def process_decode_queue(self: Scheduler):
if self.polling_count % self.polling_interval == 0:
req_conns, _ = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
transferred_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
self.waiting_queue.extend(transferred_reqs)
18 changes: 10 additions & 8 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,18 +1124,20 @@ def _handle_node_failure(self, failed_bootstrap_addr):
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_attn_tp_size_table:
del self.prefill_attn_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]

possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
keys_to_remove = [
self.prefill_attn_tp_size_table,
self.prefill_dp_size_table,
self.prefill_pp_size_table,
self.follow_bootstrap_room_table,
self.addr_to_rooms_tracker,
]
for k in keys_to_remove:
if failed_bootstrap_addr in k:
del k[failed_bootstrap_addr]

# Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed
affected_rooms = []
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class Envs:
SGLANG_REQ_WAITING_TIMEOUT = EnvFloat(-1) # in seconds
SGLANG_NCCL_ALL_GATHER_IN_OVERLAP_SCHEDULER_SYNC_BATCH = EnvBool(False)
SGLANG_REQ_RUNNING_TIMEOUT = EnvFloat(-1) # in seconds
SGLANG_DISAGGREGATION_BOOTSTRAP_ENTRY_CLEANUP_INTERVAL = EnvInt(120)

# Test: pd-disaggregation
SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake")
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/disagg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def start_disagg_service(
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
host=server_args.host,
port=server_args.disaggregation_bootstrap_port,
dp_size=server_args.dp_size,
)
is_create_store = (
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
Expand Down
12 changes: 0 additions & 12 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,18 +814,6 @@ def _handle_load_balance_method(self):
)
return

# Backward compat: in PD prefill, legacy "round_robin" means `bootstrap_room` routing.
if (
self.disaggregation_mode == "prefill"
and self.load_balance_method == "round_robin"
):
logger.warning(
"In PD-disaggregation prefill mode, the 'round_robin' load balancing method "
"means `bootstrap_room` routing (use 'follow_bootstrap_room' instead). "
"Falling back to 'follow_bootstrap_room' for backward compatibility."
)
self.load_balance_method = "follow_bootstrap_room"

def _handle_deprecated_args(self):
# Handle deprecated tool call parsers
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class TestDisaggregationDPAttention(PDDisaggregationServerBase):
PREFILL_DP_SIZE = 4
DECODE_DP_SIZE = 4
LOAD_BALANCE_METHOD = "auto"

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -50,6 +51,8 @@ def start_prefill(cls):
"--dp",
str(cls.PREFILL_DP_SIZE),
"--enable-dp-attention",
"--load-balance-method",
cls.LOAD_BALANCE_METHOD,
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
Expand All @@ -72,6 +75,8 @@ def start_decode(cls):
"--enable-dp-attention",
"--base-gpu-id",
str(cls.PREFILL_DP_SIZE),
"--load-balance-method",
cls.LOAD_BALANCE_METHOD,
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
Expand All @@ -97,5 +102,9 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.60)


class TestDisaggregationDPAttentionRoundRobin(TestDisaggregationDPAttention):
LOAD_BALANCE_METHOD = "round_robin"


if __name__ == "__main__":
unittest.main()
Loading