diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5a7907163f5c..c63973689353 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -239,14 +239,27 @@ def __init__( self.local_size = get_int_env_var("LOCAL_SIZE", 0) for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - # a cpu_group to allow direct coordination between processes through - # the CPU. The backend is chosen based on `torch_distributed_backend` + active_ranks = torch.ones(len(ranks), dtype=torch.int32, device="cuda") + active_ranks_cpu = torch.ones(len(ranks), dtype=torch.int32) if "mooncake" in torch_distributed_backend: - cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") + from mooncake.ep import MooncakeBackendOptions + + device_group = torch.distributed.new_group( + ranks, + backend="mooncake", + pg_options=MooncakeBackendOptions(active_ranks), + ) + cpu_group = torch.distributed.new_group( + ranks, + backend="mooncake-cpu", + pg_options=MooncakeBackendOptions(active_ranks_cpu), + ) else: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination + # between processes through the CPU. cpu_group = torch.distributed.new_group( ranks, backend="gloo", timeout=gloo_timeout ) @@ -256,6 +269,8 @@ def __init__( self.rank_in_group = ranks.index(self.rank) self.device_group = device_group self.cpu_group = cpu_group + self.active_ranks = active_ranks + self.active_ranks_cpu = active_ranks_cpu assert self.cpu_group is not None assert self.device_group is not None @@ -1343,7 +1358,7 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not (_is_npu or _is_xpu), + use_pynccl=not (_is_npu or _is_xpu or backend == "mooncake"), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce, diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index eb51352fd380..eea20137aaee 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -29,6 +29,7 @@ from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( + ActiveRanksOutput, BlockReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -158,6 +159,7 @@ def __init__( # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size + self.status: List[bool] = [True] * server_args.dp_size if server_args.enable_dp_attention: self.launch_dp_attention_schedulers(server_args, port_args) @@ -179,8 +181,9 @@ def __init__( start_cpu_monitor_thread("data_parallel_controller") def send_to_all_workers(self, obj): - for worker in self.workers: - worker.send_pyobj(obj) + for i, worker in enumerate(self.workers): + if self.status[i]: + worker.send_pyobj(obj) def send_control_message(self, obj): # Send control messages to first worker of tp group @@ -190,6 +193,9 @@ def send_control_message(self, obj): def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) + def update_active_ranks(self, ranks: ActiveRanksOutput): + self.status = ranks.status + def dispatching_with_trace(self, req: Req): if self.server_args.enable_trace: trace_set_proc_propagate_context(req.rid, req.trace_context) @@ -208,6 +214,7 @@ def init_dispatcher(self): (TokenizedEmbeddingReqInput, self.dispatching_with_trace), (BlockReqInput, self.send_to_all_workers), (WatchLoadUpdateReq, self.handle_load_update_req), + (ActiveRanksOutput, self.update_active_ranks), ] ) self._request_dispatcher.add_fallback_fn(self.send_control_message) @@ -479,8 +486,17 @@ def round_robin_scheduler(self, req: Req): if self.maybe_external_dp_rank_routing(req): return - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + while True: + if self.status[self.round_robin_counter]: + logger.debug(f"Choose worker {self.round_robin_counter}") + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + break + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) def follow_bootstrap_room_scheduler(self, req: Req): if self.maybe_external_dp_rank_routing(req): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3df09555949f..04e3b6c5b521 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1434,6 +1434,11 @@ def __post_init__(self): self.rid = "" +@dataclass +class ActiveRanksOutput(BaseReq): + status: List[bool] + + @dataclass class GetInternalStateReq(BaseReq): pass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dd585fc8bce4..69165360f444 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -69,6 +69,7 @@ from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.managers.io_struct import ( AbortReq, + ActiveRanksOutput, BaseBatchReq, BaseReq, BatchTokenizedEmbeddingReqInput, @@ -2350,6 +2351,19 @@ def run_batch( for req in batch.reqs: req.time_stats.prefill_end_time_host = current_time + if ( + self.server_args.enable_dp_attention + and self.server_args.elastic_ep_backend == "mooncake" + ): + # Get the tensors indicating rank activeness + tp_active_ranks = self.tp_group.active_ranks.detach().cpu().numpy() + tp_active_ranks_cpu = self.tp_group.active_ranks_cpu.detach().numpy() + tp_active_ranks &= tp_active_ranks_cpu + dp_active_ranks = tp_active_ranks.reshape(self.dp_size, -1).prod(axis=1) + self.send_to_tokenizer.send_output( + ActiveRanksOutput(status=dp_active_ranks.tolist()) + ) + return ret def launch_batch_sample_if_needed( diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 9c92cd9c383b..65ac39da7983 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -6,9 +6,11 @@ import torch from sglang.srt.batch_overlap.two_batch_overlap import TboDPAttentionPreparer +from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.metrics.collector import DPCooperationInfo +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils.common import require_mlp_tp_gather if TYPE_CHECKING: @@ -53,6 +55,20 @@ def _get_local_tensor(self, device, dtype=torch.int64) -> torch.Tensor: dtype=dtype, ) + def _get_fallback_tensor(self, device, dtype=torch.int64) -> torch.Tensor: + return torch.tensor( + [ + 0, # num_tokens + 0, # num_tokens_for_logprob + 1, # can_cuda_graph + 0, # is_extend_in_batch + 1, # local_can_run_tbo + ForwardMode.IDLE.value, # local_forward_mode + ], + device=device, + dtype=dtype, + ) + def all_gather(self, device, group: torch.distributed.ProcessGroup): local_info_tensor = self._get_local_tensor(device=device) global_info_tensor = torch.empty( @@ -66,6 +82,14 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup): local_info_tensor, group=group, ) + if device == "cpu": + tp_active_ranks = get_tp_group().active_ranks_cpu + else: + tp_active_ranks = get_tp_group().active_ranks + + # Set fallback values for inactive ranks + tp_info = global_info_tensor.view(self.dp_size * self.tp_size, 6) + tp_info[tp_active_ranks == 0] = self._get_fallback_tensor(device=device) tp0_info = global_info_tensor[:, 0, :] self.tp0_info = tp0_info diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a3c5001e87a1..1f532d510c01 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -47,6 +47,7 @@ from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, + ActiveRanksOutput, BatchEmbeddingOutput, BatchMultimodalOutput, BatchStrOutput, @@ -473,6 +474,7 @@ def init_request_dispatcher(self): (FreezeGCReq, lambda x: None), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. (HealthCheckOutput, lambda x: None), + (ActiveRanksOutput, self.update_active_ranks), ] ) self.init_communicators(self.server_args) @@ -2153,6 +2155,9 @@ def _handle_abort_req(self, recv_obj: AbortReq): state.out_list.append(out) state.event.set() + def update_active_ranks(self, ranks: ActiveRanksOutput): + self.send_to_scheduler.send_pyobj(ranks) + def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7731c1e4b92d..d15a2f61125d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2238,6 +2238,27 @@ def forward( reinit_attn_backend, split_forward_count, ) + elastic_ep_state = ElasticEPStateManager.instance() + if ( + elastic_ep_state is not None + and not elastic_ep_state.is_active_equal_last() + ): + elastic_ep_state.snapshot_active_to_last() + elastic_ep_state.sync_active_to_cpu() + logging.info("EPLB due to rank faults") + gen = self.eplb_manager.rebalance() + while True: + try: + next(gen) + except StopIteration: + break + output = self._forward_raw( + forward_batch, + skip_attn_backend_init, + pp_proxy_tensors, + reinit_attn_backend, + split_forward_count, + ) output.expert_distribution_metrics = recorder_outputs.get("metrics") # Copy cached routing experts' buffers back to CPU cache diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 4f6d02a39616..ff06e4167469 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -9,6 +10,7 @@ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, ) @@ -38,6 +40,14 @@ def setUpClass(cls): "mooncake", "--deepep-mode", "low_latency", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--enable-two-batch-overlap", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "72", "--chunked-prefill-size", "512", "--cuda-graph-max-bs", @@ -70,76 +80,42 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) -@unittest.skip("covered in TestMooncakeWithEPLB") class TestPureDP(TestTP): extra_args = [ - "--tp", - "4", "--enable-dp-attention", "--dp", "4", ] + pkill_process_1 = "sglang::scheduler_DP1_TP1_EP1" + pkill_process_2 = "sglang::scheduler_DP3_TP3_EP3" -class TestHybridDPTP(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - ] + def test_gsm8k_fault_1(self): + """ + Kill one rank and the system should remain operational. + """ + os.system(f"pkill -f {self.pkill_process_1}") + super().test_gsm8k() + @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") + def test_gsm8k_fault_2(self): + """ + Kill another rank and the system should remain operational. + """ + os.system(f"pkill -f {self.pkill_process_2}") + super().test_gsm8k() -@unittest.skip("covered in TestMooncakeWithEPLB") -class TestNoGatherdBuffer(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - ] - -@unittest.skip("covered in TestMooncakeWithEPLB") -class TestTBO(TestTP): +@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") +class TestHybridDPTP(TestPureDP): extra_args = [ - "--tp", - "4", "--enable-dp-attention", "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-two-batch-overlap", + "2", ] - -class TestMooncakeWithEPLB(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-two-batch-overlap", - "--enable-eplb", - "--ep-num-redundant-experts", - "4", - "--eplb-rebalance-num-iterations", - "50", - "--expert-distribution-recorder-buffer-size", - "50", - "--expert-distribution-recorder-mode", - "stat", - "--ep-dispatch-algorithm", - "static", - ] + pkill_process_1 = "sglang::scheduler_DP1_TP2_EP2" + pkill_process_2 = "sglang::scheduler_DP1_TP3_EP3" if __name__ == "__main__":