Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8ae4347
pr2 eplb
HanHan009527 Sep 15, 2025
f01ba58
Let `token_dispatcher/mooncake.py` use the `global_elastic_ep_metadat…
UNIDY2002 Sep 17, 2025
cd65b69
fix
HanHan009527 Sep 17, 2025
3e38276
some fix
HanHan009527 Oct 2, 2025
484058c
fxi
HanHan009527 Oct 2, 2025
cb41b43
lint
HanHan009527 Oct 15, 2025
560c595
fix
HanHan009527 Oct 15, 2025
e6875fc
test
HanHan009527 Oct 15, 2025
63d7b0e
add ut
HanHan009527 Oct 15, 2025
8d8aca9
test
HanHan009527 Oct 15, 2025
6d36b5b
test
HanHan009527 Oct 15, 2025
56fb09c
fix
HanHan009527 Oct 15, 2025
2434821
fix
HanHan009527 Oct 15, 2025
8c0e187
fix
HanHan009527 Oct 15, 2025
37eeaab
fix
HanHan009527 Oct 15, 2025
2200898
lint
HanHan009527 Oct 15, 2025
9808c8d
fix
HanHan009527 Oct 15, 2025
cb54875
fix
HanHan009527 Oct 15, 2025
7b1bd4e
test
HanHan009527 Oct 15, 2025
2606322
t
HanHan009527 Oct 15, 2025
9a99351
ut
HanHan009527 Oct 16, 2025
06563c0
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 16, 2025
fd8cc23
Merge branch 'main' into mooncake-pr-eplb
ShangmingCai Oct 16, 2025
7b06878
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 17, 2025
4da41cd
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 17, 2025
687ba59
review
Oct 20, 2025
d66c884
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 20, 2025
fbc874e
fix lint
ShangmingCai Oct 20, 2025
47cb4ad
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 21, 2025
9168004
Merge branch 'main' into mooncake-pr-eplb
ShangmingCai Oct 21, 2025
8fdffde
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 21, 2025
a133b2d
Merge branch 'main' into mooncake-pr-eplb
ShangmingCai Oct 21, 2025
5b10b90
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 Oct 21, 2025
2cd15a9
Merge branch 'main' into mooncake-pr-eplb
ShangmingCai Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions python/sglang/srt/elastic_ep/elastic_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch

from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda


@dataclass
class ElasticEPState:
active_ranks: Optional[torch.Tensor]
last_active_ranks: Optional[torch.Tensor]
active_ranks_cpu: Optional[torch.Tensor]

def is_active_equal_last(self) -> bool:
return torch.equal(self.active_ranks, self.last_active_ranks)

def sync_active_to_cpu(self):
if self.active_ranks is not None:
self.active_ranks_cpu = self.active_ranks.detach().cpu().clone()

def snapshot_active_to_last(self):
if self.active_ranks is not None:
self.last_active_ranks = self.active_ranks.clone()


class ElasticEPStateManager:
_instance: Optional[ElasticEPState] = None

@classmethod
def instance(cls) -> ElasticEPState:
return cls._instance

@classmethod
def init(cls, server_args: ServerArgs):
if cls._instance is not None:
return cls._instance

if server_args.elastic_ep_backend is not None:
cls._instance = cls._build_state(ep_size=None, device=None)
return cls._instance

@staticmethod
def _select_device() -> torch.device:
if is_cuda():
return torch.device("cuda")
elif is_cpu():
return torch.device("cpu")
else:
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")

@classmethod
def _build_state(
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
) -> ElasticEPState:

active = cls.healthy_rank_state(ep_size=ep_size, device=device)
return ElasticEPState(
active_ranks=active,
last_active_ranks=active.clone(),
active_ranks_cpu=active.detach().cpu().clone(),
)

@classmethod
def healthy_rank_state(
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
) -> torch.Tensor:
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
dev = device if device is not None else cls._select_device()
Comment on lines +68 to +72
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe give ep_size and device a default value: None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


return torch.ones(size, dtype=torch.int32, device=dev)
Comment on lines +67 to +74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this dtype be changed to torch.int64 to align with the future usage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mooncake EP currently uses int32. BTW, what does "future usage" refer to?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@UNIDY2002 Just wonder if it should align with log2phy, which is int64. But I am not sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick verification. active_ranks has two usages in rebalance_experts: (1) num_active_ranks = active_ranks.sum().item(); (2) active_ranks_list = active_ranks.tolist(), so I think using int32 for active_ranks may be okay.

19 changes: 18 additions & 1 deletion python/sglang/srt/eplb/eplb_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import torch

from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware


class EplbAlgorithm(Enum):
deepseek = auto()
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
elasticity_aware = auto()
# TODO may have more algorithm later


Expand Down Expand Up @@ -45,6 +47,21 @@ def rebalance_experts(
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)

if algorithm == EplbAlgorithm.elasticity_aware:
return elasticity_aware.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=True,
active_ranks=(
ElasticEPStateManager.instance().active_ranks
if ElasticEPStateManager.instance() is not None
else ElasticEPStateManager.healthy_rank_state()
),
)

raise NotImplementedError


Expand Down
87 changes: 87 additions & 0 deletions python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Tuple

import torch

from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
active_ranks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.

Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""

num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
num_active_ranks = active_ranks.sum().item()
num_local_experts = num_replicas // num_gpus
if num_active_ranks < num_gpus:
# Must fall back to global load-balance policy
# and fix some params
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight,
num_local_experts * num_active_ranks,
1,
1,
num_active_ranks,
)
elif enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(
num_local_experts * num_active_ranks,
dtype=torch.int64,
device=log2phy.device,
).expand(num_layers, -1),
)
if num_active_ranks < num_gpus:
phy2log_slices = list(
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
)
active_ranks_list = active_ranks.tolist()
for idx, active_rank in enumerate(active_ranks_list):
if not active_rank:
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
log2phy = torch.where(
log2phy >= idx * num_local_experts,
log2phy + num_local_experts,
log2phy,
)
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
return phy2log, log2phy, logcnt
16 changes: 2 additions & 14 deletions python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple

from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
from sglang.srt.layers.moe.token_dispatcher.base import (
Expand Down Expand Up @@ -63,14 +64,6 @@ def format(self) -> CombineInputFormat:
assert isinstance(MooncakeCombineInput, CombineInput)


_ACTIVE_RANKS: Optional[torch.Tensor] = None


def get_ep_active_ranks() -> torch.Tensor:
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
return _ACTIVE_RANKS


class EPBuffer:
_buffer = None
_hidden_size: Optional[int] = None
Expand Down Expand Up @@ -153,12 +146,7 @@ def __init__(
self.first_execution = True
self.timeout_us = 10000000

global _ACTIVE_RANKS
if _ACTIVE_RANKS is None:
_ACTIVE_RANKS = torch.ones(
(self.num_experts,), dtype=torch.int32, device="cuda"
)
self.active_ranks = _ACTIVE_RANKS
self.active_ranks = ElasticEPStateManager.instance().active_ranks

self.handle = None

Expand Down
46 changes: 37 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand All @@ -51,6 +51,7 @@
set_symm_mem_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
ExpertDistributionRecorder,
Expand Down Expand Up @@ -379,6 +380,11 @@ def initialize(self, min_per_gpu_memory: float):
)
self.expert_location_updater = ExpertLocationUpdater()

(
ElasticEPStateManager.init(self.server_args)
if self.server_args.elastic_ep_backend
else None
)
# Load the model
self.sampler = Sampler()
self.load_model()
Expand Down Expand Up @@ -945,16 +951,33 @@ def update_expert_location(
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
):
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
if ElasticEPStateManager.instance() is not None:
# TODO: refactor the weights update when elastic ep
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
old_expert_location_metadata.update(
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
)
self.update_weights_from_disk(
self.server_args.model_path,
self.server_args.load_format,
lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
)
else:
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)

def update_weights_from_disk(
self, model_path: str, load_format: str
self,
model_path: str,
load_format: str,
weight_name_filter: Optional[Callable[[str], bool]] = None,
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
logger.info(
Expand All @@ -976,6 +999,11 @@ def get_weight_iter(config):
iter = loader._get_weights_iterator(
DefaultModelLoader.Source.init_new(config, self.model)
)
if weight_name_filter is not None:
iter = (
(name, weight) for name, weight in iter if weight_name_filter(name)
)

return iter

def model_load_weights(model, iter):
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ def __post_init__(self):
# Handle any other necessary validations.
self._handle_other_validations()

# Handle elastic expert parallelism.
self._handle_elastic_ep()

def _handle_deprecated_args(self):
# handle deprecated tool call parsers
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
Expand Down Expand Up @@ -1222,6 +1225,15 @@ def _handle_eplb_and_dispatch(self):
if self.enable_eplb:
assert self.ep_size > 1

def _handle_elastic_ep(self):
if self.elastic_ep_backend is not None:
if self.enable_eplb:
if self.eplb_algorithm == "auto":
self.eplb_algorithm = "elasticity_aware"
assert (
self.eplb_algorithm == "elasticity_aware"
), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'."

def _handle_expert_distribution_metrics(self):
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
Expand Down
Loading
Loading