Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,11 @@ def __init__(
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank

if self.prefill_pp_size == self.kv_mgr.pp_size:
self.target_pp_ranks = [self.kv_mgr.pp_rank]
else:
self.target_pp_ranks = [rank for rank in range(self.prefill_pp_size)]

self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
)
Expand All @@ -348,7 +353,7 @@ def __init__(
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
for target_pp_rank in range(self.prefill_pp_size):
for target_pp_rank in reversed(self.target_pp_ranks):
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank, self.target_dp_group, target_pp_rank
)
Expand Down Expand Up @@ -605,4 +610,4 @@ def close(self):
self.thread.join(timeout=2)
logger.info("Server thread stopped")

def poll(self) -> KVPoll: ...
def poll(self) -> KVPoll: ...
46 changes: 31 additions & 15 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
Life cycle of a request in the prefill server

Expand Down Expand Up @@ -581,22 +582,37 @@ def get_transferred_rids(self: Scheduler) -> List[str]:
return transferred_rids

def process_prefill_chunk(self: Scheduler) -> None:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
chunked_req_to_exclude = set()
if self.chunked_req:
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
else:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.last_batch.chunked_req:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude.add(self.last_batch.chunked_req)

last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False

def send_kv_chunk(
Expand Down