diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index a2ffcbcc83a0..04bfa28e22ed 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -248,6 +248,7 @@ def __init__( group_name: Optional[str] = None, pynccl_use_current_stream: bool = False, gloo_timeout: timedelta = timedelta(seconds=120 * 60), + extend_group: bool = False, ): # Set group info group_name = group_name or "anonymous" @@ -282,15 +283,19 @@ def __init__( if "mooncake" in torch_distributed_backend: from mooncake.ep import MooncakeBackendOptions + if len(ranks) == 1: + # Bypass the extension procedure + extend_group = False + device_group = torch.distributed.new_group( ranks, backend="mooncake", - pg_options=MooncakeBackendOptions(active_ranks), + pg_options=MooncakeBackendOptions(active_ranks, extend_group), ) cpu_group = torch.distributed.new_group( ranks, backend="mooncake-cpu", - pg_options=MooncakeBackendOptions(active_ranks_cpu), + pg_options=MooncakeBackendOptions(active_ranks_cpu, extend_group), ) else: pg_options = get_torch_distributed_pg_options(group_name) @@ -435,7 +440,11 @@ def __init__( ) self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: + if ( + use_message_queue_broadcaster + and self.world_size > 1 + and not extend_group + ): self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6 ) @@ -1413,7 +1422,7 @@ def get_world_group() -> GroupCoordinator: def init_world_group( - ranks: List[int], local_rank: int, backend: str + ranks: List[int], local_rank: int, backend: str, extend_group: bool = False ) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], @@ -1427,6 +1436,7 @@ def init_world_group( use_xpu_communicator=False, use_npu_communicator=False, group_name="world", + extend_group=extend_group, ) @@ -1441,6 +1451,7 @@ def init_model_parallel_group( use_mscclpp_allreduce: Optional[bool] = None, pynccl_use_current_stream: bool = True, use_torch_symm_mem_allreduce: Optional[bool] = None, + extend_group: bool = False, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -1466,6 +1477,7 @@ def init_model_parallel_group( use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, pynccl_use_current_stream=pynccl_use_current_stream, + extend_group=extend_group, ) @@ -1675,6 +1687,7 @@ def init_distributed_environment( backend: str = "nccl", timeout: Optional[int] = None, moe_a2a_backend: Optional[str] = None, + extend_group: bool = False, ): logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", @@ -1705,7 +1718,10 @@ def init_distributed_environment( assert timeout > 0, "timeout must be positive" timeout = timedelta(seconds=timeout) - pg_options = get_torch_distributed_pg_options() + from mooncake.ep import MooncakeBackendOptions + + active_ranks = torch.ones(world_size, dtype=torch.int32, device="cuda") + pg_options = MooncakeBackendOptions(active_ranks, extend_group) # this backend is used for WORLD torch.distributed.init_process_group( @@ -1734,7 +1750,7 @@ def init_distributed_environment( global _WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) - _WORLD = init_world_group(ranks, local_rank, backend) + _WORLD = init_world_group(ranks, local_rank, backend, extend_group=extend_group) else: assert ( _WORLD.world_size == torch.distributed.get_world_size() @@ -1750,6 +1766,7 @@ def initialize_model_parallel( moe_data_model_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, + extend_group: bool = False, ) -> None: """ Initialize model parallel groups. @@ -1831,6 +1848,7 @@ def initialize_model_parallel( ), group_name="tp", pynccl_use_current_stream=duplicate_tp_group, + extend_group=extend_group, ) if duplicate_tp_group: @@ -1847,6 +1865,7 @@ def initialize_model_parallel( ), group_name="pdmux_prefill_tp", pynccl_use_current_stream=True, + extend_group=extend_group, ) if _TP.pynccl_comm: _TP.pynccl_comm.disabled = False @@ -1884,6 +1903,7 @@ def initialize_model_parallel( get_world_group().local_rank, backend, group_name="attn_cp", + extend_group=extend_group, ) from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP @@ -1917,6 +1937,7 @@ def initialize_model_parallel( use_custom_allreduce=False, use_torch_symm_mem_allreduce=False, group_name="attention_tp", + extend_group=extend_group, ) moe_ep_size = expert_model_parallel_size @@ -1943,6 +1964,7 @@ def initialize_model_parallel( get_world_group().local_rank, backend, group_name="moe_dp", + extend_group=extend_group, ) global _MOE_EP @@ -1968,6 +1990,7 @@ def initialize_model_parallel( get_world_group().local_rank, backend, group_name="moe_ep", + extend_group=extend_group, ) global _MOE_TP @@ -1994,6 +2017,7 @@ def initialize_model_parallel( get_world_group().local_rank, backend, group_name="moe_tp", + extend_group=extend_group, ) # Build the pipeline model-parallel groups. @@ -2013,6 +2037,7 @@ def initialize_model_parallel( backend, use_custom_allreduce=False, group_name="pp", + extend_group=extend_group, ) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 8f31fe4c7a3d..db7efdd02385 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -1,13 +1,19 @@ from __future__ import annotations +import logging +from contextlib import contextmanager from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch +from sglang.srt.distributed import parallel_state +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.managers.schedule_batch import ServerArgs from sglang.srt.utils import is_cpu, is_cuda +logger = logging.getLogger(__name__) + @dataclass class ElasticEPState: @@ -43,6 +49,18 @@ def init(cls, server_args: ServerArgs): cls._instance = cls._build_state(ep_size=None, device=None) return cls._instance + @classmethod + def reset(cls, server_args: ServerArgs): + from sglang.srt.layers.moe import MoeA2ABackend + from sglang.srt.layers.moe.token_dispatcher.mooncake import EPBuffer + + moe_a2a_backend = MoeA2ABackend(server_args.moe_a2a_backend) + if moe_a2a_backend.is_mooncake(): + EPBuffer.mark_ep_member_refresh_needed() + cls._instance.active_ranks.fill_(1) + cls._instance.snapshot_active_to_last() + cls._instance.sync_active_to_cpu() + @staticmethod def _select_device() -> torch.device: if is_cuda(): @@ -71,3 +89,137 @@ def healthy_rank_state( dev = device if device is not None else cls._select_device() return torch.ones(size, dtype=torch.int32, device=dev) + + +def _check_peer_state_for_group(group, ranks_in_group): + """Check peer state for a specific group.""" + from mooncake import ep as mooncake_ep + + if not ranks_in_group: + return None + + # Check device group + device_backend = group.device_group._get_backend(torch.device("cuda")) + device_peer_state = mooncake_ep.get_peer_state(device_backend, ranks_in_group) + + # Check cpu group + cpu_backend = group.cpu_group._get_backend(torch.device("cpu")) + cpu_peer_state = mooncake_ep.get_peer_state(cpu_backend, ranks_in_group) + + return {"device": device_peer_state, "cpu": cpu_peer_state} + + +def try_recover_ranks(global_ranks: List[int]): + from mooncake import ep as mooncake_ep + + logger.info(f"Trying to recover ranks: {global_ranks}") + + backend = torch.distributed.group.WORLD._get_backend(torch.device("cuda")) + peer_state = mooncake_ep.get_peer_state(backend, global_ranks) + + if not all(peer_state): + logger.info("Early return from try_recover_ranks") + return False + + groups = parallel_state._groups + groups = {group_name: groups[group_name] for group_name in groups} + + for group_name in groups: + group = groups[group_name]() + if group is not None: + group_ranks = group.ranks + ranks_in_group = [r for r in global_ranks if r in group_ranks] + + if ranks_in_group: + group_rank_mapping = { + global_rank: idx for idx, global_rank in enumerate(group_ranks) + } + group_local_ranks = [ + group_rank_mapping[r] + for r in ranks_in_group + if r in group_rank_mapping + ] + if group_local_ranks: + device_backend = group.device_group._get_backend( + torch.device("cuda") + ) + group_peer_state = mooncake_ep.get_peer_state( + device_backend, group_local_ranks + ) + + cpu_backend = group.cpu_group._get_backend(torch.device("cpu")) + group_peer_state = mooncake_ep.get_peer_state( + cpu_backend, group_local_ranks + ) + + mooncake_ep.recover_ranks(backend, global_ranks) + + device_group_status = {group_name: False for group_name in groups} + cpu_group_status = {group_name: False for group_name in groups} + while True: + for group_name in groups: + group = groups[group_name]() + if group is not None: + group_ranks = group.ranks + ranks_in_group = [r for r in global_ranks if r in group_ranks] + + if ranks_in_group: + # Convert to group-local ranks + group_rank_mapping = { + global_rank: idx for idx, global_rank in enumerate(group_ranks) + } + group_local_ranks = [ + group_rank_mapping[r] + for r in ranks_in_group + if r in group_rank_mapping + ] + + if group_local_ranks: + if not device_group_status[group_name]: + device_backend = group.device_group._get_backend( + torch.device("cuda") + ) + group_peer_state = mooncake_ep.get_peer_state( + device_backend, group_local_ranks + ) + if all(group_peer_state): + mooncake_ep.recover_ranks( + device_backend, group_local_ranks + ) + device_group_status[group_name] = True + + if not cpu_group_status[group_name]: + cpu_backend = group.cpu_group._get_backend( + torch.device("cpu") + ) + group_peer_state = mooncake_ep.get_peer_state( + cpu_backend, group_local_ranks + ) + if all(group_peer_state): + mooncake_ep.recover_ranks( + cpu_backend, group_local_ranks + ) + if ( + group.use_message_queue_broadcaster + and group.world_size > 1 + ): + # Create message queue + from sglang.srt.distributed.device_communicators.shm_broadcast import ( + MessageQueue, + ) + + group.mq_broadcaster = ( + MessageQueue.create_from_process_group( + group.cpu_group, 1 << 22, 6 + ) + ) + cpu_group_status[group_name] = True + else: + device_group_status[group_name] = True + cpu_group_status[group_name] = True + else: + device_group_status[group_name] = True + cpu_group_status[group_name] = True + if all(device_group_status.values()) and all(cpu_group_status.values()): + logger.info(f"Recovered ranks {global_ranks}") + return True diff --git a/python/sglang/srt/eplb/eplb_manager.py b/python/sglang/srt/eplb/eplb_manager.py index e88a3d28e0f3..38f8b07d29da 100644 --- a/python/sglang/srt/eplb/eplb_manager.py +++ b/python/sglang/srt/eplb/eplb_manager.py @@ -41,6 +41,9 @@ def __init__(self, model_runner: "ModelRunner"): def on_forward_pass_end(self): next(self._main_generator) + def reset_generator(self): + self._main_generator = self._entrypoint() + # can be more complex if needed def _entrypoint(self): while True: diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 7bd0254baa5a..adc7c170ab5c 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -319,6 +319,59 @@ def set_global_expert_location_metadata(value): _global_expert_location_metadata = value +def broadcast_global_expert_location_metadata( + src_rank: int = 0, group: Optional[torch.distributed.ProcessGroup] = None +): + """Broadcast the global ExpertLocationMetadata from src_rank to all ranks. + + This is used in Elastic EP rank recovery to ensure that all ranks (including + newly recovered ones) share exactly the same expert location metadata, + especially physical_to_logical_map and related CPU / GPU copies. + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + logger.warning( + "broadcast_global_expert_location_metadata: torch.distributed not initialized, skip broadcast" + ) + return + + metadata = get_global_expert_location_metadata() + if metadata is None: + logger.warning( + "broadcast_global_expert_location_metadata: no global metadata found, skip broadcast" + ) + return + + # Ensure device tensors are contiguous before broadcasting in-place + metadata.physical_to_logical_map = metadata.physical_to_logical_map.contiguous() + metadata.logical_to_all_physical_map = ( + metadata.logical_to_all_physical_map.contiguous() + ) + metadata.logical_to_all_physical_map_num_valid = ( + metadata.logical_to_all_physical_map_num_valid.contiguous() + ) + if metadata.logical_to_rank_dispatch_physical_map is not None: + metadata.logical_to_rank_dispatch_physical_map = ( + metadata.logical_to_rank_dispatch_physical_map.contiguous() + ) + + device_tensors = [ + metadata.physical_to_logical_map, + metadata.logical_to_all_physical_map, + metadata.logical_to_all_physical_map_num_valid, + ] + if metadata.logical_to_rank_dispatch_physical_map is not None: + device_tensors.append(metadata.logical_to_rank_dispatch_physical_map) + + for tensor in device_tensors: + torch.distributed.broadcast(tensor, src=src_rank, group=group) + + # After broadcasting device tensors, refresh corresponding CPU copies + metadata.physical_to_logical_map_cpu = metadata.physical_to_logical_map.cpu() + metadata.logical_to_all_physical_map_cpu = ( + metadata.logical_to_all_physical_map.cpu() + ) + + def _compute_logical_to_all_physical_map( server_args: ServerArgs, physical_to_logical_map: torch.Tensor, diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index f475d69d2dd1..ca33cce75d41 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -57,11 +57,16 @@ def format(self) -> CombineInputFormat: class EPBuffer: + _needs_ep_member_refresh = False _buffer = None _hidden_size: Optional[int] = None _num_max_dispatch_tokens_per_rank: Optional[int] = None _num_experts: Optional[int] = None + @classmethod + def mark_ep_member_refresh_needed(cls): + cls._needs_ep_member_refresh = True + @classmethod def get_ep_buffer( cls, @@ -73,6 +78,9 @@ def get_ep_buffer( num_experts: int = -1, ): if cls._buffer is not None: + if cls._needs_ep_member_refresh: + cls._buffer.update_ep_member() + cls._needs_ep_member_refresh = False return cls._buffer # Lazy import Buffer to avoid creating CUDA context at module import time diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 7aeaafdbbf24..28085242ec67 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -13,6 +13,7 @@ # ============================================================================== """A controller that dispatches requests to multiple data parallel workers.""" +import dataclasses import faulthandler import logging import multiprocessing as mp @@ -20,7 +21,7 @@ import threading import time from enum import Enum, auto -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional, Set import psutil import setproctitle @@ -150,7 +151,10 @@ def __init__( # Launch data parallel workers self.scheduler_procs = [] + self.scheduler_procs_by_global_rank: Dict[int, mp.Process] = {} + self.recovering_global_ranks: Set[int] = set() self.workers: List[zmq.Socket] = [None] * server_args.dp_size + self.worker_ports = [] self.status: List[bool] = [True] * server_args.dp_size if server_args.enable_dp_attention: @@ -182,11 +186,209 @@ def send_control_message(self, obj): for worker in self.workers[:: self.control_message_step]: worker.send_pyobj(obj) + def _global_rank_to_dp_rank(self, global_rank: int) -> int: + tp_rank = global_rank % self.server_args.tp_size + _, _, dp_rank = compute_dp_attention_world_info( + self.server_args.enable_dp_attention, + tp_rank, + self.server_args.tp_size, + self.server_args.dp_size, + self.server_args.attn_cp_size, + ) + return dp_rank + + def _register_scheduler_proc(self, global_rank: int, proc: mp.Process) -> None: + self.scheduler_procs.append(proc) + self.scheduler_procs_by_global_rank[global_rank] = proc + + def maybe_recover_dead_schedulers(self) -> None: + if not self.server_args.enable_dp_attention: + return + + dead_ranks = [] + for global_rank, proc in list(self.scheduler_procs_by_global_rank.items()): + if proc.exitcode is None or global_rank in self.recovering_global_ranks: + continue + + dp_rank = self._global_rank_to_dp_rank(global_rank) + if self.status[dp_rank]: + logger.error( + "Detected dead scheduler for global_rank=%s pid=%s exitcode=%s", + global_rank, + proc.pid, + proc.exitcode, + ) + self.status[dp_rank] = False + dead_ranks.append(global_rank) + + if not dead_ranks: + return + + self.recovering_global_ranks.update(dead_ranks) + try: + self.handle_recover_ranks_req(dead_ranks) + finally: + self.recovering_global_ranks.difference_update(dead_ranks) + + def handle_recover_ranks_req(self, ranks_to_recover: List[int]): + """Handle request to recover specific ranks. + + This method: + 1. Launches new processes for the specified global ranks + 2. Sends a message to all workers to call try_recover_ranks + """ + logger.info(f"Recovering ranks: {ranks_to_recover}") + assert self.server_args.enable_dp_attention + + # Get worker ports if needed + if self.server_args.node_rank == 0: + for rank in ranks_to_recover: + port_and_socket = get_zmq_socket(self.context, zmq.PUSH) + self.worker_ports[rank] = port_and_socket[0] + self.workers[rank] = port_and_socket[1] + + broadcasted_ports = self._broadcast_worker_ports( + self.server_args, self.worker_ports if self.worker_ports else None + ) + + # Launch processes for each rank that needs to be recovered + # We need to determine which ranks belong to this node + pp_size_per_node = max(self.server_args.pp_size // self.server_args.nnodes, 1) + nnodes_per_pp_rank = max(self.server_args.nnodes // self.server_args.pp_size, 1) + pp_rank_range = range( + pp_size_per_node * (self.server_args.node_rank // nnodes_per_pp_rank), + pp_size_per_node * (self.server_args.node_rank // nnodes_per_pp_rank + 1), + ) + + nnodes_per_tp_group = nnodes_per_pp_rank + tp_size_per_node = self.server_args.tp_size // nnodes_per_tp_group + tp_rank_range = range( + tp_size_per_node * (self.server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (self.server_args.node_rank % nnodes_per_tp_group + 1), + ) + + # Launch processes for ranks that belong to this node + for global_rank in ranks_to_recover: + # Compute tp_rank and pp_rank from global_rank + # Assuming pp_size=1 for dp_attention: global_rank = tp_rank + tp_rank = global_rank + pp_rank = 0 # Assuming pp_size=1 + + # Check if this rank belongs to this node + if pp_rank in pp_rank_range and tp_rank in tp_rank_range: + # Compute dp_rank from tp_rank + _, _, dp_rank = compute_dp_attention_world_info( + self.server_args.enable_dp_attention, + tp_rank, + self.server_args.tp_size, + self.server_args.dp_size, + self.server_args.attn_cp_size, + ) + + # Use helper method to compute parallelism info + gpu_id, attn_cp_rank, moe_dp_rank, moe_ep_rank = ( + self._compute_rank_parallel_info( + tp_rank, pp_rank, pp_size_per_node, tp_size_per_node + ) + ) + + # Create port args for this rank + rank_port_args = PortArgs.init_new( + self.server_args, dp_rank, broadcasted_ports + ) + rank_port_args.nccl_port = self.port_args.nccl_port + + # Launch the process + reader, writer = mp.Pipe(duplex=False) + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.enable_memory_saver + ) + + rank_server_args = dataclasses.replace( + self.server_args, mooncake_extend_group=True + ) + + with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id: + proc = mp.Process( + target=self.run_scheduler_process_func, + args=( + rank_server_args, + rank_port_args, + gpu_id, + tp_rank, + attn_cp_rank, + moe_dp_rank, + moe_ep_rank, + pp_rank, + dp_rank, + writer, + ), + ) + with memory_saver_adapter.configure_subprocess(), numa_utils.configure_subprocess( + self.server_args, gpu_id + ): + proc.start() + + self._register_scheduler_proc(global_rank, proc) + logger.info( + f"Launched process for global_rank={global_rank}, tp_rank={tp_rank}, dp_rank={dp_rank}" + ) + + def _compute_rank_parallel_info( + self, + tp_rank: int, + pp_rank: int, + pp_size_per_node: int, + tp_size_per_node: int, + server_args: Optional[ServerArgs] = None, + ): + """Compute parallelism info for a given tp_rank and pp_rank. + + Args: + tp_rank: Tensor parallelism rank + pp_rank: Pipeline parallelism rank + pp_size_per_node: PP size per node + tp_size_per_node: TP size per node + server_args: ServerArgs to use (defaults to self.server_args) + + Returns: + Tuple of (gpu_id, attn_cp_rank, moe_dp_rank, moe_ep_rank) + """ + args = server_args if server_args is not None else self.server_args + + attn_dp_size = args.dp_size + + # Parallelism hierarchy (outermost to innermost): + # - Attention: Global(TP) -> DP -> ATTN_CP -> ATTN_TP (innermost) + # - MoE: Global(TP) -> MOE_DP -> EP -> MOE_TP (innermost) + attn_tp_size = args.tp_size // attn_dp_size // args.attn_cp_size + attn_cp_rank = (tp_rank // attn_tp_size) % args.attn_cp_size + moe_dp_rank = tp_rank // (args.tp_size // args.moe_dp_size) + moe_ep_rank = ( + tp_rank + % (args.tp_size // args.moe_dp_size) + // (args.tp_size // args.moe_dp_size // args.ep_size) + ) + + gpu_id = ( + args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * args.gpu_id_step + ) + + return gpu_id, attn_cp_rank, moe_dp_rank, moe_ep_rank + def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) def update_active_ranks(self, ranks: ActiveRanksOutput): - self.status = ranks.status + if self.status != ranks.status: + self.status = ranks.status + if self.server_args.enable_dp_attention and not all(ranks.status): + ranks_to_recover = [ + i for i in range(len(ranks.status)) if not ranks.status[i] + ] + self.handle_recover_ranks_req(ranks_to_recover) def dispatching_with_trace(self, req: Req): req.time_stats = DPControllerReqTimeStats.new_from_obj(req.time_stats) @@ -360,16 +562,15 @@ def launch_dp_attention_schedulers( self, server_args: ServerArgs, port_args: PortArgs ): # Pre-allocate worker ports on node 0 to avoid conflicts - worker_ports = [] if server_args.node_rank == 0: for dp_rank in range(server_args.dp_size): port_and_socket = get_zmq_socket(self.context, zmq.PUSH) - worker_ports.append(port_and_socket[0]) + self.worker_ports.append(port_and_socket[0]) self.workers[dp_rank] = port_and_socket[1] logger.debug(f"Assigned port {port_and_socket[0]} to worker {dp_rank}") broadcasted_ports = self._broadcast_worker_ports( - server_args, worker_ports if worker_ports else None + server_args, self.worker_ports if self.worker_ports else None ) self.launch_tensor_parallel_group( server_args, port_args, 0, None, broadcasted_ports @@ -480,7 +681,8 @@ def launch_tensor_parallel_group( server_args, gpu_id ): proc.start() - self.scheduler_procs.append(proc) + global_rank = pp_rank * server_args.tp_size + tp_rank + self._register_scheduler_proc(global_rank, proc) scheduler_pipe_readers.append(reader) # Wait for model to finish loading @@ -549,6 +751,7 @@ def total_tokens_scheduler(self, req: Req): def event_loop(self): while True: + self.maybe_recover_dead_schedulers() while True: self.soft_watchdog.feed() try: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1ed4fd07bc49..b328285bb15a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -3370,8 +3370,8 @@ def run_scheduler_process( dp_rank, ) - # Send initialization info back to the parent process - pipe_writer.send(scheduler.get_init_info()) + if not server_args.mooncake_extend_group: + pipe_writer.send(scheduler.get_init_info()) # Run the event loop (blocks until shutdown) scheduler.run_event_loop() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8f55006399e1..7cdf9d435f4d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -307,13 +307,19 @@ def __init__( self.max_req_len > 0 and self.max_req_input_len > 0 ), "Memory pool size is too small" - # Sync random seed across TP workers - self.random_seed = broadcast_pyobj( - [server_args.random_seed], - self.tp_size * self.pp_rank + tp_rank, - self.world_group.cpu_group, - src=self.world_group.ranks[0], - )[0] + # Recovered ranks initialize in an isolated extension group before join. + # World-rank broadcasts sourced from rank 0 are invalid in that phase, + # so use the configured seed locally and let the normal group sync resume + # after join. + if server_args.mooncake_extend_group: + self.random_seed = server_args.random_seed + else: + self.random_seed = broadcast_pyobj( + [server_args.random_seed], + self.tp_size * self.pp_rank + tp_rank, + self.world_group.cpu_group, + src=self.world_group.ranks[0], + )[0] set_random_seed(self.random_seed) self.enable_overlap = not server_args.disable_overlap_schedule diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d85a7d1bf9eb..af6551fc2eca 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -72,7 +72,7 @@ use_symmetric_memory, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager, try_recover_ranks from sglang.srt.elastic_ep.expert_backup_client import ExpertBackupClient from sglang.srt.environ import envs from sglang.srt.eplb.eplb_manager import EPLBManager @@ -86,6 +86,7 @@ ExpertLocationMetadata, compute_initial_expert_location_metadata, get_global_expert_location_metadata, + broadcast_global_expert_location_metadata, set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater @@ -111,6 +112,7 @@ get_global_experts_capturer, set_global_experts_capturer, ) +from sglang.srt.layers.moe.token_dispatcher import MooncakeEPDispatcher from sglang.srt.layers.pooler import EmbeddingPoolerOutput from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype from sglang.srt.layers.sampler import create_sampler @@ -157,6 +159,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, + broadcast_pyobj, cpu_has_amx_support, dynamic_import, empty_context, @@ -422,6 +425,14 @@ def __init__( self.initialize(pre_model_load_memory) self.check_quantized_moe_compatibility() + if ( + self.server_args.elastic_ep_backend == "mooncake" + and self.server_args.mooncake_extend_group + ): + self._join_deferred_mooncake_groups() + self._finish_deferred_mooncake_recovery() + ElasticEPStateManager.reset(self.server_args) + if self.is_multimodal: sanity_check_mm_pad_shift_value(self.model_config.vocab_size) @@ -614,6 +625,14 @@ def initialize(self, pre_model_load_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + if self.server_args.mooncake_extend_group and self.device in ["cuda", "musa"]: + active_ranks = torch.zeros(self.tp_size, dtype=torch.int32) + active_ranks[self.tp_rank] = 1 + elastic_ep_state = ElasticEPStateManager.instance() + elastic_ep_state.active_ranks = active_ranks.cuda() + elastic_ep_state.snapshot_active_to_last() + elastic_ep_state.sync_active_to_cpu() + if self.device == "cuda" or self.device == "musa": self.init_cublas() self.init_attention_backend() @@ -640,6 +659,54 @@ def initialize(self, pre_model_load_memory: float): self.prealloc_symmetric_memory_pool() + def _join_deferred_mooncake_groups(self): + from mooncake import ep as mooncake_ep + + from sglang.srt.distributed import parallel_state + + joined_backends = [] + + def join_backend(label: str, backend): + if any(backend is joined_backend for joined_backend in joined_backends): + return + logger.info("Recovered rank joining deferred Mooncake backend %s", label) + mooncake_ep.join_group(backend) + joined_backends.append(backend) + + join_backend( + "default_world", + dist.group.WORLD._get_backend(torch.device("cuda")), + ) + + for group_ref in parallel_state._groups.values(): + group = group_ref() + if group is None or group.world_size <= 1: + continue + + join_backend( + f"{group.unique_name}:device", + group.device_group._get_backend(torch.device("cuda")), + ) + join_backend( + f"{group.unique_name}:cpu", + group.cpu_group._get_backend(torch.device("cpu")), + ) + + if ( + group.use_message_queue_broadcaster + and group.world_size > 1 + ): + # Create message queue + from sglang.srt.distributed.device_communicators.shm_broadcast import ( + MessageQueue, + ) + + group.mq_broadcaster = ( + MessageQueue.create_from_process_group( + group.cpu_group, 1 << 22, 6 + ) + ) + def init_routed_experts_capturer(self): if not self.server_args.disable_shared_experts_fusion and hasattr( self.model, "num_fused_shared_experts" @@ -809,6 +876,7 @@ def _(data, dim): distributed_init_method=dist_init_method, timeout=self.server_args.dist_timeout, moe_a2a_backend=self.server_args.moe_a2a_backend, + extend_group=self.server_args.mooncake_extend_group, ) initialize_model_parallel( tensor_model_parallel_size=self.tp_size, @@ -818,6 +886,7 @@ def _(data, dim): attention_context_model_parallel_size=self.attn_cp_size, moe_data_model_parallel_size=self.moe_dp_size, duplicate_tp_group=self.server_args.enable_pdmux, + extend_group=self.server_args.mooncake_extend_group, ) initialize_dp_attention( server_args=self.server_args, @@ -829,7 +898,8 @@ def _(data, dim): pre_model_load_memory = get_available_gpu_memory( self.device, self.gpu_id, - distributed=get_world_group().world_size > 1, + distributed=get_world_group().world_size > 1 + and not self.server_args.mooncake_extend_group, cpu_group=get_world_group().cpu_group, ) self.tp_group = get_tp_group() @@ -1137,6 +1207,117 @@ def update_expert_location( weight_name_filter=weight_name_filter, ) + def _sync_expert_weights_to_recovered_ranks( + self, + ranks_to_recover: List[int], + ): + # First, ensure that expert location metadata is fully synchronized across + # all ranks (including newly recovered ranks). This guarantees that + # old_physical_to_logical_map and related structures used by the + # ExpertLocationUpdater are identical everywhere, avoiding asymmetric + # P2P patterns after recovery. + metadata = get_global_expert_location_metadata() + if metadata is None: + logger.warning( + "No expert location metadata found, skipping expert weight / metadata sync" + ) + return + + world_group = get_world_group() + src_rank = world_group.ranks[0] + + broadcast_global_expert_location_metadata( + src_rank=src_rank, group=world_group.device_group + ) + + # For now, we still rely on Mooncake EP and subsequent EPLB runs to move + # the actual expert weights as needed. The important part here is that + # metadata (physical_to_logical_map, etc.) is now consistent across ranks. + + def _sync_first_execution_flags(self): + """Sync first_execution flags across all ranks after rank recovery.""" + self.expert_location_updater._first_execution = False + for module in self.model.modules(): + if isinstance(module, MooncakeEPDispatcher): + module._get_impl().first_execution = False + + def _handle_rank_recovery(self, ranks_to_recover: List[int]): + """Handle rank recovery logic after try_recover_ranks returns successfully.""" + import torch.distributed as dist + + # To sync with the new ranks' load_weight barrier + dist.barrier(group=get_tp_group().cpu_group) + + ElasticEPStateManager.reset(self.server_args) + + # To sync with the new ranks' graph capture barrier + dist.all_reduce( + torch.tensor([0], dtype=torch.int32, device="cuda"), + group=get_tp_group().device_group, + ) + torch.cuda.synchronize() + + self.forward_pass_id = 0 + + # Skip EPLB rebalance after rank recovery. + # Instead of using EPLB (which relies on old_expert_location_metadata + # that is incorrect for recovered ranks), we directly sync expert + # weights from healthy ranks to recovered ranks. + # See: https://github.com/sgl-project/sglang/pull/15771 + if self.eplb_manager is not None: + self.eplb_manager.reset_generator() + # Directly sync expert weights to recovered ranks instead of + # running EPLB rebalance which causes asymmetric P2P operations + # due to incorrect old metadata for recovered ranks. + self._sync_expert_weights_to_recovered_ranks(ranks_to_recover) + + # Sync first_execution flags so recovered ranks and healthy + # ranks behave consistently in dispatch/combine. + self._sync_first_execution_flags() + + # Force sync elastic_ep_state to prevent forward pass from + # triggering EPLB again (is_active_equal_last should return True) + elastic_ep_state = ElasticEPStateManager.instance() + if elastic_ep_state is not None: + elastic_ep_state.snapshot_active_to_last() + elastic_ep_state.sync_active_to_cpu() + + # To sync with the new ranks' random seed setup + self.random_seed = broadcast_pyobj( + [self.server_args.random_seed], + self.tp_size * self.pp_rank + self.tp_rank, + get_world_group().cpu_group, + src=get_world_group().ranks[0], + )[0] + + def _get_tp_ranks_to_recover(self) -> List[int]: + tp_active_ranks = self.tp_group.active_ranks.detach().cpu().numpy() + tp_active_ranks_cpu = self.tp_group.active_ranks_cpu.detach().numpy() + tp_active_ranks &= tp_active_ranks_cpu + return [i for i in range(len(tp_active_ranks)) if not tp_active_ranks[i]] + + def maybe_recover_tp_ranks(self) -> List[int]: + ranks_to_recover = self._get_tp_ranks_to_recover() + if not ranks_to_recover: + return [] + + logger.info("Begin to recover ranks %s", ranks_to_recover) + if not try_recover_ranks(ranks_to_recover): + return ranks_to_recover + + self._handle_rank_recovery(ranks_to_recover) + logger.info("recover ranks %s done", ranks_to_recover) + + return self._get_tp_ranks_to_recover() + + def _finish_deferred_mooncake_recovery(self) -> None: + ranks_to_recover = [self.tp_rank] + logger.info( + "Recovered rank waiting for healthy ranks to finish recovery for %s", + ranks_to_recover, + ) + self._handle_rank_recovery(ranks_to_recover) + def update_weights_from_disk( self, model_path: str, @@ -2523,6 +2704,8 @@ def forward( if dumper.may_enable: dumper.step() + self.maybe_recover_tp_ranks() + return output def _forward_raw( diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 588da80342ca..48a46f5ab380 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -138,7 +138,8 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): post_model_load_memory = get_available_gpu_memory( self.device, self.gpu_id, - distributed=get_world_group().world_size > 1, + distributed=get_world_group().world_size > 1 + and not self.server_args.mooncake_extend_group, cpu_group=get_world_group().cpu_group, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ef7df1cfa6dd..047b691a973d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -533,6 +533,7 @@ class ServerArgs: elastic_ep_backend: Literal[None, "mooncake", "nixl"] = None enable_elastic_expert_backup: bool = False mooncake_ib_device: Optional[str] = None + mooncake_extend_group: bool = False # Mamba cache max_mamba_cache_size: Optional[int] = None @@ -4714,6 +4715,12 @@ def add_cli_args(parser: argparse.ArgumentParser): "(e.g., --mooncake-ib-device mlx5_0,mlx5_1). " "Default is None, which triggers automatic device detection when Mooncake Backend is enabled.", ) + parser.add_argument( + "--mooncake-extend-group", + action="store_true", + default=ServerArgs.mooncake_extend_group, + help=argparse.SUPPRESS, + ) # Mamba Cache parser.add_argument( diff --git a/test/registered/ep/test_mooncake_ep_small.py b/test/registered/ep/test_mooncake_ep_small.py index 16cc4622c494..adcd3c519b53 100644 --- a/test/registered/ep/test_mooncake_ep_small.py +++ b/test/registered/ep/test_mooncake_ep_small.py @@ -72,7 +72,7 @@ def test_gsm8k(self): args = SimpleNamespace( num_shots=5, data_path=None, - num_questions=200, + num_questions=1200, max_new_tokens=512, parallel=128, host="http://127.0.0.1", @@ -102,14 +102,6 @@ def test_gsm8k_fault_1(self): os.system(f"pkill -f {self.pkill_process_1}") super().test_gsm8k() - @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") - def test_gsm8k_fault_2(self): - """ - Kill another rank and the system should remain operational. - """ - os.system(f"pkill -f {self.pkill_process_2}") - super().test_gsm8k() - @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestHybridDPTP(TestPureDP):