diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index 1a66c83ae0c9..f0856bd48cc4 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -337,6 +337,11 @@ def __init__( # FIXME: alias here: target_dp_group -> prefill_dp_rank self.target_dp_group = self.prefill_dp_rank + if self.prefill_pp_size == self.kv_mgr.pp_size: + self.target_pp_ranks = [self.kv_mgr.pp_rank] + else: + self.target_pp_ranks = [rank for rank in range(self.prefill_pp_size)] + self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num ) @@ -348,7 +353,7 @@ def __init__( if bootstrap_key not in self.kv_mgr.connection_pool: bootstrap_infos = [] for target_tp_rank in self.target_tp_ranks: - for target_pp_rank in range(self.prefill_pp_size): + for target_pp_rank in reversed(self.target_pp_ranks): bootstrap_info = self._get_bootstrap_info_from_server( target_tp_rank, self.target_dp_group, target_pp_rank ) @@ -605,4 +610,4 @@ def close(self): self.thread.join(timeout=2) logger.info("Server thread stopped") - def poll(self) -> KVPoll: ... + def poll(self) -> KVPoll: ... \ No newline at end of file diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index e2a51e712f32..6c3c9fd46c59 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -1,3 +1,4 @@ + """ Life cycle of a request in the prefill server @@ -581,22 +582,37 @@ def get_transferred_rids(self: Scheduler) -> List[str]: return transferred_rids def process_prefill_chunk(self: Scheduler) -> None: - if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.chunked_req: - # Move the chunked request out of the batch so that we can merge - # only finished requests to running_batch. - self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) - self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) - if self.enable_overlap: - # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved - self.chunked_req.tmp_end_idx = min( - len(self.chunked_req.fill_ids), - len(self.chunked_req.origin_input_ids), - ) - else: - self.send_kv_chunk(self.chunked_req) - # chunked request keeps its rid but will get a new req_pool_idx + chunked_req_to_exclude = set() + if self.chunked_req: + chunked_req_to_exclude.add(self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) + if self.enable_overlap: + # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved + self.chunked_req.tmp_end_idx = min( + len(self.chunked_req.fill_ids), + len(self.chunked_req.origin_input_ids), + ) + else: + self.send_kv_chunk(self.chunked_req) + # chunked request keeps its rid but will get a new req_pool_idx + if self.tp_worker.model_runner.mambaish_config is not None: + self.req_to_token_pool.free( + self.chunked_req.req_pool_idx, free_mamba_cache=False + ) + else: self.req_to_token_pool.free(self.chunked_req.req_pool_idx) + self.running_batch.batch_is_full = False + if self.last_batch and self.last_batch.forward_mode.is_extend(): + if self.last_batch.chunked_req: + # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. + # We need to discard it. + chunked_req_to_exclude.add(self.last_batch.chunked_req) + + last_bs = self.last_batch.batch_size() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) + if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False def send_kv_chunk(