Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(
group_name: Optional[str] = None,
pynccl_use_current_stream: bool = False,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
extend_group: bool = False,
):
# Set group info
group_name = group_name or "anonymous"
Expand Down Expand Up @@ -282,15 +283,19 @@ def __init__(
if "mooncake" in torch_distributed_backend:
from mooncake.ep import MooncakeBackendOptions

if len(ranks) == 1:
# Bypass the extension procedure
extend_group = False

device_group = torch.distributed.new_group(
ranks,
backend="mooncake",
pg_options=MooncakeBackendOptions(active_ranks),
pg_options=MooncakeBackendOptions(active_ranks, extend_group),
)
cpu_group = torch.distributed.new_group(
ranks,
backend="mooncake-cpu",
pg_options=MooncakeBackendOptions(active_ranks_cpu),
pg_options=MooncakeBackendOptions(active_ranks_cpu, extend_group),
)
else:
pg_options = get_torch_distributed_pg_options(group_name)
Expand Down Expand Up @@ -435,7 +440,11 @@ def __init__(
)

self.mq_broadcaster: Optional[MessageQueue] = None
if use_message_queue_broadcaster and self.world_size > 1:
if (
use_message_queue_broadcaster
and self.world_size > 1
and not extend_group
):
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6
)
Expand Down Expand Up @@ -1413,7 +1422,7 @@ def get_world_group() -> GroupCoordinator:


