Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,27 @@ def __init__(
self.local_size = get_int_env_var("LOCAL_SIZE", 0)

for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a cpu_group to allow direct coordination between processes through
# the CPU. The backend is chosen based on `torch_distributed_backend`
active_ranks = torch.ones(len(ranks), dtype=torch.int32, device="cuda")
active_ranks_cpu = torch.ones(len(ranks), dtype=torch.int32)
if "mooncake" in torch_distributed_backend:
cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu")
from mooncake.ep import MooncakeBackendOptions

device_group = torch.distributed.new_group(
ranks,
backend="mooncake",
pg_options=MooncakeBackendOptions(active_ranks),
)
cpu_group = torch.distributed.new_group(
ranks,
backend="mooncake-cpu",
pg_options=MooncakeBackendOptions(active_ranks_cpu),
)
else:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination
# between processes through the CPU.
cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
Expand All @@ -256,6 +269,8 @@ def __init__(
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
self.active_ranks = active_ranks
self.active_ranks_cpu = active_ranks_cpu

assert self.cpu_group is not None
assert self.device_group is not None
Expand Down Expand Up @@ -1343,7 +1358,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not (_is_npu or _is_xpu),
use_pynccl=not (_is_npu or _is_xpu or backend == "mooncake"),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce,
Expand Down
24 changes: 20 additions & 4 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
ActiveRanksOutput,
BlockReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(
# Launch data parallel workers
self.scheduler_procs = []
self.workers: List[zmq.Socket] = [None] * server_args.dp_size
self.status: List[bool] = [True] * server_args.dp_size

if server_args.enable_dp_attention:
self.launch_dp_attention_schedulers(server_args, port_args)
Expand All @@ -179,8 +181,9 @@ def __init__(
start_cpu_monitor_thread("data_parallel_controller")

def send_to_all_workers(self, obj):
for worker in self.workers:
worker.send_pyobj(obj)
for i, worker in enumerate(self.workers):
if self.status[i]:
worker.send_pyobj(obj)

def send_control_message(self, obj):
# Send control messages to first worker of tp group
Expand All @@ -190,6 +193,9 @@ def send_control_message(self, obj):
def handle_load_update_req(self, obj):
self.dp_budget.update_budget(obj)

def update_active_ranks(self, ranks: ActiveRanksOutput):
self.status = ranks.status

def dispatching_with_trace(self, req: Req):
if self.server_args.enable_trace:
trace_set_proc_propagate_context(req.rid, req.trace_context)
Expand All @@ -208,6 +214,7 @@ def init_dispatcher(self):
(TokenizedEmbeddingReqInput, self.dispatching_with_trace),
(BlockReqInput, self.send_to_all_workers),
(WatchLoadUpdateReq, self.handle_load_update_req),
(ActiveRanksOutput, self.update_active_ranks),
]
)
self._request_dispatcher.add_fallback_fn(self.send_control_message)
Expand Down Expand Up @@ -479,8 +486,17 @@ def round_robin_scheduler(self, req: Req):
if self.maybe_external_dp_rank_routing(req):
return

self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
while True:
if self.status[self.round_robin_counter]:
logger.debug(f"Choose worker {self.round_robin_counter}")
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)
break
self.round_robin_counter = (self.round_robin_counter + 1) % len(
self.workers
)

def follow_bootstrap_room_scheduler(self, req: Req):
if self.maybe_external_dp_rank_routing(req):
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,11 @@ def __post_init__(self):
self.rid = ""


@dataclass
class ActiveRanksOutput(BaseReq):
status: List[bool]


@dataclass
class GetInternalStateReq(BaseReq):
pass
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.io_struct import (
AbortReq,
ActiveRanksOutput,
BaseBatchReq,
BaseReq,
BatchTokenizedEmbeddingReqInput,
Expand Down Expand Up @@ -2350,6 +2351,19 @@ def run_batch(
for req in batch.reqs:
req.time_stats.prefill_end_time_host = current_time

if (
self.server_args.enable_dp_attention
and self.server_args.elastic_ep_backend == "mooncake"
):
# Get the tensors indicating rank activeness
tp_active_ranks = self.tp_group.active_ranks.detach().cpu().numpy()
tp_active_ranks_cpu = self.tp_group.active_ranks_cpu.detach().numpy()
tp_active_ranks &= tp_active_ranks_cpu
dp_active_ranks = tp_active_ranks.reshape(self.dp_size, -1).prod(axis=1)
self.send_to_tokenizer.send_output(
ActiveRanksOutput(status=dp_active_ranks.tolist())
)

return ret

def launch_batch_sample_if_needed(
Expand Down
24 changes: 24 additions & 0 deletions python/sglang/srt/managers/scheduler_dp_attn_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import torch

from sglang.srt.batch_overlap.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.metrics.collector import DPCooperationInfo
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils.common import require_mlp_tp_gather

if TYPE_CHECKING:
Expand Down Expand Up @@ -53,6 +55,20 @@ def _get_local_tensor(self, device, dtype=torch.int64) -> torch.Tensor:
dtype=dtype,
)

def _get_fallback_tensor(self, device, dtype=torch.int64) -> torch.Tensor:
return torch.tensor(
[
0, # num_tokens
0, # num_tokens_for_logprob
1, # can_cuda_graph
0, # is_extend_in_batch
1, # local_can_run_tbo
ForwardMode.IDLE.value, # local_forward_mode
],
device=device,
dtype=dtype,
)

def all_gather(self, device, group: torch.distributed.ProcessGroup):
local_info_tensor = self._get_local_tensor(device=device)
global_info_tensor = torch.empty(
Expand All @@ -66,6 +82,14 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup):
local_info_tensor,
group=group,
)
if device == "cpu":
tp_active_ranks = get_tp_group().active_ranks_cpu
else:
tp_active_ranks = get_tp_group().active_ranks

# Set fallback values for inactive ranks
tp_info = global_info_tensor.view(self.dp_size * self.tp_size, 6)
tp_info[tp_active_ranks == 0] = self._get_fallback_tensor(device=device)

tp0_info = global_info_tensor[:, 0, :]
self.tp0_info = tp0_info
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
ActiveRanksOutput,
BatchEmbeddingOutput,
BatchMultimodalOutput,
BatchStrOutput,
Expand Down Expand Up @@ -473,6 +474,7 @@ def init_request_dispatcher(self):
(FreezeGCReq, lambda x: None),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(HealthCheckOutput, lambda x: None),
(ActiveRanksOutput, self.update_active_ranks),
]
)
self.init_communicators(self.server_args)
Expand Down Expand Up @@ -2153,6 +2155,9 @@ def _handle_abort_req(self, recv_obj: AbortReq):
state.out_list.append(out)
state.event.set()

