Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Supports 'mooncake' and 'deepep'. Use 'none' to disable. | None |
| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None |
| `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
Expand Down Expand Up @@ -248,7 +248,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep` or `mooncake`. | none |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep`, `mooncake`, or `nixl`. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | auto |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
Expand Down
57 changes: 57 additions & 0 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import torch.distributed
from torch.distributed import Backend, ProcessGroup

from sglang.srt.distributed.utils import set_global_tcp_store
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
Expand Down Expand Up @@ -1398,6 +1399,59 @@ def set_symm_mem_all_reduce(enable: bool):
_ENABLE_SYMM_MEM_ALL_REDUCE = enable


def _create_global_tcp_store(
rank: int, world_size: int
) -> None:
"""Create a global TCPStore for coordination across ranks.

This function creates a TCPStore that all ranks can use for coordination
(e.g., for NIXL buffer setup).
"""
from torch.distributed import TCPStore

master_ip = os.environ.get("MASTER_ADDR")

if not master_ip:
logger.warning(
"Could not determine master IP for global TCPStore. "
"Broadcasting from rank 0 to all ranks."
)

base_store_port = int(os.environ.get("SGLANG_TCP_STORE_PORT", "29600"))

# Rank 0 gets its local IP and broadcasts it to all ranks
# Use broadcast_object_list which works with any backend (handles CPU/GPU automatically)
if not master_ip:
if rank == 0:
master_ip = get_local_ip_auto()
ip_list = [master_ip]
else:
ip_list = [None]

torch.distributed.broadcast_object_list(ip_list, src=0)
master_ip = ip_list[0]

logger.debug("yorayz: master_ip: %s, base_store_port: %d", master_ip, base_store_port)
try:
tcp_store = TCPStore(
host_name=master_ip,
port=base_store_port,
world_size=world_size,
is_master=(rank == 0),
)
set_global_tcp_store(tcp_store)
logger.info(
"Created global TCPStore at %s:%d (rank=%d, world_size=%d)",
master_ip, base_store_port, rank, world_size
)
except Exception as e:
logger.warning(
"Failed to create global TCPStore at %s:%d: %s. "
"Components requiring TCPStore (like NIXL) may not work.",
master_ip, base_store_port, e
)


def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
Expand Down Expand Up @@ -1444,6 +1498,9 @@ def init_distributed_environment(
timeout=timeout,
)

# Create a global TCPStore for coordination (used by NIXL)
_create_global_tcp_store(rank, world_size)

# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/srt/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,42 @@

logger = logging.getLogger(__name__)

# Global TCPStore that is created during distributed initialization
# This is the single shared store that all components should use
_global_tcp_store: Optional[TCPStore] = None


def set_global_tcp_store(store: TCPStore) -> None:
"""Set the global TCPStore instance.

This should be called during distributed initialization to make
the store available to all components that need it.
"""
global _global_tcp_store
_global_tcp_store = store
logger.info("Global TCPStore has been set")


def get_global_tcp_store() -> Optional[TCPStore]:
"""Get the existing global TCPStore.

This function provides access to the shared TCPStore instance that was
created during distributed initialization. All components (like NIXL buffers)
should use this same store for coordination.

Returns:
The global TCPStore instance, or None if not initialized yet.
"""
global _global_tcp_store

if _global_tcp_store is None:
logger.warning(
"Global TCPStore not found. Make sure init_distributed_environment "
"was called with a tcp:// init method."
)

return _global_tcp_store


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down
108 changes: 108 additions & 0 deletions python/sglang/srt/elastic_ep/elastic_ep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union

import torch

from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda


@dataclass
class ElasticEPState:
_active_ranks: Optional[torch.Tensor]
_last_active_ranks: Optional[torch.Tensor]
_active_ranks_cpu: Optional[torch.Tensor]
on_forward: Optional[Callable] = None
rank_status: Optional[torch.Tensor] = None

def is_active_equal_last(self) -> bool:
return torch.equal(self._active_ranks, self._last_active_ranks)

def sync_active_to_cpu(self):
if self._active_ranks is not None:
self._active_ranks_cpu = self._active_ranks.detach().cpu().clone()

def snapshot_active_to_last(self):
if self._active_ranks is not None:
self._last_active_ranks = self._active_ranks.clone()


class ElasticEPStateManager:
_instance: Optional[ElasticEPState] = None
_lock = threading.Lock()

@staticmethod
def on_forward_mooncake(
state: ElasticEPState, status: torch.Tensor = None, **kwargs
):
state._active_ranks = state.rank_status.to(dtype=torch.int32)

@staticmethod
def on_forward_deepep(state: ElasticEPState, status: torch.Tensor = None, **kwargs):
state._active_ranks = 1 - state.rank_status.to(torch.int32)

@classmethod
def instance(cls) -> ElasticEPState:
return cls._instance

@classmethod
def init(cls, server_args: ServerArgs):
with cls._lock:
if cls._instance is not None:
return cls._instance

if server_args.elastic_ep_backend is not None:
cls._instance = cls._build_state(
ep_size=None,
device=None,
backend_type=server_args.elastic_ep_backend,
)
return cls._instance

@staticmethod
def _select_device() -> torch.device:
if is_cuda():
return torch.device("cuda")
elif is_cpu():
return torch.device("cpu")
else:
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")

@classmethod
def _build_state(
cls,
*,
ep_size: Optional[int],
device: Optional[torch.device],
backend_type: str = "none",
) -> ElasticEPState:

active = cls.create_rank_state(ep_size=ep_size, device=device, value=1)

if backend_type == "mooncake":
on_forward = cls.on_forward_mooncake
elif backend_type == "deepep":
on_forward = cls.on_forward_deepep
else:
on_forward = None

return ElasticEPState(
_active_ranks=active,
_last_active_ranks=active.clone(),
_active_ranks_cpu=active.detach().cpu().clone(),
rank_status=active.clone(),
on_forward=on_forward,
)

@classmethod
def create_rank_state(
cls, *, ep_size: Optional[int], device: Optional[torch.device], value: int = 1
) -> torch.Tensor:
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
dev = device if device is not None else cls._select_device()

return torch.full((size,), value, dtype=torch.int32, device=dev)
23 changes: 22 additions & 1 deletion python/sglang/srt/eplb/eplb_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

import torch

from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware


class EplbAlgorithm(Enum):
deepseek = auto()
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
elasticity_aware = auto()
elasticity_aware_hierarchical = auto()
# TODO may have more algorithm later


Expand Down Expand Up @@ -45,6 +48,24 @@ def rebalance_experts(
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)

if algorithm in [
EplbAlgorithm.elasticity_aware,
EplbAlgorithm.elasticity_aware_hierarchical,
]:
return elasticity_aware.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=algorithm == EplbAlgorithm.elasticity_aware_hierarchical,
active_ranks=(
ElasticEPStateManager.instance()._active_ranks
if ElasticEPStateManager.instance() is not None
else ElasticEPStateManager.healthy_rank_state()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The method ElasticEPStateManager.healthy_rank_state() does not exist. This will raise an AttributeError if the elasticity_aware algorithm is used when ElasticEPStateManager is not initialized (i.e., no elastic EP backend is active).

To prevent this crash, you should provide a fallback that creates a tensor indicating all ranks are healthy. A tensor of ones with a shape corresponding to the number of GPUs would be a suitable default.

Suggested change
else ElasticEPStateManager.healthy_rank_state()
else torch.ones(num_physical_experts // num_local_physical_experts, dtype=torch.int32)

),
)

raise NotImplementedError


Expand Down
87 changes: 87 additions & 0 deletions python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Tuple

import torch

from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
active_ranks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.

Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""

num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
num_active_ranks = active_ranks.sum().item()
num_local_experts = num_replicas // num_gpus
if num_active_ranks < num_gpus:
# Must fall back to global load-balance policy
# and fix some params
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight,
num_local_experts * num_active_ranks,
1,
1,
num_active_ranks,
)
elif enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(
num_local_experts * num_active_ranks,
dtype=torch.int64,
device=log2phy.device,
).expand(num_layers, -1),
)
if num_active_ranks < num_gpus:
phy2log_slices = list(
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
)
active_ranks_list = active_ranks.tolist()
for idx, active_rank in enumerate(active_ranks_list):
if not active_rank:
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
log2phy = torch.where(
log2phy >= idx * num_local_experts,
log2phy + num_local_experts,
log2phy,
)
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
return phy2log, log2phy, logcnt
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,11 @@ def _forward_ll(dispatch_output: DeepEPLLOutput):


def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
if (
get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_mooncake()
or get_moe_a2a_backend().is_nixl()
):
return DeepEPMoE

# NEW: Direct FP4 detection (bypasses EP requirements)
Expand Down
Loading