def init_world_group(
ranks: List[int], local_rank: int, backend: str
ranks: List[int], local_rank: int, backend: str, extend_group: bool = False
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
Expand All @@ -1427,6 +1436,7 @@ def init_world_group(
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="world",
extend_group=extend_group,
)


Expand All @@ -1441,6 +1451,7 @@ def init_model_parallel_group(
use_mscclpp_allreduce: Optional[bool] = None,
pynccl_use_current_stream: bool = True,
use_torch_symm_mem_allreduce: Optional[bool] = None,
extend_group: bool = False,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
Expand All @@ -1466,6 +1477,7 @@ def init_model_parallel_group(
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
pynccl_use_current_stream=pynccl_use_current_stream,
extend_group=extend_group,
)


Expand Down Expand Up @@ -1675,6 +1687,7 @@ def init_distributed_environment(
backend: str = "nccl",
timeout: Optional[int] = None,
moe_a2a_backend: Optional[str] = None,
extend_group: bool = False,
):
logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
Expand Down Expand Up @@ -1705,7 +1718,10 @@ def init_distributed_environment(
assert timeout > 0, "timeout must be positive"
timeout = timedelta(seconds=timeout)

pg_options = get_torch_distributed_pg_options()
from mooncake.ep import MooncakeBackendOptions

active_ranks = torch.ones(world_size, dtype=torch.int32, device="cuda")
pg_options = MooncakeBackendOptions(active_ranks, extend_group)

# this backend is used for WORLD
torch.distributed.init_process_group(
Expand Down Expand Up @@ -1734,7 +1750,7 @@ def init_distributed_environment(
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_WORLD = init_world_group(ranks, local_rank, backend, extend_group=extend_group)
else:
assert (
_WORLD.world_size == torch.distributed.get_world_size()
Expand All @@ -1750,6 +1766,7 @@ def initialize_model_parallel(
moe_data_model_parallel_size: int = 1,
backend: Optional[str] = None,
duplicate_tp_group: bool = False,
extend_group: bool = False,
) -> None:
"""
Initialize model parallel groups.
Expand Down Expand Up @@ -1831,6 +1848,7 @@ def initialize_model_parallel(
),
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
extend_group=extend_group,
)

if duplicate_tp_group:
Expand All @@ -1847,6 +1865,7 @@ def initialize_model_parallel(
),
group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
extend_group=extend_group,
)
if _TP.pynccl_comm:
_TP.pynccl_comm.disabled = False
Expand Down Expand Up @@ -1884,6 +1903,7 @@ def initialize_model_parallel(
get_world_group().local_rank,
backend,
group_name="attn_cp",
extend_group=extend_group,
)

from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
Expand Down Expand Up @@ -1917,6 +1937,7 @@ def initialize_model_parallel(
use_custom_allreduce=False,
use_torch_symm_mem_allreduce=False,
group_name="attention_tp",
extend_group=extend_group,
)

moe_ep_size = expert_model_parallel_size
Expand All @@ -1943,6 +1964,7 @@ def initialize_model_parallel(
get_world_group().local_rank,
backend,
group_name="moe_dp",
extend_group=extend_group,
)

global _MOE_EP
Expand All @@ -1968,6 +1990,7 @@ def initialize_model_parallel(
get_world_group().local_rank,
backend,
group_name="moe_ep",
extend_group=extend_group,
)

global _MOE_TP
Expand All @@ -1994,6 +2017,7 @@ def initialize_model_parallel(
get_world_group().local_rank,
backend,
group_name="moe_tp",
extend_group=extend_group,
)

# Build the pipeline model-parallel groups.
Expand All @@ -2013,6 +2037,7 @@ def initialize_model_parallel(
backend,
use_custom_allreduce=False,
group_name="pp",
extend_group=extend_group,
)


Expand Down
154 changes: 153 additions & 1 deletion python/sglang/srt/elastic_ep/elastic_ep.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import torch

from sglang.srt.distributed import parallel_state
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda

logger = logging.getLogger(__name__)


@dataclass
class ElasticEPState:
Expand Down Expand Up @@ -43,6 +49,18 @@ def init(cls, server_args: ServerArgs):
cls._instance = cls._build_state(ep_size=None, device=None)
return cls._instance

@classmethod
def reset(cls, server_args: ServerArgs):
from sglang.srt.layers.moe import MoeA2ABackend
from sglang.srt.layers.moe.token_dispatcher.mooncake import EPBuffer

moe_a2a_backend = MoeA2ABackend(server_args.moe_a2a_backend)
if moe_a2a_backend.is_mooncake():
EPBuffer.mark_ep_member_refresh_needed()
cls._instance.active_ranks.fill_(1)
cls._instance.snapshot_active_to_last()
cls._instance.sync_active_to_cpu()

@staticmethod
def _select_device() -> torch.device:
if is_cuda():
Expand Down Expand Up @@ -71,3 +89,137 @@ def healthy_rank_state(
dev = device if device is not None else cls._select_device()

return torch.ones(size, dtype=torch.int32, device=dev)


def _check_peer_state_for_group(group, ranks_in_group):
"""Check peer state for a specific group."""
from mooncake import ep as mooncake_ep

if not ranks_in_group:
return None

# Check device group
device_backend = group.device_group._get_backend(torch.device("cuda"))
device_peer_state = mooncake_ep.get_peer_state(device_backend, ranks_in_group)

# Check cpu group
cpu_backend = group.cpu_group._get_backend(torch.device("cpu"))
cpu_peer_state = mooncake_ep.get_peer_state(cpu_backend, ranks_in_group)

return {"device": device_peer_state, "cpu": cpu_peer_state}


def try_recover_ranks(global_ranks: List[int]):
from mooncake import ep as mooncake_ep

logger.info(f"Trying to recover ranks: {global_ranks}")

backend = torch.distributed.group.WORLD._get_backend(torch.device("cuda"))
peer_state = mooncake_ep.get_peer_state(backend, global_ranks)

if not all(peer_state):
logger.info("Early return from try_recover_ranks")
return False

groups = parallel_state._groups
groups = {group_name: groups[group_name] for group_name in groups}

for group_name in groups:
group = groups[group_name]()
if group is not None:
group_ranks = group.ranks
ranks_in_group = [r for r in global_ranks if r in group_ranks]

if ranks_in_group:
group_rank_mapping = {
global_rank: idx for idx, global_rank in enumerate(group_ranks)
}
group_local_ranks = [
group_rank_mapping[r]
for r in ranks_in_group
if r in group_rank_mapping
]
if group_local_ranks:
device_backend = group.device_group._get_backend(
torch.device("cuda")
)
group_peer_state = mooncake_ep.get_peer_state(
device_backend, group_local_ranks
)

cpu_backend = group.cpu_group._get_backend(torch.device("cpu"))
group_peer_state = mooncake_ep.get_peer_state(
cpu_backend, group_local_ranks
)

mooncake_ep.recover_ranks(backend, global_ranks)

device_group_status = {group_name: False for group_name in groups}
cpu_group_status = {group_name: False for group_name in groups}
while True:
for group_name in groups:
group = groups[group_name]()
if group is not None:
group_ranks = group.ranks
ranks_in_group = [r for r in global_ranks if r in group_ranks]

if ranks_in_group:
# Convert to group-local ranks
group_rank_mapping = {
global_rank: idx for idx, global_rank in enumerate(group_ranks)
}
group_local_ranks = [
group_rank_mapping[r]
for r in ranks_in_group
if r in group_rank_mapping
]

if group_local_ranks:
if not device_group_status[group_name]:
device_backend = group.device_group._get_backend(
torch.device("cuda")
)
group_peer_state = mooncake_ep.get_peer_state(
device_backend, group_local_ranks
)
if all(group_peer_state):
mooncake_ep.recover_ranks(
device_backend, group_local_ranks
)
device_group_status[group_name] = True

if not cpu_group_status[group_name]:
cpu_backend = group.cpu_group._get_backend(
torch.device("cpu")
)
group_peer_state = mooncake_ep.get_peer_state(
cpu_backend, group_local_ranks
)
if all(group_peer_state):
mooncake_ep.recover_ranks(
cpu_backend, group_local_ranks
)
if (
group.use_message_queue_broadcaster
and group.world_size > 1
):
# Create message queue
from sglang.srt.distributed.device_communicators.shm_broadcast import (
MessageQueue,
)

group.mq_broadcaster = (
MessageQueue.create_from_process_group(
group.cpu_group, 1 << 22, 6
)
)
cpu_group_status[group_name] = True
else:
device_group_status[group_name] = True
cpu_group_status[group_name] = True
else:
device_group_status[group_name] = True
cpu_group_status[group_name] = True
if all(device_group_status.values()) and all(cpu_group_status.values()):
logger.info(f"Recovered ranks {global_ranks}")
return True
3 changes: 3 additions & 0 deletions python/sglang/srt/eplb/eplb_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self, model_runner: "ModelRunner"):
def on_forward_pass_end(self):
next(self._main_generator)

def reset_generator(self):
self._main_generator = self._entrypoint()

# can be more complex if needed
def _entrypoint(self):
while True:
Expand Down
Loading
Loading