def update_active_ranks(self, ranks: ActiveRanksOutput):
self.send_to_scheduler.send_pyobj(ranks)

def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id if recv_obj.success else None
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,27 @@ def forward(
reinit_attn_backend,
split_forward_count,
)
elastic_ep_state = ElasticEPStateManager.instance()
if (
elastic_ep_state is not None
and not elastic_ep_state.is_active_equal_last()
):
elastic_ep_state.snapshot_active_to_last()
elastic_ep_state.sync_active_to_cpu()
logging.info("EPLB due to rank faults")
gen = self.eplb_manager.rebalance()
while True:
try:
next(gen)
except StopIteration:
break
output = self._forward_raw(
forward_batch,
skip_attn_backend_init,
pp_proxy_tensors,
reinit_attn_backend,
split_forward_count,
)
output.expert_distribution_metrics = recorder_outputs.get("metrics")

# Copy cached routing experts' buffers back to CPU cache
Expand Down
84 changes: 30 additions & 54 deletions test/srt/ep/test_mooncake_ep_small.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
from types import SimpleNamespace

Expand All @@ -9,6 +10,7 @@
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)

Expand Down Expand Up @@ -38,6 +40,14 @@ def setUpClass(cls):
"mooncake",
"--deepep-mode",
"low_latency",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--enable-two-batch-overlap",
"--disable-custom-all-reduce",
"--enable-eplb",
"--ep-num-redundant-experts",
"72",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
Expand Down Expand Up @@ -70,76 +80,42 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.60)


@unittest.skip("covered in TestMooncakeWithEPLB")
class TestPureDP(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
]

pkill_process_1 = "sglang::scheduler_DP1_TP1_EP1"
pkill_process_2 = "sglang::scheduler_DP3_TP3_EP3"

class TestHybridDPTP(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
]
def test_gsm8k_fault_1(self):
"""
Kill one rank and the system should remain operational.
"""
os.system(f"pkill -f {self.pkill_process_1}")
super().test_gsm8k()

@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
def test_gsm8k_fault_2(self):
"""
Kill another rank and the system should remain operational.
"""
os.system(f"pkill -f {self.pkill_process_2}")
super().test_gsm8k()

@unittest.skip("covered in TestMooncakeWithEPLB")
class TestNoGatherdBuffer(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
]


@unittest.skip("covered in TestMooncakeWithEPLB")
class TestTBO(TestTP):
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
class TestHybridDPTP(TestPureDP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-two-batch-overlap",
"2",
]


class TestMooncakeWithEPLB(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-two-batch-overlap",
"--enable-eplb",
"--ep-num-redundant-experts",
"4",
"--eplb-rebalance-num-iterations",
"50",
"--expert-distribution-recorder-buffer-size",
"50",
"--expert-distribution-recorder-mode",
"stat",
"--ep-dispatch-algorithm",
"static",
]
pkill_process_1 = "sglang::scheduler_DP1_TP2_EP2"
pkill_process_2 = "sglang::scheduler_DP1_TP3_EP3"


if __name__ == "__main__":
Expand Down
Loading