diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index a21f07640db4..3569e68054c6 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -28,6 +28,8 @@ from sglang.srt.distributed import get_pp_group from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, get_attention_dp_rank, get_attention_dp_size, get_attention_tp_rank, @@ -48,6 +50,7 @@ @dataclasses.dataclass class PrefillServerInfo: attn_tp_size: int + attn_cp_size: int dp_size: int pp_size: int page_size: Optional[int] @@ -56,6 +59,7 @@ class PrefillServerInfo: def __post_init__(self): self.attn_tp_size = int(self.attn_tp_size) + self.attn_cp_size = int(self.attn_cp_size) self.dp_size = int(self.dp_size) self.pp_size = int(self.pp_size) self.page_size = int(self.page_size) if self.page_size is not None else None @@ -65,6 +69,16 @@ def __post_init__(self): self.follow_bootstrap_room = bool(self.follow_bootstrap_room) +@dataclasses.dataclass +class PrefillRankInfo: + rank_ip: str + rank_port: int + + def __post_init__(self): + self.rank_ip = str(self.rank_ip) + self.rank_port = int(self.rank_port) + + class CommonKVManager(BaseKVManager): def __init__( self, @@ -83,6 +97,8 @@ def __init__( self.dist_init_addr = server_args.dist_init_addr self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() + self.attn_cp_size = get_attention_cp_size() + self.attn_cp_rank = get_attention_cp_rank() self.attn_dp_size = get_attention_dp_size() self.attn_dp_rank = get_attention_dp_rank() self.system_dp_size = ( @@ -108,7 +124,12 @@ def __init__( self.failure_lock = threading.Lock() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.register_to_bootstrap() + # TODO(shangming): Fix me when we support MHA/GQA + CP, or when we utilize all cp ranks for KV transfer in CP mode. + self.is_dummy_cp_rank = ( + is_mla_backend and self.attn_cp_size > 1 and self.attn_cp_rank != 0 + ) + if not self.is_dummy_cp_rank: + self.register_to_bootstrap() self.transfer_infos = {} self.decode_kv_args_table = {} self.pp_group = get_pp_group() @@ -211,7 +232,7 @@ def _fetch_prefill_server_info( ) -> Optional[PrefillServerInfo]: """Fetch the prefill server info from the bootstrap server.""" try: - url = f"http://{bootstrap_addr}/route?engine_rank={-1}&prefill_dp_rank={-1}&target_pp_rank={-1}" + url = f"http://{bootstrap_addr}/route?prefill_dp_rank={-1}&prefill_cp_rank={-1}&target_tp_rank={-1}&target_pp_rank={-1}" response = requests.get(url, timeout=5) if response.status_code == 200: data = response.json() @@ -246,6 +267,8 @@ def register_to_bootstrap(self): payload = { "attn_tp_size": self.attn_tp_size, "attn_tp_rank": self.attn_tp_rank, + "attn_cp_size": self.attn_cp_size, + "attn_cp_rank": self.attn_cp_rank, "attn_dp_size": self.attn_dp_size, "attn_dp_rank": self.attn_dp_rank, "pp_size": self.pp_size, @@ -342,6 +365,11 @@ def __init__( self.bootstrap_server_url = bootstrap_addr # inner state self.curr_idx = 0 + if self.kv_mgr.is_dummy_cp_rank: + # Non-authoritative CP ranks are dummy participants. + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) + return + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) if ( self.kv_mgr.server_args.dp_size > 1 @@ -412,15 +440,13 @@ def __init__( self.prefill_info = self.kv_mgr.prefill_info_table[self.bootstrap_addr] - # Handling for PD with different TP sizes per DP rank + # Rank mapping for PD with different TP sizes per rank for target DP/CP group if self.kv_mgr.attn_tp_size == self.prefill_info.attn_tp_size: self.target_tp_rank = ( self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size ) self.required_dst_info_num = 1 - self.required_prefill_response_num = 1 * ( - self.prefill_info.pp_size // self.kv_mgr.pp_size - ) + self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] elif self.kv_mgr.attn_tp_size > self.prefill_info.attn_tp_size: if not self.kv_mgr.is_mla_backend: @@ -433,9 +459,7 @@ def __init__( self.required_dst_info_num = ( self.kv_mgr.attn_tp_size // self.prefill_info.attn_tp_size ) - self.required_prefill_response_num = 1 * ( - self.prefill_info.pp_size // self.kv_mgr.pp_size - ) + self.required_prefill_response_num = 1 self.target_tp_ranks = [self.target_tp_rank] else: if not self.kv_mgr.is_mla_backend: @@ -459,13 +483,35 @@ def __init__( self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 if self.kv_mgr.is_mla_backend: - self.required_prefill_response_num = ( - self.prefill_info.pp_size // self.kv_mgr.pp_size - ) + self.required_prefill_response_num = 1 else: self.required_prefill_response_num = ( self.prefill_info.attn_tp_size // self.kv_mgr.attn_tp_size - ) * (self.prefill_info.pp_size // self.kv_mgr.pp_size) + ) + + # Decode cp size should be equal to prefill cp size or 1 + assert ( + self.kv_mgr.attn_cp_size == self.prefill_info.attn_cp_size + or self.kv_mgr.attn_cp_size == 1 + ), ( + f"Decode cp size ({self.kv_mgr.attn_cp_size}) should be equal to prefill cp size ({self.prefill_info.attn_cp_size}) or 1", + ) + if self.kv_mgr.attn_cp_size == self.prefill_info.attn_cp_size: + self.target_cp_ranks = [self.kv_mgr.attn_cp_rank] + else: + self.target_cp_ranks = [ + rank for rank in range(self.prefill_info.attn_cp_size) + ] + # TODO(shangming): Support KVCache transfer for multiple prefill cp ranks -> 1 decode cp rank + # For now, we handle the control plane in advance, but we need to support the data plane in the future. + if self.kv_mgr.is_mla_backend: + # For MLA: we only need to retrieve KVCache from the first CP rank now + self.target_cp_ranks = self.target_cp_ranks[:1] + self.required_prefill_response_num *= 1 + else: + self.required_prefill_response_num *= ( + self.prefill_info.attn_cp_size // self.kv_mgr.attn_cp_size + ) # Decode pp size should be equal to prefill pp size or 1 assert ( @@ -477,6 +523,9 @@ def __init__( self.target_pp_ranks = [self.kv_mgr.pp_rank] else: self.target_pp_ranks = [rank for rank in range(self.prefill_info.pp_size)] + self.required_prefill_response_num *= ( + self.prefill_info.pp_size // self.kv_mgr.pp_size + ) self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = ( self.required_prefill_response_num @@ -489,57 +538,65 @@ def __init__( self._setup_bootstrap_infos() def _setup_bootstrap_infos(self): - # NOTE: key distinguished by bootstrap_addr, prefill_dp_rank, and target_tp_rank - bootstrap_key = ( - f"{self.bootstrap_addr}_{self.prefill_dp_rank}_{self.target_tp_rank}" - ) - - if bootstrap_key not in self.kv_mgr.connection_pool: - bootstrap_infos = [] - for target_tp_rank in self.target_tp_ranks: - # Enable higher PP ranks to be bootstrapped earlier to make PP PD requests bootstrap more robust - for target_pp_rank in reversed(self.target_pp_ranks): - bootstrap_info = self._get_bootstrap_info_from_server( - target_tp_rank, self.prefill_dp_rank, target_pp_rank - ) - if bootstrap_info is not None: - if self.kv_mgr.is_mla_backend: - # For MLA: target_tp_rank is the selected real rank, others are dummy ranks - bootstrap_info["is_dummy"] = not bool( - target_tp_rank == self.target_tp_rank - or self.target_tp_rank is None + all_bootstrap_infos = [] + # NOTE: key distinguished by bootstrap_addr, prefill_dp_rank, prefill_cp_rank, and target_tp_rank + for target_cp_rank in self.target_cp_ranks: + bootstrap_key = f"{self.bootstrap_addr}_{self.prefill_dp_rank}_{target_cp_rank}_{self.target_tp_rank}" + + if bootstrap_key not in self.kv_mgr.connection_pool: + bootstrap_infos = [] + for target_tp_rank in self.target_tp_ranks: + # Enable higher PP ranks to be bootstrapped earlier to make PP PD requests bootstrap more robust + for target_pp_rank in reversed(self.target_pp_ranks): + bootstrap_info = self._get_bootstrap_info_from_server( + self.prefill_dp_rank, + target_cp_rank, + target_tp_rank, + target_pp_rank, + ) + if bootstrap_info is not None: + if self.kv_mgr.is_mla_backend: + # For MLA: target_tp_rank is the selected real rank, others are dummy ranks + bootstrap_info["is_dummy"] = not bool( + target_tp_rank == self.target_tp_rank + or self.target_tp_rank is None + ) + else: + # For non-MLA: all target_tp_ranks are selected real ranks + bootstrap_info["is_dummy"] = False + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for DP {self.prefill_dp_rank} CP {target_cp_rank} TP {target_tp_rank} PP {target_pp_rank}" ) + bootstrap_infos.append(bootstrap_info) else: - # For non-MLA: all target_tp_ranks are selected real ranks - bootstrap_info["is_dummy"] = False - logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for DP {self.prefill_dp_rank} TP {target_tp_rank} PP {target_pp_rank}" - ) - bootstrap_infos.append(bootstrap_info) - else: - self.kv_mgr.record_failure( - self.bootstrap_room, - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and prefill_dp_rank: {self.prefill_dp_rank} and target_pp_rank {target_pp_rank}", - ) - self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) - return + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch bootstrap info for: prefill_dp_rank: {self.prefill_dp_rank} prefill_cp_rank: {target_cp_rank} target_tp_rank: {target_tp_rank} and target_pp_rank {target_pp_rank}", + ) + self.kv_mgr.update_status( + self.bootstrap_room, KVPoll.Failed + ) + return - self.bootstrap_infos = bootstrap_infos - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos + self.bootstrap_infos = bootstrap_infos + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos - # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server - self._register_kv_args() - else: - self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] + # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server + self._register_kv_args() + else: + self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] + + assert len(self.bootstrap_infos) > 0 + all_bootstrap_infos.extend(self.bootstrap_infos) - assert len(self.bootstrap_infos) > 0 + self.bootstrap_infos = all_bootstrap_infos def _get_bootstrap_info_from_server( - self, engine_rank, prefill_dp_rank, target_pp_rank + self, prefill_dp_rank, prefill_cp_rank, target_tp_rank, target_pp_rank ): """Fetch the bootstrap info from the bootstrap server.""" try: - url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&prefill_dp_rank={prefill_dp_rank}&target_pp_rank={target_pp_rank}" + url = f"http://{self.bootstrap_addr}/route?prefill_dp_rank={prefill_dp_rank}&prefill_cp_rank={prefill_cp_rank}&target_tp_rank={target_tp_rank}&target_pp_rank={target_pp_rank}" response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() @@ -615,12 +672,13 @@ def __init__(self, host: str, port: int, dp_size: int = 1): self._setup_routes() self.pp_size = None self.attn_tp_size = None + self.attn_cp_size = None self.dp_size = dp_size self.page_size = None self.kv_cache_dtype: Optional[str] = None self.follow_bootstrap_room: Optional[bool] = None self.prefill_port_table: Dict[ - int, Dict[int, Dict[int, Dict[str, Union[str, int]]]] + int, Dict[int, Dict[int, Dict[int, PrefillRankInfo]]] ] = {} self.room_to_dp_rank: Dict[int, Dict[str, Union[int, float]]] = {} self._registered_count = 0 @@ -667,6 +725,8 @@ async def _handle_route_put(self, request: web.Request): data = await request.json() attn_tp_size = data["attn_tp_size"] attn_tp_rank = data["attn_tp_rank"] + attn_cp_size = data["attn_cp_size"] + attn_cp_rank = data["attn_cp_rank"] attn_dp_size = data["attn_dp_size"] attn_dp_rank = data["attn_dp_rank"] pp_size = data["pp_size"] @@ -681,6 +741,9 @@ async def _handle_route_put(self, request: web.Request): if self.attn_tp_size is None: self.attn_tp_size = attn_tp_size + if self.attn_cp_size is None: + self.attn_cp_size = attn_cp_size + if self.dp_size is None: self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size @@ -706,35 +769,42 @@ async def _handle_route_put(self, request: web.Request): # Add lock to make sure thread-safe async with self.lock: - if dp_group not in self.prefill_port_table: - self.prefill_port_table[dp_group] = {} - if attn_tp_rank not in self.prefill_port_table[dp_group]: - self.prefill_port_table[dp_group][attn_tp_rank] = {} - - self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = { - "rank_ip": rank_ip, - "rank_port": rank_port, - } + dp_group_table = self.prefill_port_table.setdefault(dp_group, {}) + cp_group_table = dp_group_table.setdefault(attn_cp_rank, {}) + tp_group_table = cp_group_table.setdefault(attn_tp_rank, {}) + + tp_group_table[pp_rank] = PrefillRankInfo( + rank_ip=rank_ip, + rank_port=rank_port, + ) + self._registered_count += 1 - expected = self.dp_size * self.attn_tp_size * self.pp_size + expected = self.dp_size * self.attn_cp_size * self.attn_tp_size * self.pp_size logger.debug( - f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register prefill bootstrap: DP{dp_group} CP{attn_cp_rank} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f" ({self._registered_count}/{expected} registered)" ) return web.Response(text="OK", status=200) async def _handle_route_get(self, request: web.Request): - engine_rank = request.query.get("engine_rank") prefill_dp_rank = request.query.get("prefill_dp_rank") + prefill_cp_rank = request.query.get("prefill_cp_rank") + target_tp_rank = request.query.get("target_tp_rank") target_pp_rank = request.query.get("target_pp_rank") - if not engine_rank or not prefill_dp_rank or not target_pp_rank: + if ( + not prefill_dp_rank + or not prefill_cp_rank + or not target_tp_rank + or not target_pp_rank + ): return web.Response(text="Missing inputs for bootstrap server.", status=400) if ( - int(engine_rank) == -1 - and int(prefill_dp_rank) == -1 + int(prefill_dp_rank) == -1 + and int(prefill_cp_rank) == -1 + and int(target_tp_rank) == -1 and int(target_pp_rank) == -1 ): if not self._is_ready(): @@ -745,6 +815,7 @@ async def _handle_route_get(self, request: web.Request): ) info = PrefillServerInfo( attn_tp_size=self.attn_tp_size, + attn_cp_size=self.attn_cp_size, dp_size=self.dp_size, pp_size=self.pp_size, page_size=self.page_size, @@ -768,16 +839,16 @@ async def _handle_route_get(self, request: web.Request): try: async with self.lock: bootstrap_info = self.prefill_port_table[int(prefill_dp_rank)][ - int(engine_rank) - ][int(target_pp_rank)] + int(prefill_cp_rank) + ][int(target_tp_rank)][int(target_pp_rank)] except KeyError: return web.Response( - text=f"Bootstrap info not found for dp_rank={prefill_dp_rank} " - f"engine_rank={engine_rank} pp_rank={target_pp_rank}", + text=f"Bootstrap info not found for dp_rank={prefill_dp_rank} cp_rank={prefill_cp_rank} " + f"tp_rank={target_tp_rank} pp_rank={target_pp_rank}", status=404, ) - return web.json_response(bootstrap_info, status=200) + return web.json_response(dataclasses.asdict(bootstrap_info), status=200) async def _handle_register_dp_rank(self, request: web.Request): data = await request.json() diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index cea0a822ef82..a223a2578ea2 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -53,7 +53,7 @@ class TransferKVChunk: room: int prefill_kv_indices: npt.NDArray[np.int32] index_slice: slice - is_last: bool + is_last_chunk: bool prefill_aux_index: Optional[int] state_indices: Optional[List[int]] @@ -861,7 +861,7 @@ def transfer_worker( ) break - if kv_chunk.is_last: + if kv_chunk.is_last_chunk: if kv_chunk.state_indices is not None: self.maybe_send_extra( req, @@ -893,7 +893,7 @@ def transfer_worker( else: # Dummy request means the decode instance is not used, so its status can be marked as success directly # Dummy request does not need to sync status to decode endpoint - if kv_chunk.is_last and req.room in self.request_status: + if kv_chunk.is_last_chunk and req.room in self.request_status: self.update_status(req.room, KVPoll.Success) if ( @@ -1038,12 +1038,12 @@ def add_transfer_request( bootstrap_room: int, kv_indices: npt.NDArray[np.int32], index_slice: slice, - is_last: bool, + is_last_chunk: bool, aux_index: Optional[int] = None, state_indices: Optional[List[int]] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL - assert not is_last or (is_last and aux_index is not None) + assert not is_last_chunk or (is_last_chunk and aux_index is not None) if ( bootstrap_room not in self.request_status @@ -1072,7 +1072,7 @@ def add_transfer_request( room=bootstrap_room, prefill_kv_indices=kv_indices, index_slice=index_slice, - is_last=is_last, + is_last_chunk=is_last_chunk, prefill_aux_index=aux_index, state_indices=state_indices, ) @@ -1134,9 +1134,16 @@ def send( ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) - is_last = self.curr_idx == self.num_kv_indices + is_last_chunk = self.curr_idx == self.num_kv_indices - if not is_last: + if self.kv_mgr.is_dummy_cp_rank: + if not is_last_chunk: + return + else: + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success) + return + + if not is_last_chunk: self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 0ad5d6f19024..e8595ad6c469 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -39,7 +39,7 @@ is_mla_backend, kv_to_page_indices, kv_to_page_num, - poll_and_all_reduce, + poll_and_all_reduce_attn_cp_tp_group, prepare_abort, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch @@ -269,8 +269,10 @@ def pop_bootstrapped( else: return [], [] - polls = poll_and_all_reduce( - [req.disagg_kv_sender for req in self.queue], self.gloo_group + polls = poll_and_all_reduce_attn_cp_tp_group( + [req.disagg_kv_sender for req in self.queue], + self.scheduler.attn_cp_cpu_group, + self.scheduler.attn_tp_cpu_group, ) for i, (req, poll) in enumerate(zip(self.queue, polls)): @@ -553,8 +555,9 @@ def process_disagg_prefill_inflight_queue( done_reqs = [] - polls = poll_and_all_reduce( + polls = poll_and_all_reduce_attn_cp_tp_group( [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], + self.attn_cp_cpu_group, self.attn_tp_cpu_group, ) @@ -621,8 +624,9 @@ def get_transferred_rids(self: Scheduler) -> List[str]: """ Used by PP, get the transferred rids but **do not pop** """ - polls = poll_and_all_reduce( + polls = poll_and_all_reduce_attn_cp_tp_group( [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], + self.attn_cp_cpu_group, self.attn_tp_cpu_group, ) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 0e6b12f312a7..fe1c71294572 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -44,7 +44,7 @@ class DisaggregationMode(Enum): FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) -def poll_and_all_reduce(pollers, gloo_group): +def poll_and_all_reduce(pollers, gloo_group: dist.ProcessGroup): # at a certain prob, the poll is failed to simulate failure if FAILURE_PROB > 0: from sglang.srt.disaggregation.base import KVPoll @@ -60,6 +60,26 @@ def poll_and_all_reduce(pollers, gloo_group): return tensor_to_reduce.tolist() +def poll_and_all_reduce_attn_cp_tp_group( + pollers, + attn_cp_cpu_group: dist.ProcessGroup, + attn_tp_cpu_group: dist.ProcessGroup, +): + # First sync across attn-tp ranks so all TP participants for a given (dp, cp) + # shard observe the same status transitions. + polls = poll_and_all_reduce(pollers, attn_tp_cpu_group) + + # Then sync across attn-cp ranks, so all TPxCP participants in one DP shard + # converge to the same global status. + tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") + dist.all_reduce( + tensor_to_reduce, + op=dist.ReduceOp.MIN, + group=attn_cp_cpu_group, + ) + return tensor_to_reduce.tolist() + + ######################### # Metadata Buffers ######################### diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 9a4396f2f8b8..90b13576b30b 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -13,7 +13,7 @@ from tqdm import tqdm from sglang.srt.disaggregation.base.conn import KVPoll -from sglang.srt.disaggregation.utils import poll_and_all_reduce +from sglang.srt.disaggregation.utils import poll_and_all_reduce_attn_cp_tp_group from sglang.srt.distributed.parallel_state import P2PWork from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( @@ -1082,8 +1082,9 @@ def get_rids( """ Used by PP, get the required rids with the given poll statuses. """ - polls = poll_and_all_reduce( + polls = poll_and_all_reduce_attn_cp_tp_group( [req.disagg_kv_sender if is_send else req.kv_receiver for req in req_queue], + self.attn_cp_cpu_group, self.attn_tp_cpu_group, ) rids: List = []