diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index a21f07640db4..d476403130b4 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -28,6 +28,8 @@ from sglang.srt.distributed import get_pp_group from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, get_attention_dp_rank, get_attention_dp_size, get_attention_tp_rank, @@ -83,8 +85,18 @@ def __init__( self.dist_init_addr = server_args.dist_init_addr self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() + self.attn_cp_size = get_attention_cp_size() + self.attn_cp_rank = get_attention_cp_rank() self.attn_dp_size = get_attention_dp_size() self.attn_dp_rank = get_attention_dp_rank() + # For MLA + CP prefill, a single authoritative CP rank per attn-TP rank + # is enough because KV is already rebuilt to full sequence on each CP rank. + self.skip_register_prefill = ( + disaggregation_mode == DisaggregationMode.PREFILL + and is_mla_backend + and self.attn_cp_size > 1 + and self.attn_cp_rank != 0 + ) self.system_dp_size = ( 1 if server_args.enable_dp_attention else server_args.dp_size ) @@ -108,7 +120,8 @@ def __init__( self.failure_lock = threading.Lock() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.register_to_bootstrap() + if not self.skip_register_prefill: + self.register_to_bootstrap() self.transfer_infos = {} self.decode_kv_args_table = {} self.pp_group = get_pp_group() @@ -258,7 +271,6 @@ def register_to_bootstrap(self): "kv_cache_dtype": self.server_args.kv_cache_dtype, "load_balance_method": self.server_args.load_balance_method, } - try: response = requests.put(url, json=payload, timeout=5) if response.status_code == 200: @@ -342,9 +354,14 @@ def __init__( self.bootstrap_server_url = bootstrap_addr # inner state self.curr_idx = 0 - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + initial_status = KVPoll.Bootstrapping + if self.kv_mgr.skip_register_prefill: + # Non-authoritative CP ranks are dummy participants. + initial_status = KVPoll.WaitingForInput + self.kv_mgr.update_status(self.bootstrap_room, initial_status) if ( - self.kv_mgr.server_args.dp_size > 1 + not self.kv_mgr.skip_register_prefill + and self.kv_mgr.server_args.dp_size > 1 and self.kv_mgr.server_args.load_balance_method != "follow_bootstrap_room" ): self._register_prefill_dp_rank() diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index cea0a822ef82..8b5f8dabcd96 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -891,8 +891,11 @@ def transfer_worker( endpoint, dst_port, room, status, local_rank ) else: - # Dummy request means the decode instance is not used, so its status can be marked as success directly - # Dummy request does not need to sync status to decode endpoint + # Dummy request means the decode instance is not used, + # so its status can be marked as success directly. + # Keep the membership guard to avoid resurrecting a + # request that has already been cleared concurrently. + # Dummy request does not need to sync status to decode endpoint. if kv_chunk.is_last and req.room in self.request_status: self.update_status(req.room, KVPoll.Success) @@ -1038,12 +1041,12 @@ def add_transfer_request( bootstrap_room: int, kv_indices: npt.NDArray[np.int32], index_slice: slice, - is_last: bool, + is_last_chunk: bool, aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL - assert not is_last or (is_last and aux_index is not None) + assert not is_last_chunk or (is_last_chunk and aux_index is not None) if ( bootstrap_room not in self.request_status @@ -1057,7 +1060,11 @@ def add_transfer_request( if bootstrap_room not in self.transfer_infos: # This means that the current rank is a dummy rank for this request, # and it has already been marked as success, so there is no need to - # add further chunks into the transfer queue. + # add further chunks into the transfer queue. We still guard on + # request_status membership to avoid recreating terminal status for + # requests that have already been cleared on this rank. + if is_last_chunk and bootstrap_room in self.request_status: + self.update_status(bootstrap_room, KVPoll.Success) return # NOTE(shangming): sharding according to the dst_infos to make sure @@ -1072,7 +1079,7 @@ def add_transfer_request( room=bootstrap_room, prefill_kv_indices=kv_indices, index_slice=index_slice, - is_last=is_last, + is_last=is_last_chunk, prefill_aux_index=aux_index, state_indices=state_indices, ) @@ -1134,21 +1141,21 @@ def send( ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) - is_last = self.curr_idx == self.num_kv_indices + is_last_chunk = self.curr_idx == self.num_kv_indices - if not is_last: + if not is_last_chunk: self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, - False, + is_last_chunk=False, ) else: self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, - True, + is_last_chunk=is_last_chunk, aux_index=self.aux_index, state_indices=state_indices, ) diff --git a/python/sglang/srt/disaggregation/mori/conn.py b/python/sglang/srt/disaggregation/mori/conn.py index 9e0d2d5b91a4..4e8a52d0e613 100644 --- a/python/sglang/srt/disaggregation/mori/conn.py +++ b/python/sglang/srt/disaggregation/mori/conn.py @@ -798,9 +798,16 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL transfer_infos = self.transfer_infos.get(bootstrap_room) if not transfer_infos: - raise RuntimeError( - f"No transfer info found for bootstrap_room={bootstrap_room}" - ) + if not self.skip_register_prefill: + raise RuntimeError( + f"No transfer info found for bootstrap_room={bootstrap_room}" + ) + # Non-authoritative CP ranks can have no transfer_infos for this room. + # If this is the last chunk, finish the local sender as a no-op. + if is_last and bootstrap_room in self.request_status: + self.update_status(bootstrap_room, KVPoll.Success) + return [], [] + return [], None result_statuses = [] target_infos_snapshot: Optional[List[TransferInfo]] = None with self.transfer_lock: diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 279dfab90b73..c0f500e79bfc 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -734,6 +734,17 @@ def add_transfer_request( assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) + if bootstrap_room not in self.transfer_infos: + if not self.skip_register_prefill: + raise RuntimeError( + f"No transfer info found for bootstrap_room={bootstrap_room}" + ) + # Non-authoritative CP ranks can have no transfer_infos for this room. + # If this is the last chunk, finish the local sender as a no-op. + if is_last and bootstrap_room in self.request_status: + self.update_status(bootstrap_room, KVPoll.Success) + return [] + reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() handles = [] for req in reqs_to_be_processed: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 0ad5d6f19024..0eb892b087fe 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, List, Optional import torch +import torch.distributed as dist from sglang.srt.disaggregation.base import KVPoll from sglang.srt.disaggregation.common.conn import CommonKVManager @@ -49,14 +50,35 @@ from sglang.srt.observability.req_time_stats import set_schedule_time_batch if TYPE_CHECKING: - from torch.distributed import ProcessGroup - from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler from sglang.srt.mem_cache.memory_pool import KVCache logger = logging.getLogger(__name__) +def poll_and_all_reduce_attn_groups( + pollers, + attn_cp_cpu_group: dist.ProcessGroup, + attn_tp_cpu_group: Optional[dist.ProcessGroup] = None, +): + # First sync across CP ranks (same attn_tp shard). + polls = poll_and_all_reduce(pollers, attn_cp_cpu_group) + # Then sync across attn-TP ranks so all TPxCP participants in one DP shard + # observe the same status transitions. + if ( + attn_tp_cpu_group is not None + and dist.get_world_size(group=attn_tp_cpu_group) > 1 + ): + tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") + dist.all_reduce( + tensor_to_reduce, + op=dist.ReduceOp.MIN, + group=attn_tp_cpu_group, + ) + polls = tensor_to_reduce.tolist() + return polls + + def release_req_to_metadata_buffer( req: Req, allocator: ReqToMetadataIdxAllocator ) -> None: @@ -93,7 +115,7 @@ def __init__( tp_size: int, gpu_id: int, bootstrap_port: int, - gloo_group: ProcessGroup, + gloo_group: dist.ProcessGroup, max_total_num_tokens: int, decode_tp_size: int, decode_dp_size: int, @@ -269,8 +291,10 @@ def pop_bootstrapped( else: return [], [] - polls = poll_and_all_reduce( - [req.disagg_kv_sender for req in self.queue], self.gloo_group + polls = poll_and_all_reduce_attn_groups( + [req.disagg_kv_sender for req in self.queue], + self.scheduler.attn_cp_cpu_group, + self.scheduler.attn_tp_cpu_group, ) for i, (req, poll) in enumerate(zip(self.queue, polls)): @@ -553,13 +577,13 @@ def process_disagg_prefill_inflight_queue( done_reqs = [] - polls = poll_and_all_reduce( + polls = poll_and_all_reduce_attn_groups( [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], + self.attn_cp_cpu_group, self.attn_tp_cpu_group, ) undone_reqs: List[Req] = [] - # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue for req, poll in zip(self.disagg_prefill_inflight_queue, polls): if rids_to_check is not None: @@ -621,8 +645,9 @@ def get_transferred_rids(self: Scheduler) -> List[str]: """ Used by PP, get the transferred rids but **do not pop** """ - polls = poll_and_all_reduce( + polls = poll_and_all_reduce_attn_groups( [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], + self.attn_cp_cpu_group, self.attn_tp_cpu_group, ) @@ -740,4 +765,5 @@ def send_kv_chunk( f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" ) return + req.disagg_kv_sender.send(page_indices, state_indices)