Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Supports 'mooncake' and 'deepep'. Use 'none' to disable. | None |
| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None |
| `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
Expand Down
108 changes: 108 additions & 0 deletions python/sglang/srt/elastic_ep/elastic_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union

import torch

from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda


@dataclass
class ElasticEPState:
_active_ranks: Optional[torch.Tensor]
_last_active_ranks: Optional[torch.Tensor]
_active_ranks_cpu: Optional[torch.Tensor]
on_forward: Optional[Callable] = None
rank_status: Optional[torch.Tensor] = None

def is_active_equal_last(self) -> bool:
return torch.equal(self._active_ranks, self._last_active_ranks)

def sync_active_to_cpu(self):
if self._active_ranks is not None:
self._active_ranks_cpu = self._active_ranks.detach().cpu().clone()

def snapshot_active_to_last(self):
if self._active_ranks is not None:
self._last_active_ranks = self._active_ranks.clone()


class ElasticEPStateManager:
_instance: Optional[ElasticEPState] = None
_lock = threading.Lock()

@staticmethod
def on_forward_mooncake(
state: ElasticEPState, status: torch.Tensor = None, **kwargs
):
state._active_ranks = state.rank_status.to(dtype=torch.int32)

@staticmethod
def on_forward_deepep(state: ElasticEPState, status: torch.Tensor = None, **kwargs):
state._active_ranks = 1 - state.rank_status.to(torch.int32)

@classmethod
def instance(cls) -> ElasticEPState:
return cls._instance

@classmethod
def init(cls, server_args: ServerArgs):
with cls._lock:
if cls._instance is not None:
return cls._instance

if server_args.elastic_ep_backend is not None:
cls._instance = cls._build_state(
ep_size=None,
device=None,
backend_type=server_args.elastic_ep_backend,
)
return cls._instance

@staticmethod
def _select_device() -> torch.device:
if is_cuda():
return torch.device("cuda")
elif is_cpu():
return torch.device("cpu")
else:
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")

@classmethod
def _build_state(
cls,
*,
ep_size: Optional[int],
device: Optional[torch.device],
backend_type: str = "none",
) -> ElasticEPState:

active = cls.create_rank_state(ep_size=ep_size, device=device, value=1)

if backend_type == "mooncake":
on_forward = cls.on_forward_mooncake
elif backend_type == "deepep":
on_forward = cls.on_forward_deepep
else:
on_forward = None

return ElasticEPState(
_active_ranks=active,
_last_active_ranks=active.clone(),
_active_ranks_cpu=active.detach().cpu().clone(),
rank_status=active.clone(),
on_forward=on_forward,
)

@classmethod
def create_rank_state(
cls, *, ep_size: Optional[int], device: Optional[torch.device], value: int = 1
) -> torch.Tensor:
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
dev = device if device is not None else cls._select_device()

return torch.full((size,), value, dtype=torch.int32, device=dev)
19 changes: 18 additions & 1 deletion python/sglang/srt/eplb/eplb_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import torch

from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware


class EplbAlgorithm(Enum):
deepseek = auto()
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
elasticity_aware = auto()
# TODO may have more algorithm later


Expand Down Expand Up @@ -45,6 +47,21 @@ def rebalance_experts(
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)

if algorithm == EplbAlgorithm.elasticity_aware:
return elasticity_aware.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=True,
active_ranks=(
ElasticEPStateManager.instance()._active_ranks
if ElasticEPStateManager.instance() is not None
else ElasticEPStateManager.healthy_rank_state()
),
)

raise NotImplementedError


Expand Down
87 changes: 87 additions & 0 deletions python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Tuple

import torch

from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
active_ranks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.

Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""

num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
num_active_ranks = active_ranks.sum().item()
num_local_experts = num_replicas // num_gpus
if num_active_ranks < num_gpus:
# Must fall back to global load-balance policy
# and fix some params
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight,
num_local_experts * num_active_ranks,
1,
1,
num_active_ranks,
)
elif enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(
num_local_experts * num_active_ranks,
dtype=torch.int64,
device=log2phy.device,
).expand(num_layers, -1),
)
if num_active_ranks < num_gpus:
phy2log_slices = list(
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
)
active_ranks_list = active_ranks.tolist()
for idx, active_rank in enumerate(active_ranks_list):
if not active_rank:
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
log2phy = torch.where(
log2phy >= idx * num_local_experts,
log2phy + num_local_experts,
log2phy,
)
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
return phy2log, log2phy, logcnt
6 changes: 6 additions & 0 deletions python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union

from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
Expand Down Expand Up @@ -211,6 +212,7 @@ def get_deepep_buffer(
low_latency_mode=deepep_mode.enable_low_latency(),
num_qps_per_rank=num_qps_per_rank,
# TODO can be false when unneeded
enable_shrink=True,
allow_mnnvl=True,
)
return cls._buffer
Expand Down Expand Up @@ -299,6 +301,7 @@ def __init__(
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
assert self.num_max_dispatch_tokens_per_rank <= 1024
self.status_tensor = ElasticEPStateManager.instance().rank_status

self.handle = None

Expand Down Expand Up @@ -664,6 +667,9 @@ def _combine_core(
else {}
),
)
torch.cuda.synchronize()
buffer.low_latency_query_mask_buffer(self.status_tensor)
torch.cuda.synchronize()

self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook
Expand Down
16 changes: 2 additions & 14 deletions python/sglang/srt/layers/moe/token_dispatcher/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple

from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
Expand Down Expand Up @@ -62,14 +63,6 @@ def format(self) -> CombineInputFormat:
assert isinstance(MooncakeCombineInput, CombineInput)


_ACTIVE_RANKS: Optional[torch.Tensor] = None


def get_ep_active_ranks() -> torch.Tensor:
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
return _ACTIVE_RANKS


class EPBuffer:
_buffer = None
_hidden_size: Optional[int] = None
Expand Down Expand Up @@ -152,12 +145,7 @@ def __init__(
self.first_execution = True
self.timeout_us = 10000000

global _ACTIVE_RANKS
if _ACTIVE_RANKS is None:
_ACTIVE_RANKS = torch.ones(
(self.num_experts,), dtype=torch.int32, device="cuda"
)
self.active_ranks = _ACTIVE_RANKS
self.active_ranks = ElasticEPStateManager.instance().rank_status

self.handle = None

Expand Down
46 changes: 37 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand All @@ -51,6 +51,7 @@
set_symm_mem_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
ExpertDistributionRecorder,
Expand Down Expand Up @@ -382,6 +383,11 @@ def initialize(self, min_per_gpu_memory: float):
)
self.expert_location_updater = ExpertLocationUpdater()

(
ElasticEPStateManager.init(self.server_args)
if self.server_args.elastic_ep_backend
else None
)
# Load the model
self.sampler = Sampler()
self.load_model()
Expand Down Expand Up @@ -926,16 +932,33 @@ def update_expert_location(
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
):
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
if ElasticEPStateManager.instance() is not None:
# TODO: refactor the weights update when elastic ep
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
old_expert_location_metadata.update(
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
)
self.update_weights_from_disk(
self.server_args.model_path,
self.server_args.load_format,
lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
)
else:
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)

def update_weights_from_disk(
self, model_path: str, load_format: str
self,
model_path: str,
load_format: str,
weight_name_filter: Optional[Callable[[str], bool]] = None,
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
logger.info(
Expand All @@ -957,6 +980,11 @@ def get_weight_iter(config):
iter = loader._get_weights_iterator(
DefaultModelLoader.Source.init_new(config, self.model)
)
if weight_name_filter is not None:
iter = (
(name, weight) for name, weight in iter if weight_name_filter(name)
)

return iter

def model_load_weights(model, iter):
Expand Down
Loading
Loading