From a347674e32f4f920be454a8debec07e83cc579e5 Mon Sep 17 00:00:00 2001 From: UNIDY Date: Wed, 17 Sep 2025 14:14:15 +0800 Subject: [PATCH 1/8] Get active_ranks info from Mooncake Backend Co-authored-by: Hank Han --- .../sglang/srt/distributed/parallel_state.py | 29 +++++++-- .../sglang/srt/model_executor/model_runner.py | 21 ++++++ test/srt/ep/test_mooncake_ep_small.py | 64 ++++++++++++++++++- 3 files changed, 105 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5a7907163f5c..c63973689353 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -239,14 +239,27 @@ def __init__( self.local_size = get_int_env_var("LOCAL_SIZE", 0) for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - # a cpu_group to allow direct coordination between processes through - # the CPU. The backend is chosen based on `torch_distributed_backend` + active_ranks = torch.ones(len(ranks), dtype=torch.int32, device="cuda") + active_ranks_cpu = torch.ones(len(ranks), dtype=torch.int32) if "mooncake" in torch_distributed_backend: - cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") + from mooncake.ep import MooncakeBackendOptions + + device_group = torch.distributed.new_group( + ranks, + backend="mooncake", + pg_options=MooncakeBackendOptions(active_ranks), + ) + cpu_group = torch.distributed.new_group( + ranks, + backend="mooncake-cpu", + pg_options=MooncakeBackendOptions(active_ranks_cpu), + ) else: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination + # between processes through the CPU. cpu_group = torch.distributed.new_group( ranks, backend="gloo", timeout=gloo_timeout ) @@ -256,6 +269,8 @@ def __init__( self.rank_in_group = ranks.index(self.rank) self.device_group = device_group self.cpu_group = cpu_group + self.active_ranks = active_ranks + self.active_ranks_cpu = active_ranks_cpu assert self.cpu_group is not None assert self.device_group is not None @@ -1343,7 +1358,7 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not (_is_npu or _is_xpu), + use_pynccl=not (_is_npu or _is_xpu or backend == "mooncake"), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6752ad5f33a0..8c4c46ddb6b9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2229,6 +2229,27 @@ def forward( reinit_attn_backend, split_forward_count, ) + elastic_ep_state = ElasticEPStateManager.instance() + if ( + elastic_ep_state is not None + and not elastic_ep_state.is_active_equal_last() + ): + elastic_ep_state.snapshot_active_to_last() + elastic_ep_state.sync_active_to_cpu() + logging.info("EPLB due to rank faults") + gen = self.eplb_manager.rebalance() + while True: + try: + next(gen) + except StopIteration: + break + output = self._forward_raw( + forward_batch, + skip_attn_backend_init, + pp_proxy_tensors, + reinit_attn_backend, + split_forward_count, + ) output.expert_distribution_metrics = recorder_outputs.get("metrics") # Copy cached routing experts' buffers back to CPU cache diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 4f6d02a39616..aa812750f5ea 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -70,7 +71,6 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) -@unittest.skip("covered in TestMooncakeWithEPLB") class TestPureDP(TestTP): extra_args = [ "--tp", @@ -78,8 +78,29 @@ class TestPureDP(TestTP): "--enable-dp-attention", "--dp", "4", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "72", ] + def test_gsm8k_fault_1(self): + """ + Kill one rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP1_TP1_EP1") + super().test_gsm8k() + + def test_gsm8k_fault_2(self): + """ + Kill another rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP3_TP3_EP3") + super().test_gsm8k() + class TestHybridDPTP(TestTP): extra_args = [ @@ -88,8 +109,29 @@ class TestHybridDPTP(TestTP): "--enable-dp-attention", "--dp", "2", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "72", ] + def test_gsm8k_fault_1(self): + """ + Kill one rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP1_TP2_EP2") + super().test_gsm8k() + + def test_gsm8k_fault_2(self): + """ + Kill another rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP1_TP3_EP3") + super().test_gsm8k() + @unittest.skip("covered in TestMooncakeWithEPLB") class TestNoGatherdBuffer(TestTP): @@ -104,7 +146,6 @@ class TestNoGatherdBuffer(TestTP): ] -@unittest.skip("covered in TestMooncakeWithEPLB") class TestTBO(TestTP): extra_args = [ "--tp", @@ -115,8 +156,27 @@ class TestTBO(TestTP): "--moe-dense-tp-size", "1", "--enable-two-batch-overlap", + "--enable-dp-lm-head", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "72", ] + def test_gsm8k_fault_1(self): + """ + Kill one rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP1_TP1_EP1") + super().test_gsm8k() + + def test_gsm8k_fault_2(self): + """ + Kill another rank and the system should remain operational. + """ + os.system("pkill -f sglang::scheduler_DP3_TP3_EP3") + super().test_gsm8k() + class TestMooncakeWithEPLB(TestTP): extra_args = [ From 6fcbf3951a3b36e4cb82941dc8ce1f35a21a466d Mon Sep 17 00:00:00 2001 From: ympcMark Date: Thu, 18 Sep 2025 21:20:51 +0800 Subject: [PATCH 2/8] Achieve fault tolerance at the DP level Co-authored-by: UNIDY2002 --- .../srt/managers/data_parallel_controller.py | 24 +++++++++++++++---- python/sglang/srt/managers/io_struct.py | 5 ++++ python/sglang/srt/managers/scheduler.py | 14 +++++++++++ .../srt/managers/scheduler_dp_attn_mixin.py | 12 ++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 5 ++++ 5 files changed, 56 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index eb51352fd380..08b63c9c1e3a 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -29,6 +29,7 @@ from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( + ActiveRanksOutput, BlockReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -158,6 +159,7 @@ def __init__( # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size + self.status: List[int] = [1] * server_args.dp_size if server_args.enable_dp_attention: self.launch_dp_attention_schedulers(server_args, port_args) @@ -179,8 +181,9 @@ def __init__( start_cpu_monitor_thread("data_parallel_controller") def send_to_all_workers(self, obj): - for worker in self.workers: - worker.send_pyobj(obj) + for i, worker in enumerate(self.workers): + if self.status[i] == 1: + worker.send_pyobj(obj) def send_control_message(self, obj): # Send control messages to first worker of tp group @@ -190,6 +193,9 @@ def send_control_message(self, obj): def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) + def update_active_ranks(self, ranks: ActiveRanksOutput): + self.status = ranks.status + def dispatching_with_trace(self, req: Req): if self.server_args.enable_trace: trace_set_proc_propagate_context(req.rid, req.trace_context) @@ -208,6 +214,7 @@ def init_dispatcher(self): (TokenizedEmbeddingReqInput, self.dispatching_with_trace), (BlockReqInput, self.send_to_all_workers), (WatchLoadUpdateReq, self.handle_load_update_req), + (ActiveRanksOutput, self.update_active_ranks), ] ) self._request_dispatcher.add_fallback_fn(self.send_control_message) @@ -479,8 +486,17 @@ def round_robin_scheduler(self, req: Req): if self.maybe_external_dp_rank_routing(req): return - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + while True: + if self.status[self.round_robin_counter] == 1: + logger.info(f"Choose worker {self.round_robin_counter}") + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + break + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) def follow_bootstrap_room_scheduler(self, req: Req): if self.maybe_external_dp_rank_routing(req): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 18cc3d2aa636..72045902569d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1434,6 +1434,11 @@ def __post_init__(self): self.rid = "" +@dataclass +class ActiveRanksOutput(BaseReq): + status: List[int] + + @dataclass class GetInternalStateReq(BaseReq): pass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1bf294973df1..12463710ad67 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.managers.io_struct import ( AbortReq, + ActiveRanksOutput, BaseBatchReq, BaseReq, BatchTokenizedEmbeddingReqInput, @@ -2273,6 +2274,19 @@ def run_batch( for req in batch.reqs: req.time_stats.prefill_end_time_host = current_time + if ( + self.server_args.enable_dp_attention + and self.server_args.elastic_ep_backend == "mooncake" + ): + # Get the tensors indicating rank activeness + tp_active_ranks = self.tp_group.active_ranks.detach().cpu().numpy() + tp_active_ranks_cpu = self.tp_group.active_ranks_cpu.detach().numpy() + tp_active_ranks &= tp_active_ranks_cpu + dp_active_ranks = tp_active_ranks.reshape(self.dp_size, -1).prod(axis=1) + self.send_to_tokenizer.send_output( + ActiveRanksOutput(status=dp_active_ranks.tolist()) + ) + return ret def launch_batch_sample_if_needed( diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 9c92cd9c383b..b52553e8a0c0 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -6,9 +6,11 @@ import torch from sglang.srt.batch_overlap.two_batch_overlap import TboDPAttentionPreparer +from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.metrics.collector import DPCooperationInfo +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils.common import require_mlp_tp_gather if TYPE_CHECKING: @@ -66,6 +68,15 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup): local_info_tensor, group=group, ) + if device == "cpu": + tp_active_ranks = get_tp_group().active_ranks_cpu + else: + tp_active_ranks = get_tp_group().active_ranks + global_info_tensor.view(-1, 6)[tp_active_ranks == 0, :] = torch.tensor( + [0, 1, 0, 0, 1, ForwardMode.IDLE.value], + device=global_info_tensor.device, + dtype=global_info_tensor.dtype, + ) tp0_info = global_info_tensor[:, 0, :] self.tp0_info = tp0_info @@ -149,6 +160,7 @@ def prepare_mlp_sync_batch_raw( if len(offload_tags) == 0 and disable_overlap_schedule: group = tp_group.device_group device = tp_group.device + torch.distributed.barrier(group=tp_group.cpu_group) else: group = tp_group.cpu_group device = "cpu" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a433a0597bf1..d117a4208de9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -47,6 +47,7 @@ from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, + ActiveRanksOutput, BatchEmbeddingOutput, BatchMultimodalOutput, BatchStrOutput, @@ -465,6 +466,7 @@ def init_request_dispatcher(self): (FreezeGCReq, lambda x: None), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. (HealthCheckOutput, lambda x: None), + (ActiveRanksOutput, self.update_active_ranks), ] ) self.init_communicators(self.server_args) @@ -2104,6 +2106,9 @@ def _handle_abort_req(self, recv_obj: AbortReq): state.out_list.append(out) state.event.set() + def update_active_ranks(self, ranks: ActiveRanksOutput): + self.send_to_scheduler.send_pyobj(ranks) + def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id if recv_obj.success else None From 8cf153afebdbe342218d613b32b3bc8e7801fadb Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 13 Jan 2026 21:05:04 +0800 Subject: [PATCH 3/8] debgug --- python/sglang/srt/managers/data_parallel_controller.py | 6 +++--- python/sglang/srt/managers/io_struct.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 08b63c9c1e3a..d5a2ec3477bb 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -159,7 +159,7 @@ def __init__( # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size - self.status: List[int] = [1] * server_args.dp_size + self.status: List[bool] = [True] * server_args.dp_size if server_args.enable_dp_attention: self.launch_dp_attention_schedulers(server_args, port_args) @@ -182,7 +182,7 @@ def __init__( def send_to_all_workers(self, obj): for i, worker in enumerate(self.workers): - if self.status[i] == 1: + if self.status[i]: worker.send_pyobj(obj) def send_control_message(self, obj): @@ -487,7 +487,7 @@ def round_robin_scheduler(self, req: Req): return while True: - if self.status[self.round_robin_counter] == 1: + if self.status[self.round_robin_counter]: logger.info(f"Choose worker {self.round_robin_counter}") self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 72045902569d..a92fa6b8a98d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1436,7 +1436,7 @@ def __post_init__(self): @dataclass class ActiveRanksOutput(BaseReq): - status: List[int] + status: List[bool] @dataclass From 7a8d656ab407b0d0e9b5d244645fcd33ef67333e Mon Sep 17 00:00:00 2001 From: UNIDY Date: Fri, 16 Jan 2026 14:52:30 +0800 Subject: [PATCH 4/8] Improve readability --- .../srt/managers/scheduler_dp_attn_mixin.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index b52553e8a0c0..46fbb505b7c5 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -55,6 +55,20 @@ def _get_local_tensor(self, device, dtype=torch.int64) -> torch.Tensor: dtype=dtype, ) + def _get_fallback_tensor(self, device, dtype=torch.int64) -> torch.Tensor: + return torch.tensor( + [ + 0, # num_tokens + 0, # num_tokens_for_logprob + 1, # can_cuda_graph + 0, # is_extend_in_batch + 1, # local_can_run_tbo + ForwardMode.IDLE.value, # local_forward_mode + ], + device=device, + dtype=dtype, + ) + def all_gather(self, device, group: torch.distributed.ProcessGroup): local_info_tensor = self._get_local_tensor(device=device) global_info_tensor = torch.empty( @@ -72,11 +86,10 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup): tp_active_ranks = get_tp_group().active_ranks_cpu else: tp_active_ranks = get_tp_group().active_ranks - global_info_tensor.view(-1, 6)[tp_active_ranks == 0, :] = torch.tensor( - [0, 1, 0, 0, 1, ForwardMode.IDLE.value], - device=global_info_tensor.device, - dtype=global_info_tensor.dtype, - ) + + # Set fallback values for inactive ranks + tp_info = global_info_tensor.view(self.dp_size * self.tp_size, 6) + tp_info[tp_active_ranks == 0] = self._get_fallback_tensor(device=device) tp0_info = global_info_tensor[:, 0, :] self.tp0_info = tp0_info From b2467260ad75246d77be1d6085f793a0d5eba0f1 Mon Sep 17 00:00:00 2001 From: UNIDY Date: Mon, 19 Jan 2026 19:16:53 +0800 Subject: [PATCH 5/8] Remove duplicated code in test_mooncake_ep_small.py and simplify test cases --- test/srt/ep/test_mooncake_ep_small.py | 117 ++++---------------------- 1 file changed, 15 insertions(+), 102 deletions(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index aa812750f5ea..4a4168f414d4 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -39,6 +39,14 @@ def setUpClass(cls): "mooncake", "--deepep-mode", "low_latency", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--enable-two-batch-overlap", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "72", "--chunked-prefill-size", "512", "--cuda-graph-max-bs", @@ -73,133 +81,38 @@ def test_gsm8k(self): class TestPureDP(TestTP): extra_args = [ - "--tp", - "4", "--enable-dp-attention", "--dp", "4", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", - "--disable-custom-all-reduce", - "--enable-eplb", - "--ep-num-redundant-experts", - "72", ] - def test_gsm8k_fault_1(self): - """ - Kill one rank and the system should remain operational. - """ - os.system("pkill -f sglang::scheduler_DP1_TP1_EP1") - super().test_gsm8k() - - def test_gsm8k_fault_2(self): - """ - Kill another rank and the system should remain operational. - """ - os.system("pkill -f sglang::scheduler_DP3_TP3_EP3") - super().test_gsm8k() - - -class TestHybridDPTP(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", - "--disable-custom-all-reduce", - "--enable-eplb", - "--ep-num-redundant-experts", - "72", - ] + pkill_process_1 = "sglang::scheduler_DP1_TP1_EP1" + pkill_process_2 = "sglang::scheduler_DP3_TP3_EP3" def test_gsm8k_fault_1(self): """ Kill one rank and the system should remain operational. """ - os.system("pkill -f sglang::scheduler_DP1_TP2_EP2") + os.system(f"pkill -f {self.pkill_process_1}") super().test_gsm8k() def test_gsm8k_fault_2(self): """ Kill another rank and the system should remain operational. """ - os.system("pkill -f sglang::scheduler_DP1_TP3_EP3") + os.system(f"pkill -f {self.pkill_process_2}") super().test_gsm8k() -@unittest.skip("covered in TestMooncakeWithEPLB") -class TestNoGatherdBuffer(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - ] - - -class TestTBO(TestTP): +class TestHybridDPTP(TestTP): extra_args = [ - "--tp", - "4", "--enable-dp-attention", "--dp", "4", - "--moe-dense-tp-size", - "1", - "--enable-two-batch-overlap", - "--enable-dp-lm-head", - "--disable-custom-all-reduce", - "--enable-eplb", - "--ep-num-redundant-experts", - "72", ] - def test_gsm8k_fault_1(self): - """ - Kill one rank and the system should remain operational. - """ - os.system("pkill -f sglang::scheduler_DP1_TP1_EP1") - super().test_gsm8k() - - def test_gsm8k_fault_2(self): - """ - Kill another rank and the system should remain operational. - """ - os.system("pkill -f sglang::scheduler_DP3_TP3_EP3") - super().test_gsm8k() - - -class TestMooncakeWithEPLB(TestTP): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-two-batch-overlap", - "--enable-eplb", - "--ep-num-redundant-experts", - "4", - "--eplb-rebalance-num-iterations", - "50", - "--expert-distribution-recorder-buffer-size", - "50", - "--expert-distribution-recorder-mode", - "stat", - "--ep-dispatch-algorithm", - "static", - ] + pkill_process_1 = "sglang::scheduler_DP1_TP2_EP2" + pkill_process_2 = "sglang::scheduler_DP1_TP3_EP3" if __name__ == "__main__": From 8a76cf4447d5f2712a1bbfedfc62d2a685b25cce Mon Sep 17 00:00:00 2001 From: UNIDY Date: Mon, 19 Jan 2026 19:36:16 +0800 Subject: [PATCH 6/8] Fix log level --- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index d5a2ec3477bb..eea20137aaee 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -488,7 +488,7 @@ def round_robin_scheduler(self, req: Req): while True: if self.status[self.round_robin_counter]: - logger.info(f"Choose worker {self.round_robin_counter}") + logger.debug(f"Choose worker {self.round_robin_counter}") self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers From 90317826137514ce331718e981c04066d70045cf Mon Sep 17 00:00:00 2001 From: UNIDY Date: Mon, 19 Jan 2026 19:47:19 +0800 Subject: [PATCH 7/8] Minor fix --- python/sglang/srt/managers/scheduler_dp_attn_mixin.py | 1 - test/srt/ep/test_mooncake_ep_small.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 46fbb505b7c5..65ac39da7983 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -173,7 +173,6 @@ def prepare_mlp_sync_batch_raw( if len(offload_tags) == 0 and disable_overlap_schedule: group = tp_group.device_group device = tp_group.device - torch.distributed.barrier(group=tp_group.cpu_group) else: group = tp_group.cpu_group device = "cpu" diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 4a4168f414d4..6271dfd8e0f8 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -104,11 +104,11 @@ def test_gsm8k_fault_2(self): super().test_gsm8k() -class TestHybridDPTP(TestTP): +class TestHybridDPTP(TestPureDP): extra_args = [ "--enable-dp-attention", "--dp", - "4", + "2", ] pkill_process_1 = "sglang::scheduler_DP1_TP2_EP2" From 81e1482a0fd5bf65cc96250e2cd1ee4f3be0ebcd Mon Sep 17 00:00:00 2001 From: UNIDY Date: Tue, 20 Jan 2026 16:14:45 +0800 Subject: [PATCH 8/8] Skip 4 out of 7 test cases in test_mooncake_ep_small.py --- test/srt/ep/test_mooncake_ep_small.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 6271dfd8e0f8..ff06e4167469 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -10,6 +10,7 @@ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, ) @@ -96,6 +97,7 @@ def test_gsm8k_fault_1(self): os.system(f"pkill -f {self.pkill_process_1}") super().test_gsm8k() + @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") def test_gsm8k_fault_2(self): """ Kill another rank and the system should remain operational. @@ -104,6 +106,7 @@ def test_gsm8k_fault_2(self): super().test_gsm8k() +@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") class TestHybridDPTP(TestPureDP): extra_args = [ "--enable-dp-attention",