Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8ae4347
pr2 eplb
HanHan009527 Sep 15, 2025
f01ba58
Let `token_dispatcher/mooncake.py` use the `global_elastic_ep_metadat…
UNIDY2002 Sep 17, 2025
cd65b69
fix
HanHan009527 Sep 17, 2025
3e38276
some fix
HanHan009527 Oct 2, 2025
484058c
fxi
HanHan009527 Oct 2, 2025
cb41b43
lint
HanHan009527 Oct 15, 2025
560c595
fix
HanHan009527 Oct 15, 2025
e6875fc
test
HanHan009527 Oct 15, 2025
63d7b0e
add ut
HanHan009527 Oct 15, 2025
8d8aca9
test
HanHan009527 Oct 15, 2025
6d36b5b
test
HanHan009527 Oct 15, 2025
56fb09c
fix
HanHan009527 Oct 15, 2025
2434821
fix
HanHan009527 Oct 15, 2025
8c0e187
fix
HanHan009527 Oct 15, 2025
37eeaab
fix
HanHan009527 Oct 15, 2025
2200898
lint
HanHan009527 Oct 15, 2025
9808c8d
fix
HanHan009527 Oct 15, 2025
cb54875
fix
HanHan009527 Oct 15, 2025
7b1bd4e
test
HanHan009527 Oct 15, 2025
2606322
t
HanHan009527 Oct 15, 2025
642fa37
Introduce Mooncake Backend and Mooncake EP
UNIDY2002 Sep 11, 2025
f50b9c3
tiny fix mooncake pr (#12)
HanHan009527 Sep 11, 2025
7a96c6a
Fix for more readable code
UNIDY2002 Sep 15, 2025
d71258c
pr2 eplb
HanHan009527 Sep 15, 2025
6620278
fix
HanHan009527 Sep 17, 2025
6d8c984
Test fault tolerance
UNIDY2002 Sep 17, 2025
6bafc08
feat
ympcMark Sep 18, 2025
46c0589
feat
ympcMark Sep 18, 2025
7162a8f
feat
ympcMark Sep 18, 2025
80c2c0b
feat
ympcMark Sep 18, 2025
175f9f4
feat
ympcMark Sep 19, 2025
5f72a00
feat
ympcMark Sep 19, 2025
a760bab
feat
ympcMark Sep 19, 2025
3a16563
feat
ympcMark Sep 19, 2025
d9e2866
feat
ympcMark Sep 19, 2025
a67f9ab
feat
ympcMark Sep 19, 2025
c018c58
fix
HanHan009527 Oct 15, 2025
8e4d1e9
ut
HanHan009527 Oct 15, 2025
f28ac4b
test
HanHan009527 Oct 15, 2025
c6c241b
lint
HanHan009527 Oct 15, 2025
c4e4dd0
test
HanHan009527 Oct 15, 2025
e5d9d16
t
HanHan009527 Oct 15, 2025
d6e7d26
t
HanHan009527 Oct 15, 2025
f4a6b7c
test
HanHan009527 Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def __init__(
group_name: Optional[str] = None,
torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
active_ranks: Optional[torch.Tensor] = None,
active_ranks_cpu: Optional[torch.Tensor] = None,
):
# Set group info
group_name = group_name or "anonymous"
Expand Down Expand Up @@ -1279,6 +1281,8 @@ def init_model_parallel_group(
use_mscclpp_allreduce: Optional[bool] = None,
use_symm_mem_allreduce: Optional[bool] = None,
torch_compile: Optional[bool] = None,
active_ranks: Optional[torch.Tensor] = None,
active_ranks_cpu: Optional[torch.Tensor] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
Expand All @@ -1290,7 +1294,7 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not _is_npu,
use_pynccl=not _is_npu and not "mooncake" in backend,
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce,
Expand All @@ -1300,10 +1304,23 @@ def init_model_parallel_group(
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
torch_compile=torch_compile,
active_ranks=active_ranks,
active_ranks_cpu=active_ranks_cpu,
)


_TP: Optional[GroupCoordinator] = None
_TP_ACTIVE_RANKS: Optional[torch.Tensor] = None
_TP_ACTIVE_RANKS_CPU: Optional[torch.Tensor] = None


def get_tp_active_ranks():
return _TP_ACTIVE_RANKS


def get_tp_active_ranks_cpu():
return _TP_ACTIVE_RANKS_CPU


# duplicate GroupCoordinator for prefill in PD-Multiplexing
_PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None
Expand Down Expand Up @@ -1517,6 +1534,14 @@ def initialize_model_parallel(
)
group_ranks.append(ranks)

global _TP_ACTIVE_RANKS
_TP_ACTIVE_RANKS = torch.ones(
(tensor_model_parallel_size,), dtype=torch.int32, device="cuda"
)
global _TP_ACTIVE_RANKS_CPU
_TP_ACTIVE_RANKS_CPU = torch.ones(
(tensor_model_parallel_size,), dtype=torch.int32, device="cpu"
)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
Expand All @@ -1527,6 +1552,8 @@ def initialize_model_parallel(
),
group_name="tp",
torch_compile=torch_compile,
active_ranks=_TP_ACTIVE_RANKS,
active_ranks_cpu=_TP_ACTIVE_RANKS_CPU,
)

if duplicate_tp_group:
Expand Down
78 changes: 78 additions & 0 deletions python/sglang/srt/elastic_ep/elastic_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import threading
from dataclasses import dataclass
from threading import Lock
from typing import Optional

import torch

from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda


@dataclass
class ElasticEPState:
active_ranks: Optional[torch.Tensor]
last_active_ranks: Optional[torch.Tensor]
active_ranks_cpu: Optional[torch.Tensor]

def is_active_equal_last(self) -> bool:
return torch.equal(self.active_ranks, self.last_active_ranks)

def sync_active_to_cpu(self):
if self.active_ranks is not None:
self.active_ranks_cpu = self.active_ranks.detach().cpu().clone()

def snapshot_active_to_last(self):
if self.active_ranks is not None:
self.last_active_ranks = self.active_ranks.clone()


class ElasticEPStateManager:
_instance: Optional[ElasticEPState] = None
_lock = threading.Lock()

@classmethod
def instance(cls) -> ElasticEPState:
return cls._instance

@classmethod
def init(cls, server_args: ServerArgs):
with cls._lock:
if cls._instance is not None:
return cls._instance

if server_args.elastic_ep_backend is not None:
cls._instance = cls._build_state(ep_size=None, device=None)
return cls._instance

@staticmethod
def _select_device() -> torch.device:
if is_cuda():
return torch.device("cuda")
elif is_cpu():
return torch.device("cpu")
else:
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")

@classmethod
def _build_state(
cls, *, ep_size: Optional[int], device: Optional[torch.device]
) -> ElasticEPState:

active = cls.healthy_rank_state(ep_size=ep_size, device=device)
return ElasticEPState(
active_ranks=active,
last_active_ranks=active.clone(),
active_ranks_cpu=active.detach().cpu().clone(),
)

@classmethod
def healthy_rank_state(
cls, *, ep_size: Optional[int], device: Optional[torch.device]
) -> torch.Tensor:
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
dev = device if device is not None else cls._select_device()

return torch.ones(size, dtype=torch.int32, device=dev)
22 changes: 21 additions & 1 deletion python/sglang/srt/eplb/eplb_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import torch

from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware


class EplbAlgorithm(Enum):
deepseek = auto()
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
elasticity_aware = auto()
# TODO may have more algorithm later


Expand Down Expand Up @@ -45,6 +47,21 @@ def rebalance_experts(
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)

if algorithm == EplbAlgorithm.elasticity_aware:
return elasticity_aware.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=True,
active_ranks=(
ElasticEPStateManager.instance().active_ranks
if ElasticEPStateManager.instance() is not None
else ElasticEPStateManager.healthy_rank_state()
),
)

raise NotImplementedError


Expand All @@ -56,6 +73,9 @@ def compute_algorithm(
if raw_algorithm != "auto":
return EplbAlgorithm[raw_algorithm]

if get_elastic_ep_state().using_elastic_ep:
return EplbAlgorithm.elasticity_aware

# TODO test on real scenarios and know which ones perform better
if (num_groups is not None) and (num_groups % num_nodes == 0):
return EplbAlgorithm.deepseek_hierarchical
Expand Down
87 changes: 87 additions & 0 deletions python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Tuple

import torch

from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
active_ranks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.

Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""

num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
num_active_ranks = active_ranks.sum().item()
num_local_experts = num_replicas // num_gpus
if num_active_ranks < num_gpus:
# Must fall back to global load-balance policy
# and fix some params
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight,
num_local_experts * num_active_ranks,
1,
1,
num_active_ranks,
)
elif enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(
num_local_experts * num_active_ranks,
dtype=torch.int64,
device=log2phy.device,
).expand(num_layers, -1),
)
if num_active_ranks < num_gpus:
phy2log_slices = list(
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
)
active_ranks_list = active_ranks.tolist()
for idx, active_rank in enumerate(active_ranks_list):
if not active_rank:
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
log2phy = torch.where(
log2phy >= idx * num_local_experts,
log2phy + num_local_experts,
log2phy,
)
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
return phy2log, log2phy, logcnt
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
_ATTN_TP_RANK: Optional[int] = None
_ATTN_TP_SIZE: Optional[int] = None
_ATTN_TP_ACTIVE_RANKS: Optional[torch.Tensor] = None
_ATTN_DP_RANK: Optional[int] = None
_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
Expand Down Expand Up @@ -252,6 +253,11 @@ def initialize_dp_attention(
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1

global _ATTN_TP_ACTIVE_RANKS
_ATTN_TP_ACTIVE_RANKS = torch.ones(
(_ATTN_TP_SIZE,), dtype=torch.int32, device="cuda"
)

tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
[
Expand All @@ -268,6 +274,7 @@ def initialize_dp_attention(
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="attention_tp",
active_ranks=_ATTN_TP_ACTIVE_RANKS,
)

_DpGatheredBufferWrapper.set_metadata(
Expand Down
16 changes: 2 additions & 14 deletions python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple

from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
Expand Down Expand Up @@ -62,14 +63,6 @@ def format(self) -> CombineInputFormat:
assert isinstance(MooncakeCombineInput, CombineInput)


_ACTIVE_RANKS: Optional[torch.Tensor] = None


def get_ep_active_ranks() -> torch.Tensor:
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
return _ACTIVE_RANKS


class EPBuffer:
_buffer = None
_hidden_size: Optional[int] = None
Expand Down Expand Up @@ -152,12 +145,7 @@ def __init__(
self.first_execution = True
self.timeout_us = 10000000

global _ACTIVE_RANKS
if _ACTIVE_RANKS is None:
_ACTIVE_RANKS = torch.ones(
(self.num_experts,), dtype=torch.int32, device="cuda"
)
self.active_ranks = _ACTIVE_RANKS
self.active_ranks = ElasticEPStateManager.instance().active_ranks

self.handle = None

Expand Down
Loading
Loading