From 8ae43471322b049ebf43eca6ed55077d8fd1b693 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Tue, 16 Sep 2025 01:15:35 +0800 Subject: [PATCH 01/44] pr2 eplb fix fix fix fix fix fix fix ut ut ut fix fit --- python/sglang/srt/elastic_ep/elastic_ep.py | 43 +++++++++ .../srt/eplb/eplb_algorithms/__init__.py | 15 +++- .../srt/eplb/eplb_algorithms/elastic_ep.py | 88 +++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 41 +++++++-- test/srt/ep/test_elastic_ep_eplb.py | 81 +++++++++++++++++ 5 files changed, 258 insertions(+), 10 deletions(-) create mode 100644 python/sglang/srt/elastic_ep/elastic_ep.py create mode 100644 python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py create mode 100755 test/srt/ep/test_elastic_ep_eplb.py diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py new file mode 100644 index 000000000000..79795060ad6b --- /dev/null +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from curses import use_env +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class ElasticEpMetadata: + using_elastic_ep: bool + active_ranks: Optional[torch.Tensor] + last_active_ranks: Optional[torch.Tensor] + + +_global_elastic_ep_metadata: Optional[ElasticEpMetadata] = None + + +def get_global_elastic_ep_metadata(): + return _global_elastic_ep_metadata + + +def set_global_elastic_ep_metadata(value): + global _global_elastic_ep_metadata + assert _global_elastic_ep_metadata is None + _global_elastic_ep_metadata = value + + +def _init_global_elastic_ep_metadata(): + global _global_elastic_ep_metadata + if _global_elastic_ep_metadata is not None: + return + + ep_size = torch.distributed.get_world_size() + active_ranks = torch.ones(ep_size, dtype=torch.int32) + last_active_ranks = active_ranks.clone() + + _global_elastic_ep_metadata = ElasticEpMetadata( + using_elastic_ep=False, # TODO pr elastic_ep to add args decide whether use elastic ep + active_ranks=active_ranks, + last_active_ranks=last_active_ranks, + ) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index e2a2678104af..85ffb3cf7be9 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,8 @@ import torch -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec +from sglang.srt.elastic_ep.elastic_ep import get_global_elastic_ep_metadata +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elastic_ep class EplbAlgorithm(Enum): @@ -11,6 +12,7 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() + elastic_ep = auto() # TODO may have more algorithm later @@ -45,6 +47,17 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) + if algorithm == EplbAlgorithm.elastic_ep: + return elastic_ep.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical, + active_ranks=get_global_elastic_ep_metadata().active_ranks, + ) + raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py new file mode 100644 index 000000000000..098f920e6d4c --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py @@ -0,0 +1,88 @@ +from typing import Tuple + +import torch + +from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, + active_ranks: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + num_active_ranks = active_ranks.sum().item() + num_local_experts = num_replicas // num_gpus + if num_active_ranks < num_gpus: + # Must fall back to global load-balance policy + # and fix some params + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, + num_local_experts * num_active_ranks, + 1, + 1, + num_alive_gpus, + # weight, num_local_experts * num_alive_gpus, 1, 1, num_alive_gpus + ) + elif enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange( + num_local_experts * num_active_ranks, + dtype=torch.int64, + device=log2phy.device, + ).expand(num_layers, -1), + ) + if num_active_ranks < num_gpus: + phy2log_slices = list( + phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) + ) + active_ranks_list = active_ranks.tolist() + for idx, active_rank in enumerate(active_ranks_list): + if not active_rank: + phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) + log2phy = torch.where( + log2phy >= idx * num_local_experts, + log2phy + num_local_experts, + log2phy, + ) + phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) + return phy2log, log2phy, logcnt diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 39ee02aaf673..cf8fc084b13b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,7 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.elastic_ep.elastic_ep import get_global_elastic_ep_metadata from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -927,16 +928,32 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - self.expert_location_updater.update( - self.model.routed_experts_weights_of_layer, - new_expert_location_metadata, - update_layer_ids=update_layer_ids, - nnodes=self.server_args.nnodes, - rank=self.tp_rank, - ) + if get_global_elastic_ep_metadata().using_elastic_ep: + old_expert_location_metadata = get_global_expert_location_metadata() + assert old_expert_location_metadata is not None + old_expert_location_metadata.update( + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + ) + self.update_weights_from_disk( + self.server_args.model_path, + self.server_args.load_format, + lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name, + ) + else: + self.expert_location_updater.update( + self.model.routed_experts_weights_of_layer, + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=self.server_args.nnodes, + rank=self.tp_rank, + ) def update_weights_from_disk( - self, model_path: str, load_format: str + self, + model_path: str, + load_format: str, + weight_name_filter: Optional[Callable[[str], bool]] = None, ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" logger.info( @@ -958,6 +975,11 @@ def get_weight_iter(config): iter = loader._get_weights_iterator( DefaultModelLoader.Source.init_new(config, self.model) ) + if weight_name_filter is not None: + iter = ( + (name, weight) for name, weight in iter if weight_name_filter(name) + ) + return iter def model_load_weights(model, iter): @@ -2131,6 +2153,7 @@ def forward( self.forward_pass_id, forward_batch, ): + output = self._forward_raw( forward_batch, skip_attn_backend_init, diff --git a/test/srt/ep/test_elastic_ep_eplb.py b/test/srt/ep/test_elastic_ep_eplb.py new file mode 100755 index 000000000000..be4ab3e06e9a --- /dev/null +++ b/test/srt/ep/test_elastic_ep_eplb.py @@ -0,0 +1,81 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class _BaseTestDynamicEPLB(CustomTestCase): + extra_args = [] + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "mooncake", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + # TODO pr-chain: enable later + # "--enable-expert-distribution-metrics", + # TODO auto determine these flags + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + *cls.extra_args, + ], + env={ + "SGL_ENABLE_JIT_DEEPGEMM": "0", + "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", + **os.environ, + }, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +class TestDynamicEPLBSimple(_BaseTestDynamicEPLB): + pass + + +if __name__ == "__main__": + unittest.main() From f01ba58fd1ce2235a1324cfdf0febef7fb896995 Mon Sep 17 00:00:00 2001 From: Xun Sun Date: Wed, 17 Sep 2025 22:26:37 +0800 Subject: [PATCH 02/44] Let `token_dispatcher/mooncake.py` use the `global_elastic_ep_metadata` (#13) --- python/sglang/srt/elastic_ep/elastic_ep.py | 2 +- .../srt/eplb/eplb_algorithms/elastic_ep.py | 4 ++-- .../layers/moe/token_dispatcher/mooncake.py | 18 ++++-------------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 79795060ad6b..707549bb1649 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -33,7 +33,7 @@ def _init_global_elastic_ep_metadata(): return ep_size = torch.distributed.get_world_size() - active_ranks = torch.ones(ep_size, dtype=torch.int32) + active_ranks = torch.ones(ep_size, dtype=torch.int32, device="cuda") last_active_ranks = active_ranks.clone() _global_elastic_ep_metadata = ElasticEpMetadata( diff --git a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py index 098f920e6d4c..eb3c3a7126cd 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py +++ b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py @@ -42,8 +42,8 @@ def rebalance_experts( num_local_experts * num_active_ranks, 1, 1, - num_alive_gpus, - # weight, num_local_experts * num_alive_gpus, 1, 1, num_alive_gpus + num_active_ranks, + # weight, num_local_experts * num_active_ranks, 1, 1, num_active_ranks ) elif enable_hierarchical: # use hierarchical load-balance policy diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index d6d56186563a..0fd2e7f7d009 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple +from sglang.srt.elastic_ep.elastic_ep import _init_global_elastic_ep_metadata, get_global_elastic_ep_metadata from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -62,14 +63,6 @@ def format(self) -> CombineInputFormat: assert isinstance(MooncakeCombineInput, CombineInput) -_ACTIVE_RANKS: Optional[torch.Tensor] = None - - -def get_ep_active_ranks() -> torch.Tensor: - assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized" - return _ACTIVE_RANKS - - class EPBuffer: _buffer = None _hidden_size: Optional[int] = None @@ -152,12 +145,9 @@ def __init__( self.first_execution = True self.timeout_us = 10000000 - global _ACTIVE_RANKS - if _ACTIVE_RANKS is None: - _ACTIVE_RANKS = torch.ones( - (self.num_experts,), dtype=torch.int32, device="cuda" - ) - self.active_ranks = _ACTIVE_RANKS + _init_global_elastic_ep_metadata() + global_elastic_ep_metadata = get_global_elastic_ep_metadata() + self.active_ranks = global_elastic_ep_metadata.active_ranks self.handle = None From cd65b6930879ea8d9d3851d4ca09ba9249a20bbd Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 17 Sep 2025 23:57:07 +0800 Subject: [PATCH 03/44] fix fi fi fix fix fix fix fix fix fix fix fix fit fix --- python/sglang/srt/elastic_ep/elastic_ep.py | 73 ++++++++++++----- .../srt/eplb/eplb_algorithms/__init__.py | 12 +-- .../{elastic_ep.py => elasticity_aware.py} | 0 .../layers/moe/token_dispatcher/mooncake.py | 6 +- .../sglang/srt/model_executor/model_runner.py | 5 +- test/srt/ep/test_elastic_ep_eplb.py | 81 ------------------- 6 files changed, 64 insertions(+), 113 deletions(-) rename python/sglang/srt/eplb/eplb_algorithms/{elastic_ep.py => elasticity_aware.py} (100%) delete mode 100755 test/srt/ep/test_elastic_ep_eplb.py diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 707549bb1649..ba9c1373d493 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -1,43 +1,76 @@ from __future__ import annotations -from curses import use_env from dataclasses import dataclass +from threading import Lock from typing import Optional import torch +from sglang.srt.utils import is_cpu, is_cuda + @dataclass -class ElasticEpMetadata: +class ElasticEPState: using_elastic_ep: bool active_ranks: Optional[torch.Tensor] last_active_ranks: Optional[torch.Tensor] + active_ranks_cpu: Optional[torch.Tensor] + + def is_active_equal_last(self) -> bool: + if self.active_ranks is None or self.last_active_ranks is None: + return False + return torch.equal(self.active_ranks, self.last_active_ranks) + + def sync_active_to_cpu(self): + if self.active_ranks is not None: + self.active_ranks_cpu = self.active_ranks.detach().cpu().clone() + + def snapshot_active_to_last(self): + if self.active_ranks is not None: + self.last_active_ranks = self.active_ranks.clone() + +__elastic_ep_state: Optional[ElasticEPState] = None +__state_lock = Lock() -_global_elastic_ep_metadata: Optional[ElasticEpMetadata] = None +def get_elastic_ep_state(): + global __elastic_ep_state + if __elastic_ep_state is None: + with __state_lock: + if __elastic_ep_state is None: + __elastic_ep_state = _build_default_state() + return __elastic_ep_state -def get_global_elastic_ep_metadata(): - return _global_elastic_ep_metadata +def _build_default_state() -> ElasticEPState: + return _build_state(ep_size=None, device=None, using_elastic_ep=True) -def set_global_elastic_ep_metadata(value): - global _global_elastic_ep_metadata - assert _global_elastic_ep_metadata is None - _global_elastic_ep_metadata = value +def _select_device() -> torch.device: + # cuda or cpu for now + if is_cuda(): + return torch.device("cuda") + elif is_cpu(): + return torch.device("cpu") + else: + raise NotImplementedError("Only CUDA and CPU are supported now.") -def _init_global_elastic_ep_metadata(): - global _global_elastic_ep_metadata - if _global_elastic_ep_metadata is not None: - return - ep_size = torch.distributed.get_world_size() - active_ranks = torch.ones(ep_size, dtype=torch.int32, device="cuda") - last_active_ranks = active_ranks.clone() +def _build_state( + *, + ep_size: Optional[int], + device: Optional[torch.device], + using_elastic_ep: bool, +) -> ElasticEPState: + ep = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else _select_device() - _global_elastic_ep_metadata = ElasticEpMetadata( - using_elastic_ep=False, # TODO pr elastic_ep to add args decide whether use elastic ep - active_ranks=active_ranks, - last_active_ranks=last_active_ranks, + active = torch.ones(ep, dtype=torch.int32, device=dev) + state = ElasticEPState( + using_elastic_ep=using_elastic_ep, + active_ranks=active, + last_active_ranks=active.clone(), + active_ranks_cpu=active.detach().cpu().clone(), ) + return state diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index 85ffb3cf7be9..7109726b147d 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,8 +3,8 @@ import torch -from sglang.srt.elastic_ep.elastic_ep import get_global_elastic_ep_metadata -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elastic_ep +from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware class EplbAlgorithm(Enum): @@ -12,7 +12,7 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() - elastic_ep = auto() + elasticity_aware = auto() # TODO may have more algorithm later @@ -47,15 +47,15 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) - if algorithm == EplbAlgorithm.elastic_ep: - return elastic_ep.rebalance_experts( + if algorithm == EplbAlgorithm.elasticity_aware: + return elasticity_aware.rebalance_experts( weight=tokens_per_expert.sum(dim=0), num_replicas=num_physical_experts, num_groups=num_groups, num_nodes=num_nodes, num_gpus=num_physical_experts // num_local_physical_experts, enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical, - active_ranks=get_global_elastic_ep_metadata().active_ranks, + active_ranks=get_elastic_ep_state().active_ranks, ) raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py similarity index 100% rename from python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py rename to python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index 0fd2e7f7d009..76be31a626f2 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple -from sglang.srt.elastic_ep.elastic_ep import _init_global_elastic_ep_metadata, get_global_elastic_ep_metadata +from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -145,9 +145,7 @@ def __init__( self.first_execution = True self.timeout_us = 10000000 - _init_global_elastic_ep_metadata() - global_elastic_ep_metadata = get_global_elastic_ep_metadata() - self.active_ranks = global_elastic_ep_metadata.active_ranks + self.active_ranks = get_elastic_ep_state().active_ranks self.handle = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cf8fc084b13b..0d9b4c055a33 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -51,7 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.elastic_ep.elastic_ep import get_global_elastic_ep_metadata +from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -928,7 +928,8 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - if get_global_elastic_ep_metadata().using_elastic_ep: + if get_elastic_ep_state().using_elastic_ep: + # TODO: refactor the weights update when elastic ep old_expert_location_metadata = get_global_expert_location_metadata() assert old_expert_location_metadata is not None old_expert_location_metadata.update( diff --git a/test/srt/ep/test_elastic_ep_eplb.py b/test/srt/ep/test_elastic_ep_eplb.py deleted file mode 100755 index be4ab3e06e9a..000000000000 --- a/test/srt/ep/test_elastic_ep_eplb.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -class _BaseTestDynamicEPLB(CustomTestCase): - extra_args = [] - - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "mooncake", - "--enable-eplb", - "--ep-num-redundant-experts", - "4", - "--eplb-rebalance-num-iterations", - "50", - "--expert-distribution-recorder-buffer-size", - "50", - # TODO pr-chain: enable later - # "--enable-expert-distribution-metrics", - # TODO auto determine these flags - "--expert-distribution-recorder-mode", - "stat", - "--ep-dispatch-algorithm", - "static", - *cls.extra_args, - ], - env={ - "SGL_ENABLE_JIT_DEEPGEMM": "0", - "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", - **os.environ, - }, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.5) - - -class TestDynamicEPLBSimple(_BaseTestDynamicEPLB): - pass - - -if __name__ == "__main__": - unittest.main() From 3e38276238d367dca4165bc6ddee426228cf31df Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 2 Oct 2025 12:12:28 +0800 Subject: [PATCH 04/44] some fix --- python/sglang/srt/elastic_ep/elastic_ep.py | 24 +++++++------------ .../sglang/srt/model_executor/model_runner.py | 4 ++-- python/sglang/srt/server_args.py | 12 ++++++++++ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index ba9c1373d493..ccd7a5e6cb78 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -1,17 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from threading import Lock from typing import Optional import torch +from sglang.srt.managers.schedule_batch import ServerArgs from sglang.srt.utils import is_cpu, is_cuda @dataclass class ElasticEPState: - using_elastic_ep: bool active_ranks: Optional[torch.Tensor] last_active_ranks: Optional[torch.Tensor] active_ranks_cpu: Optional[torch.Tensor] @@ -30,21 +29,18 @@ def snapshot_active_to_last(self): self.last_active_ranks = self.active_ranks.clone() -__elastic_ep_state: Optional[ElasticEPState] = None -__state_lock = Lock() +_elastic_ep_state: Optional[ElasticEPState] = None def get_elastic_ep_state(): - global __elastic_ep_state - if __elastic_ep_state is None: - with __state_lock: - if __elastic_ep_state is None: - __elastic_ep_state = _build_default_state() - return __elastic_ep_state + return _elastic_ep_state -def _build_default_state() -> ElasticEPState: - return _build_state(ep_size=None, device=None, using_elastic_ep=True) +def init_elastic_ep_state(server_args: ServerArgs): + global _elastic_ep_state + assert _elastic_ep_state is None + if server_args.elastic_ep_backend is not None: + return _build_state(ep_size=None, device=None) def _select_device() -> torch.device: @@ -54,21 +50,19 @@ def _select_device() -> torch.device: elif is_cpu(): return torch.device("cpu") else: - raise NotImplementedError("Only CUDA and CPU are supported now.") + raise NotImplementedError("Only CUDA and CPU support elastic ep now.") def _build_state( *, ep_size: Optional[int], device: Optional[torch.device], - using_elastic_ep: bool, ) -> ElasticEPState: ep = ep_size if ep_size is not None else torch.distributed.get_world_size() dev = device if device is not None else _select_device() active = torch.ones(ep, dtype=torch.int32, device=dev) state = ElasticEPState( - using_elastic_ep=using_elastic_ep, active_ranks=active, last_active_ranks=active.clone(), active_ranks_cpu=active.detach().cpu().clone(), diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0d9b4c055a33..aac7ef398f24 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -51,7 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state +from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state, init_elastic_ep_state from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -384,7 +384,7 @@ def initialize(self, min_per_gpu_memory: float): else None ) self.expert_location_updater = ExpertLocationUpdater() - + init_elastic_ep_state(self.server_args) # Load the model self.sampler = Sampler() self.load_model() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8d179b2c7391..6be7884f8204 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -577,6 +577,9 @@ def __post_init__(self): # Handle any other necessary validations. self._handle_other_validations() + # Handle elastic expert parallelism. + self._handle_elastic_ep() + def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1125,6 +1128,15 @@ def _handle_eplb_and_dispatch(self): if self.enable_eplb: assert self.ep_size > 1 + + def _handle_elastic_ep(self): + if self.elastic_ep_backend is not None: + assert self.enable_eplb, "Elastic EP requires EPLB to be enabled." + assert self.ep_dispatch_algorithm == "dynamic", "Elastic EP requires EPLB dynamic dispatch." + if self.eplb_algorithm == "auto" : + self.eplb_algorithm = "elasticity_aware" + assert self.eplb_algorithm != "elasticity_aware", "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." + def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( self.expert_distribution_recorder_mode is None From 484058c3af0101babc148890ac15278625ba73a2 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 2 Oct 2025 16:13:02 +0800 Subject: [PATCH 05/44] fxi --- python/sglang/srt/eplb/eplb_algorithms/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index 7109726b147d..af467f380c5c 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -54,7 +54,7 @@ def rebalance_experts( num_groups=num_groups, num_nodes=num_nodes, num_gpus=num_physical_experts // num_local_physical_experts, - enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical, + enable_hierarchical=True, active_ranks=get_elastic_ep_state().active_ranks, ) From cb41b43154315b7393652cfbb8802a48befd7ebe Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 11:12:21 +0800 Subject: [PATCH 06/44] lint --- python/sglang/srt/server_args.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6be7884f8204..474445335aa9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1128,14 +1128,17 @@ def _handle_eplb_and_dispatch(self): if self.enable_eplb: assert self.ep_size > 1 - def _handle_elastic_ep(self): if self.elastic_ep_backend is not None: assert self.enable_eplb, "Elastic EP requires EPLB to be enabled." - assert self.ep_dispatch_algorithm == "dynamic", "Elastic EP requires EPLB dynamic dispatch." - if self.eplb_algorithm == "auto" : + assert ( + self.ep_dispatch_algorithm == "dynamic" + ), "Elastic EP requires EPLB dynamic dispatch." + if self.eplb_algorithm == "auto": self.eplb_algorithm = "elasticity_aware" - assert self.eplb_algorithm != "elasticity_aware", "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." + assert ( + self.eplb_algorithm != "elasticity_aware" + ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( From 560c595d460c4fb9d60f0f64def08f3d6de85453 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 11:38:31 +0800 Subject: [PATCH 07/44] fix --- python/sglang/srt/model_executor/model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index aac7ef398f24..0dabd80fe927 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -928,7 +928,7 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - if get_elastic_ep_state().using_elastic_ep: + if get_elastic_ep_state() is not None: # TODO: refactor the weights update when elastic ep old_expert_location_metadata = get_global_expert_location_metadata() assert old_expert_location_metadata is not None @@ -2154,7 +2154,6 @@ def forward( self.forward_pass_id, forward_batch, ): - output = self._forward_raw( forward_batch, skip_attn_backend_init, From e6875fcaef69952e9ee97b4eb42e1f552e54056a Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 22:01:37 +0800 Subject: [PATCH 08/44] test --- python/sglang/srt/elastic_ep/elastic_ep.py | 2 -- python/sglang/srt/server_args.py | 15 ++++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index ccd7a5e6cb78..f22f76057dc5 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -16,8 +16,6 @@ class ElasticEPState: active_ranks_cpu: Optional[torch.Tensor] def is_active_equal_last(self) -> bool: - if self.active_ranks is None or self.last_active_ranks is None: - return False return torch.equal(self.active_ranks, self.last_active_ranks) def sync_active_to_cpu(self): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 474445335aa9..04220876f447 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1130,15 +1130,12 @@ def _handle_eplb_and_dispatch(self): def _handle_elastic_ep(self): if self.elastic_ep_backend is not None: - assert self.enable_eplb, "Elastic EP requires EPLB to be enabled." - assert ( - self.ep_dispatch_algorithm == "dynamic" - ), "Elastic EP requires EPLB dynamic dispatch." - if self.eplb_algorithm == "auto": - self.eplb_algorithm = "elasticity_aware" - assert ( - self.eplb_algorithm != "elasticity_aware" - ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." + if self.enable_eplb: + if self.eplb_algorithm == "auto": + self.eplb_algorithm = "elasticity_aware" + assert ( + self.eplb_algorithm == "elasticity_aware" + ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( From 63d7b0e719ac1d3e3509050f40c2168cb112603d Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 23:00:39 +0800 Subject: [PATCH 09/44] add ut --- .../eplb/eplb_algorithms/elasticity_aware.py | 1 - .../sglang/test/test_disaggregation_utils.py | 49 ++- test/srt/ep/test_mooncake_ep_small.py | 295 +++++------------- 3 files changed, 114 insertions(+), 231 deletions(-) diff --git a/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py index eb3c3a7126cd..c781c444ae3b 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +++ b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py @@ -43,7 +43,6 @@ def rebalance_experts( 1, 1, num_active_ranks, - # weight, num_local_experts * num_active_ranks, 1, 1, num_active_ranks ) elif enable_hierarchical: # use hierarchical load-balance policy diff --git a/python/sglang/test/test_disaggregation_utils.py b/python/sglang/test/test_disaggregation_utils.py index e8084f802d1b..72960c9af905 100644 --- a/python/sglang/test/test_disaggregation_utils.py +++ b/python/sglang/test/test_disaggregation_utils.py @@ -1,3 +1,4 @@ +import logging import os import time import warnings @@ -15,6 +16,8 @@ popen_with_error_check, ) +logger = logging.getLogger(__name__) + class TestDisaggregationBase(CustomTestCase): @classmethod @@ -100,11 +103,27 @@ def tearDownClass(cls): def get_rdma_devices_args(): + def _parse_list_env(var_name: str): + val = os.getenv(var_name) + if not val: + return None + items = [x.strip() for x in val.split(",") if x.strip()] + return items or None + + def _pick_default_pair(rdma_all_devices): + return [rdma_all_devices[0], rdma_all_devices[len(rdma_all_devices) // 2]] + + rdma_all_devices = _parse_list_env("SGLANG_CI_RDMA_ALL_DEVICES") or [ + f"mlx5_roce{i}" for i in range(8) + ] + logger.info("Resolved rdma_all_devices=%s", rdma_all_devices) + + n_rdma = len(rdma_all_devices) # 1. Get visible GPU indices cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") if not cuda_visible_devices: warnings.warn("CUDA_VISIBLE_DEVICES is not set. Using default RDMA devices.") - return "mlx5_roce0,mlx5_roce4" + return ",".join(_pick_default_pair(rdma_all_devices)) try: # Convert to list of integers (handling possible spaces and empty strings) @@ -112,29 +131,29 @@ def get_rdma_devices_args(): int(idx.strip()) for idx in cuda_visible_devices.split(",") if idx.strip() ] if not gpu_indices or len(gpu_indices) > 4: - return "mlx5_roce0,mlx5_roce4" + return ",".join(_pick_default_pair(rdma_all_devices)) + except ValueError: warnings.warn(f"Invalid CUDA_VISIBLE_DEVICES format: {cuda_visible_devices}") - return "mlx5_roce0,mlx5_roce4" + return ",".join(_pick_default_pair(rdma_all_devices)) # 2. Calculate base RDMA index group (each group of 4 GPUs uses consecutive devices) - base_rdma_group = min(gpu_indices) // 4 * 4 + base_rdma_group = (min(gpu_indices) // 4) * 4 - # 3. Generate RDMA device names - rdma_devices = [] for gpu_idx in gpu_indices: - # Validate GPU index within expected range - if gpu_idx < base_rdma_group or gpu_idx >= base_rdma_group + 4: + if not (base_rdma_group <= gpu_idx < base_rdma_group + 4): + warnings.warn( - f"GPU index {gpu_idx} is outside expected group {base_rdma_group}-{base_rdma_group+3}" + f"GPU index {gpu_idx} is outside expected group " + f"{base_rdma_group}-{base_rdma_group+3}" ) - continue - - # Map GPU index to RDMA device index - rdma_index = base_rdma_group // 4 * 4 + (gpu_idx % 4) - rdma_devices.append(f"mlx5_roce{rdma_index}") + # 3. Generate RDMA device names + rdma_devices = [] + for gpu_idx in gpu_indices: + nic_index = gpu_idx // (8 // n_rdma) + rdma_devices.append(rdma_all_devices[nic_index]) if not rdma_devices: - return "mlx5_roce0,mlx5_roce4" + return ",".join(_pick_default_pair(rdma_all_devices)) return ",".join(rdma_devices) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 111260a8c82d..1fb560b97fa4 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -3,6 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,8 +12,10 @@ popen_launch_server, ) +ib_devices = get_rdma_devices_args() -class TestPureDP(CustomTestCase): + +class _BaseTestMooncakeEp(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -31,7 +34,7 @@ def setUpClass(cls): "--elastic-ep-backend", "mooncake", "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + ib_devices, "--moe-a2a-backend", "deepep", "--deepep-mode", @@ -44,6 +47,7 @@ def setUpClass(cls): "512", "--mem-fraction-static", "0.5", + *cls.extra_args, ], ) @@ -67,219 +71,80 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) -class TestHybridDPTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "256", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "128", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestNoGatherdBuffer(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "32", - "--max-running-requests", - "512", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestTBO(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--enable-two-batch-overlap", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestPureDP(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + ] + + +class TestHybridDPTP(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + ] + + +class TestTP(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + ] + + +class TestNoGatherdBuffer(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + ] + + +class TestTBO(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + ] + + +class TestMooncakeWitchEPLB(_BaseTestMooncakeEp): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + ] if __name__ == "__main__": From 8d8aca9d624b09f8c51378ed1d8d90e9d9c8e6b1 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:01:25 +0800 Subject: [PATCH 10/44] test --- test/srt/ep/test_mooncake_ep_small.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 1fb560b97fa4..6451ba674793 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,7 @@ ) ib_devices = get_rdma_devices_args() - +DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/modelssglang-ci-dsv3-test" class _BaseTestMooncakeEp(CustomTestCase): @classmethod From 6d36b5b94d571e165ebca3c36afe737ce09ac06b Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:02:29 +0800 Subject: [PATCH 11/44] test --- test/srt/ep/test_mooncake_ep_small.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 6451ba674793..f63263d6e7fd 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,7 @@ ) ib_devices = get_rdma_devices_args() -DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/modelssglang-ci-dsv3-test" +DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/models/sglang-ci-dsv3-test" class _BaseTestMooncakeEp(CustomTestCase): @classmethod From 56fb09c84410e937c1b1d2eeb29676f1d1000768 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:21:12 +0800 Subject: [PATCH 12/44] fix --- python/sglang/srt/elastic_ep/elastic_ep.py | 78 +++++++++++----------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index f22f76057dc5..bc9cb310bcba 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +import threading from typing import Optional import torch @@ -27,42 +28,41 @@ def snapshot_active_to_last(self): self.last_active_ranks = self.active_ranks.clone() -_elastic_ep_state: Optional[ElasticEPState] = None - - -def get_elastic_ep_state(): - return _elastic_ep_state - - -def init_elastic_ep_state(server_args: ServerArgs): - global _elastic_ep_state - assert _elastic_ep_state is None - if server_args.elastic_ep_backend is not None: - return _build_state(ep_size=None, device=None) - - -def _select_device() -> torch.device: - # cuda or cpu for now - if is_cuda(): - return torch.device("cuda") - elif is_cpu(): - return torch.device("cpu") - else: - raise NotImplementedError("Only CUDA and CPU support elastic ep now.") - - -def _build_state( - *, - ep_size: Optional[int], - device: Optional[torch.device], -) -> ElasticEPState: - ep = ep_size if ep_size is not None else torch.distributed.get_world_size() - dev = device if device is not None else _select_device() - - active = torch.ones(ep, dtype=torch.int32, device=dev) - state = ElasticEPState( - active_ranks=active, - last_active_ranks=active.clone(), - active_ranks_cpu=active.detach().cpu().clone(), - ) - return state +class ElasticEPStateManager: + _instance: Optional[ElasticEPState] = None + _lock = threading.Lock() + + @classmethod + def instance(cls) -> ElasticEPState: + return cls._instance + + @classmethod + def init(cls, server_args: ServerArgs): + with cls._lock: + if cls._instance is not None: + return cls._instance + + if server_args.elastic_ep_backend is not None: + cls._instance = cls._build_state(ep_size=None, device=None) + return cls._instance + + @staticmethod + def _select_device() -> torch.device: + if is_cuda(): + return torch.device("cuda") + elif is_cpu(): + return torch.device("cpu") + else: + raise NotImplementedError("Only CUDA and CPU support elastic ep now.") + + @classmethod + def _build_state(cls, *, ep_size: Optional[int], device: Optional[torch.device]) -> ElasticEPState: + ep = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else cls._select_device() + + active = torch.ones(ep, dtype=torch.int32, device=dev) + return ElasticEPState( + active_ranks=active, + last_active_ranks=active.clone(), + active_ranks_cpu=active.detach().cpu().clone(), + ) \ No newline at end of file From 2434821e14a6213d570de5dfd79554970fc44a87 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:25:25 +0800 Subject: [PATCH 13/44] fix --- python/sglang/srt/eplb/eplb_algorithms/__init__.py | 4 ++-- python/sglang/srt/layers/moe/token_dispatcher/mooncake.py | 4 ++-- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index af467f380c5c..6cb3da62a81f 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,7 @@ import torch -from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware @@ -55,7 +55,7 @@ def rebalance_experts( num_nodes=num_nodes, num_gpus=num_physical_experts // num_local_physical_experts, enable_hierarchical=True, - active_ranks=get_elastic_ep_state().active_ranks, + active_ranks=ElasticEPStateManager.instance().active_ranks, ) raise NotImplementedError diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index 76be31a626f2..c8fed07564a4 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple -from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -145,7 +145,7 @@ def __init__( self.first_execution = True self.timeout_us = 10000000 - self.active_ranks = get_elastic_ep_state().active_ranks + self.active_ranks = ElasticEPStateManager.instance().active_ranks self.handle = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0dabd80fe927..bd0195984cf8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -51,7 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.elastic_ep.elastic_ep import get_elastic_ep_state, init_elastic_ep_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -928,7 +928,7 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - if get_elastic_ep_state() is not None: + if ElasticEPStateManager.instance() is not None: # TODO: refactor the weights update when elastic ep old_expert_location_metadata = get_global_expert_location_metadata() assert old_expert_location_metadata is not None From 8c0e187e9af91d53bf1c0ac9441cbedef7175a05 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:31:56 +0800 Subject: [PATCH 14/44] fix --- python/sglang/srt/model_executor/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bd0195984cf8..280106c3b9d1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -384,7 +384,8 @@ def initialize(self, min_per_gpu_memory: float): else None ) self.expert_location_updater = ExpertLocationUpdater() - init_elastic_ep_state(self.server_args) + + ElasticEPStateManager.init(self.server_args) if self.server_args.elastic_ep_backend else None # Load the model self.sampler = Sampler() self.load_model() From 37eeaab5ee8204c648ebd054c88961c262ea7bf3 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 00:56:35 +0800 Subject: [PATCH 15/44] fix --- test/srt/ep/test_mooncake_ep_small.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index f63263d6e7fd..c8ee978c177d 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,6 @@ ) ib_devices = get_rdma_devices_args() -DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/models/sglang-ci-dsv3-test" class _BaseTestMooncakeEp(CustomTestCase): @classmethod From 220089812fcfb56e1d1cfba01ae797c67f805bd7 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 01:00:13 +0800 Subject: [PATCH 16/44] lint --- python/sglang/srt/elastic_ep/elastic_ep.py | 8 +++++--- python/sglang/srt/model_executor/model_runner.py | 8 ++++++-- test/srt/ep/test_mooncake_ep_small.py | 1 + 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index bc9cb310bcba..08fccdfb88e3 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -1,7 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass import threading +from dataclasses import dataclass from typing import Optional import torch @@ -56,7 +56,9 @@ def _select_device() -> torch.device: raise NotImplementedError("Only CUDA and CPU support elastic ep now.") @classmethod - def _build_state(cls, *, ep_size: Optional[int], device: Optional[torch.device]) -> ElasticEPState: + def _build_state( + cls, *, ep_size: Optional[int], device: Optional[torch.device] + ) -> ElasticEPState: ep = ep_size if ep_size is not None else torch.distributed.get_world_size() dev = device if device is not None else cls._select_device() @@ -65,4 +67,4 @@ def _build_state(cls, *, ep_size: Optional[int], device: Optional[torch.device]) active_ranks=active, last_active_ranks=active.clone(), active_ranks_cpu=active.detach().cpu().clone(), - ) \ No newline at end of file + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 280106c3b9d1..74baaac00b61 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -384,8 +384,12 @@ def initialize(self, min_per_gpu_memory: float): else None ) self.expert_location_updater = ExpertLocationUpdater() - - ElasticEPStateManager.init(self.server_args) if self.server_args.elastic_ep_backend else None + + ( + ElasticEPStateManager.init(self.server_args) + if self.server_args.elastic_ep_backend + else None + ) # Load the model self.sampler = Sampler() self.load_model() diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index c8ee978c177d..1fb560b97fa4 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -14,6 +14,7 @@ ib_devices = get_rdma_devices_args() + class _BaseTestMooncakeEp(CustomTestCase): @classmethod def setUpClass(cls): From 9808c8dd586d4a262267809d8b28d6dbdd350f02 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 01:15:15 +0800 Subject: [PATCH 17/44] fix --- python/sglang/srt/elastic_ep/elastic_ep.py | 13 ++++++++++--- python/sglang/srt/eplb/eplb_algorithms/__init__.py | 6 +++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 08fccdfb88e3..1296a545eb9b 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -59,12 +59,19 @@ def _select_device() -> torch.device: def _build_state( cls, *, ep_size: Optional[int], device: Optional[torch.device] ) -> ElasticEPState: - ep = ep_size if ep_size is not None else torch.distributed.get_world_size() - dev = device if device is not None else cls._select_device() - active = torch.ones(ep, dtype=torch.int32, device=dev) + active = cls.healthy_rank_state(cls, ep_size=ep_size, device=device) return ElasticEPState( active_ranks=active, last_active_ranks=active.clone(), active_ranks_cpu=active.detach().cpu().clone(), ) + + @classmethod + def healthy_rank_state( + cls, *, ep_size: Optional[int], device: Optional[torch.device] + ) -> torch.Tensor: + size = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else cls._select_device() + + return torch.ones(size, dtype=torch.int32, device=dev) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index 6cb3da62a81f..fc4d8f0f88bb 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -55,7 +55,11 @@ def rebalance_experts( num_nodes=num_nodes, num_gpus=num_physical_experts // num_local_physical_experts, enable_hierarchical=True, - active_ranks=ElasticEPStateManager.instance().active_ranks, + active_ranks=( + ElasticEPStateManager.instance().active_ranks + if ElasticEPStateManager.instance() is not None + else ElasticEPStateManager.healthy_rank_state() + ), ) raise NotImplementedError From cb548756e2ec7500ab1ffc0183f2b2751a05a6d7 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 01:17:22 +0800 Subject: [PATCH 18/44] fix --- test/srt/ep/test_mooncake_ep_small.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 1fb560b97fa4..f63263d6e7fd 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,7 @@ ) ib_devices = get_rdma_devices_args() - +DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/models/sglang-ci-dsv3-test" class _BaseTestMooncakeEp(CustomTestCase): @classmethod From 7b1bd4edb988d6506ba9d26c120648aac7669b42 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 01:19:46 +0800 Subject: [PATCH 19/44] test --- python/sglang/srt/elastic_ep/elastic_ep.py | 2 +- test/srt/ep/test_mooncake_ep_small.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 1296a545eb9b..5258b0a2fc87 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -60,7 +60,7 @@ def _build_state( cls, *, ep_size: Optional[int], device: Optional[torch.device] ) -> ElasticEPState: - active = cls.healthy_rank_state(cls, ep_size=ep_size, device=device) + active = cls.healthy_rank_state(ep_size=ep_size, device=device) return ElasticEPState( active_ranks=active, last_active_ranks=active.clone(), diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index f63263d6e7fd..2a1c2e79626f 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,8 @@ ) ib_devices = get_rdma_devices_args() -DEFAULT_MODEL_NAME_FOR_TEST_MLA="/data/models/sglang-ci-dsv3-test" +DEFAULT_MODEL_NAME_FOR_TEST_MLA = "/data/models/sglang-ci-dsv3-test" + class _BaseTestMooncakeEp(CustomTestCase): @classmethod From 26063224c9204d9ea72ac913d21dc7f3b4bb339f Mon Sep 17 00:00:00 2001 From: Hank Han Date: Thu, 16 Oct 2025 01:24:05 +0800 Subject: [PATCH 20/44] t --- test/srt/ep/test_mooncake_ep_small.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 2a1c2e79626f..1fb560b97fa4 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -13,7 +13,6 @@ ) ib_devices = get_rdma_devices_args() -DEFAULT_MODEL_NAME_FOR_TEST_MLA = "/data/models/sglang-ci-dsv3-test" class _BaseTestMooncakeEp(CustomTestCase): From 642fa37b4c84f99610e218821e9edb3b65e2bf01 Mon Sep 17 00:00:00 2001 From: UNIDY2002 Date: Thu, 11 Sep 2025 15:59:15 +0800 Subject: [PATCH 21/44] Introduce Mooncake Backend and Mooncake EP --- python/sglang/srt/model_executor/model_runner.py | 3 ++- python/sglang/srt/server_args.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 74baaac00b61..aeb835047725 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -258,6 +258,7 @@ def __init__( # Parse args self.mem_fraction_static = mem_fraction_static self.device = server_args.device + self.dist_backend = server_args.dist_backend self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size @@ -745,7 +746,7 @@ def _(data, dim): # Only initialize the distributed environment on the target model worker. init_distributed_environment( - backend=backend, + backend=self.dist_backend, world_size=self.tp_size * self.pp_size, rank=self.tp_size * self.pp_rank + self.tp_rank, local_rank=self.gpu_id, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 04220876f447..44aec44498bd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -596,6 +596,17 @@ def _handle_missing_default_values(self): self.served_model_name = self.model_path if self.device is None: self.device = get_device() + if self.dist_backend is None: + if self.device == "cuda": + self.dist_backend = "nccl" + elif self.device == "xpu": + self.dist_backend = "xccl" + elif self.device == "hpu": + self.dist_backend = "hccl" + elif self.device == "cpu": + self.dist_backend = "gloo" + elif self.device == "npu": + self.dist_backend = "hccl" if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) From f50b9c36e3b084b96937b696e8b6cddd049d0022 Mon Sep 17 00:00:00 2001 From: Hank Han <54751605+HanHan009527@users.noreply.github.com> Date: Thu, 11 Sep 2025 23:10:19 +0800 Subject: [PATCH 22/44] tiny fix mooncake pr (#12) * fix * fix * tiny fix * fix --- python/sglang/srt/server_args.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 44aec44498bd..04220876f447 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -596,17 +596,6 @@ def _handle_missing_default_values(self): self.served_model_name = self.model_path if self.device is None: self.device = get_device() - if self.dist_backend is None: - if self.device == "cuda": - self.dist_backend = "nccl" - elif self.device == "xpu": - self.dist_backend = "xccl" - elif self.device == "hpu": - self.dist_backend = "hccl" - elif self.device == "cpu": - self.dist_backend = "gloo" - elif self.device == "npu": - self.dist_backend = "hccl" if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) From 7a96c6a1bb2913964617de8d87e5047e9b42161f Mon Sep 17 00:00:00 2001 From: UNIDY Date: Mon, 15 Sep 2025 10:36:25 +0800 Subject: [PATCH 23/44] Fix for more readable code --- python/sglang/srt/models/deepseek_v2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3100e14907b4..40d2cb2d87d5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -623,7 +623,14 @@ def __init__( self.top_k = config.num_experts_per_tok +<<<<<<< HEAD if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): +======= + if get_moe_a2a_backend().is_none(): + self._enable_a2a_moe = False + else: + self._enable_a2a_moe = True +>>>>>>> fa7d9d26a (Fix for more readable code) # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( From d71258cccd2bfa538d4d19ed16c681dfea996564 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Tue, 16 Sep 2025 01:15:35 +0800 Subject: [PATCH 24/44] pr2 eplb fix fix fix fix fix fix fix ut ut ut fix fit --- .../srt/eplb/eplb_algorithms/elastic_ep.py | 88 +++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 1 + test/srt/ep/test_elastic_ep_eplb.py | 81 +++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py create mode 100755 test/srt/ep/test_elastic_ep_eplb.py diff --git a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py new file mode 100644 index 000000000000..098f920e6d4c --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py @@ -0,0 +1,88 @@ +from typing import Tuple + +import torch + +from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, + active_ranks: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + num_active_ranks = active_ranks.sum().item() + num_local_experts = num_replicas // num_gpus + if num_active_ranks < num_gpus: + # Must fall back to global load-balance policy + # and fix some params + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, + num_local_experts * num_active_ranks, + 1, + 1, + num_alive_gpus, + # weight, num_local_experts * num_alive_gpus, 1, 1, num_alive_gpus + ) + elif enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange( + num_local_experts * num_active_ranks, + dtype=torch.int64, + device=log2phy.device, + ).expand(num_layers, -1), + ) + if num_active_ranks < num_gpus: + phy2log_slices = list( + phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) + ) + active_ranks_list = active_ranks.tolist() + for idx, active_rank in enumerate(active_ranks_list): + if not active_rank: + phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) + log2phy = torch.where( + log2phy >= idx * num_local_experts, + log2phy + num_local_experts, + log2phy, + ) + phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) + return phy2log, log2phy, logcnt diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index aeb835047725..4a5d8e2d43df 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2160,6 +2160,7 @@ def forward( self.forward_pass_id, forward_batch, ): + output = self._forward_raw( forward_batch, skip_attn_backend_init, diff --git a/test/srt/ep/test_elastic_ep_eplb.py b/test/srt/ep/test_elastic_ep_eplb.py new file mode 100755 index 000000000000..be4ab3e06e9a --- /dev/null +++ b/test/srt/ep/test_elastic_ep_eplb.py @@ -0,0 +1,81 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class _BaseTestDynamicEPLB(CustomTestCase): + extra_args = [] + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "mooncake", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + # TODO pr-chain: enable later + # "--enable-expert-distribution-metrics", + # TODO auto determine these flags + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + *cls.extra_args, + ], + env={ + "SGL_ENABLE_JIT_DEEPGEMM": "0", + "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", + **os.environ, + }, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +class TestDynamicEPLBSimple(_BaseTestDynamicEPLB): + pass + + +if __name__ == "__main__": + unittest.main() From 66202786da79af1c8c34af50efd10a605ab1b574 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 17 Sep 2025 23:57:07 +0800 Subject: [PATCH 25/44] fix fi fi fix fix fix fix fix fix fix fix fix fit fix --- python/sglang/srt/elastic_ep/elastic_ep.py | 1 + .../srt/eplb/eplb_algorithms/elastic_ep.py | 88 ------------------- test/srt/ep/test_elastic_ep_eplb.py | 81 ----------------- 3 files changed, 1 insertion(+), 169 deletions(-) delete mode 100644 python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py delete mode 100755 test/srt/ep/test_elastic_ep_eplb.py diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py index 5258b0a2fc87..f1b2fb661d3e 100644 --- a/python/sglang/srt/elastic_ep/elastic_ep.py +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -2,6 +2,7 @@ import threading from dataclasses import dataclass +from threading import Lock from typing import Optional import torch diff --git a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py b/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py deleted file mode 100644 index 098f920e6d4c..000000000000 --- a/python/sglang/srt/eplb/eplb_algorithms/elastic_ep.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Tuple - -import torch - -from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical - - -def rebalance_experts( - weight: torch.Tensor, - num_replicas: int, - num_groups: int, - num_nodes: int, - num_gpus: int, - enable_hierarchical: bool, - active_ranks: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Entry point for expert-parallelism load balancer. - - Parameters: - weight: [layers, num_logical_experts], the load statistics for all logical experts - num_replicas: number of physical experts, must be a multiple of `num_gpus` - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map: [layers, num_replicas], the expert index of each replica - logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert - expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert - """ - - num_layers, num_logical_experts = weight.shape - weight = weight.float().cpu() - num_active_ranks = active_ranks.sum().item() - num_local_experts = num_replicas // num_gpus - if num_active_ranks < num_gpus: - # Must fall back to global load-balance policy - # and fix some params - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, - num_local_experts * num_active_ranks, - 1, - 1, - num_alive_gpus, - # weight, num_local_experts * num_alive_gpus, 1, 1, num_alive_gpus - ) - elif enable_hierarchical: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus - ) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus - ) - maxlogcnt = logcnt.max().item() - log2phy: torch.Tensor = torch.full( - (num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device, - ) - log2phy.view(num_layers, -1).scatter_( - -1, - phy2log * maxlogcnt + phyrank, - torch.arange( - num_local_experts * num_active_ranks, - dtype=torch.int64, - device=log2phy.device, - ).expand(num_layers, -1), - ) - if num_active_ranks < num_gpus: - phy2log_slices = list( - phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) - ) - active_ranks_list = active_ranks.tolist() - for idx, active_rank in enumerate(active_ranks_list): - if not active_rank: - phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) - log2phy = torch.where( - log2phy >= idx * num_local_experts, - log2phy + num_local_experts, - log2phy, - ) - phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) - return phy2log, log2phy, logcnt diff --git a/test/srt/ep/test_elastic_ep_eplb.py b/test/srt/ep/test_elastic_ep_eplb.py deleted file mode 100755 index be4ab3e06e9a..000000000000 --- a/test/srt/ep/test_elastic_ep_eplb.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -class _BaseTestDynamicEPLB(CustomTestCase): - extra_args = [] - - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "mooncake", - "--enable-eplb", - "--ep-num-redundant-experts", - "4", - "--eplb-rebalance-num-iterations", - "50", - "--expert-distribution-recorder-buffer-size", - "50", - # TODO pr-chain: enable later - # "--enable-expert-distribution-metrics", - # TODO auto determine these flags - "--expert-distribution-recorder-mode", - "stat", - "--ep-dispatch-algorithm", - "static", - *cls.extra_args, - ], - env={ - "SGL_ENABLE_JIT_DEEPGEMM": "0", - "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", - **os.environ, - }, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.5) - - -class TestDynamicEPLBSimple(_BaseTestDynamicEPLB): - pass - - -if __name__ == "__main__": - unittest.main() From 6d8c984c9a7431f9c3ab58ffabc0e653d9b9683f Mon Sep 17 00:00:00 2001 From: UNIDY Date: Wed, 17 Sep 2025 14:14:15 +0800 Subject: [PATCH 26/44] Test fault tolerance --- .../sglang/srt/distributed/parallel_state.py | 33 +++++++++++++++---- .../srt/eplb/eplb_algorithms/__init__.py | 3 ++ python/sglang/srt/layers/dp_attention.py | 7 ++++ python/sglang/srt/managers/scheduler.py | 8 +++++ .../sglang/srt/model_executor/model_runner.py | 25 ++++++++++++-- test/srt/ep/test_mooncake_ep_small.py | 23 ++++++++++++- 6 files changed, 90 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7e18d06db78a..a720279c12d2 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -242,6 +242,8 @@ def __init__( group_name: Optional[str] = None, torch_compile: Optional[bool] = None, gloo_timeout: timedelta = timedelta(seconds=120 * 60), + active_ranks: Optional[torch.Tensor] = None, + active_ranks_cpu: Optional[torch.Tensor] = None, ): # Set group info group_name = group_name or "anonymous" @@ -256,11 +258,6 @@ def __init__( self.local_size = get_int_env_var("LOCAL_SIZE", 0) for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - # a cpu_group to allow direct coordination between processes through - # the CPU. The backend is chosen based on `torch_distributed_backend` if "mooncake" in torch_distributed_backend: cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") else: @@ -1279,6 +1276,8 @@ def init_model_parallel_group( use_mscclpp_allreduce: Optional[bool] = None, use_symm_mem_allreduce: Optional[bool] = None, torch_compile: Optional[bool] = None, + active_ranks: Optional[torch.Tensor] = None, + active_ranks_cpu: Optional[torch.Tensor] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -1290,7 +1289,7 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not _is_npu, + use_pynccl=not _is_npu and not "mooncake" in backend, use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_torch_symm_mem=use_symm_mem_allreduce, @@ -1300,10 +1299,22 @@ def init_model_parallel_group( use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, torch_compile=torch_compile, + active_ranks=active_ranks, + active_ranks_cpu=active_ranks_cpu, ) _TP: Optional[GroupCoordinator] = None +_TP_ACTIVE_RANKS: Optional[torch.Tensor] = None +_TP_ACTIVE_RANKS_CPU: Optional[torch.Tensor] = None + + +def get_tp_active_ranks(): + return _TP_ACTIVE_RANKS + +def get_tp_active_ranks_cpu(): + return _TP_ACTIVE_RANKS_CPU + # duplicate GroupCoordinator for prefill in PD-Multiplexing _PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None @@ -1517,6 +1528,14 @@ def initialize_model_parallel( ) group_ranks.append(ranks) + global _TP_ACTIVE_RANKS + _TP_ACTIVE_RANKS = torch.ones( + (tensor_model_parallel_size,), dtype=torch.int32, device="cuda" + ) + global _TP_ACTIVE_RANKS_CPU + _TP_ACTIVE_RANKS_CPU = torch.ones( + (tensor_model_parallel_size,), dtype=torch.int32, device="cpu" + ) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group( group_ranks, @@ -1527,6 +1546,8 @@ def initialize_model_parallel( ), group_name="tp", torch_compile=torch_compile, + active_ranks=_TP_ACTIVE_RANKS, + active_ranks_cpu=_TP_ACTIVE_RANKS_CPU, ) if duplicate_tp_group: diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index fc4d8f0f88bb..fcac2f35556f 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -73,6 +73,9 @@ def compute_algorithm( if raw_algorithm != "auto": return EplbAlgorithm[raw_algorithm] + if get_elastic_ep_state().using_elastic_ep: + return EplbAlgorithm.elasticity_aware + # TODO test on real scenarios and know which ones perform better if (num_groups is not None) and (num_groups % num_nodes == 0): return EplbAlgorithm.deepseek_hierarchical diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index d4db39a33b3d..c48099e26971 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -31,6 +31,7 @@ _ATTN_TP_GROUP: Optional[GroupCoordinator] = None _ATTN_TP_RANK: Optional[int] = None _ATTN_TP_SIZE: Optional[int] = None +_ATTN_TP_ACTIVE_RANKS: Optional[torch.Tensor] = None _ATTN_DP_RANK: Optional[int] = None _ATTN_DP_SIZE: Optional[int] = None _LOCAL_ATTN_DP_SIZE: Optional[int] = None @@ -252,6 +253,11 @@ def initialize_dp_attention( _ATTN_DP_SIZE = 1 _LOCAL_ATTN_DP_SIZE = 1 + global _ATTN_TP_ACTIVE_RANKS + _ATTN_TP_ACTIVE_RANKS = torch.ones( + (_ATTN_TP_SIZE,), dtype=torch.int32, device="cuda" + ) + tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( [ @@ -268,6 +274,7 @@ def initialize_dp_attention( use_xpu_communicator=False, use_npu_communicator=False, group_name="attention_tp", + active_ranks=_ATTN_TP_ACTIVE_RANKS, ) _DpGatheredBufferWrapper.set_metadata( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index f9601c9ac3d2..97e1f470b930 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2385,9 +2385,12 @@ def prepare_mlp_sync_batch_raw( if disable_overlap_schedule: group = tp_group.device_group device = tp_group.device + torch.distributed.barrier(group=tp_group.cpu_group) + tp_active_ranks = get_tp_active_ranks() else: group = tp_group.cpu_group device = "cpu" + tp_active_ranks = get_tp_active_ranks_cpu() local_info = torch.tensor( [ @@ -2412,6 +2415,11 @@ def prepare_mlp_sync_batch_raw( local_info, group=group, ) + global_info.view(-1, 6)[tp_active_ranks == 0, :] = torch.tensor( + [0, 1, 0, 0, 1, ForwardMode.IDLE.value], + device=global_info.device, + dtype=global_info.dtype, + ) global_num_tokens = global_info[:, 0, 0].tolist() can_cuda_graph = min(global_info[:, 0, 1].tolist()) global_num_tokens_for_logprob = global_info[:, 0, 2].tolist() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4a5d8e2d43df..ae186dc18d4a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2169,8 +2169,29 @@ def forward( split_forward_count, ) - if self.eplb_manager is not None: - self.eplb_manager.on_forward_pass_end() + if not torch.equal( + get_elastic_ep_state().active_ranks, + get_elastic_ep_state().last_active_ranks, + ): + get_elastic_ep_state().last_active_ranks = ( + get_elastic_ep_state().active_ranks.clone() + ) + logging.info(f"recompute _forward_raw") + gen = self.eplb_manager.rebalance() + while True: + try: + next(gen) + except StopIteration: + break + output = self._forward_raw( + forward_batch, + skip_attn_backend_init, + pp_proxy_tensors, + reinit_attn_backend, + split_forward_count, + ) + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() return output diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 1fb560b97fa4..af64ad3751fd 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -1,9 +1,11 @@ +import os import unittest from types import SimpleNamespace from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_disaggregation_utils import get_rdma_devices_args +from sglang.test.send_one import BenchArgs, send_one_prompt from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -47,7 +49,13 @@ def setUpClass(cls): "512", "--mem-fraction-static", "0.5", - *cls.extra_args, + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "24", + "--enable-dp-lm-head", + "--moe-dense-tp-size", + "1", ], ) @@ -70,6 +78,19 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) + def test_bs_1_speed(self): + args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + + def test_bs_1_fault_tolerance(self): + args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + os.system("pkill -f sglang::scheduler_DP2_TP2_EP2") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + class TestPureDP(_BaseTestMooncakeEp): extra_args = [ From 6bafc086c6a71fa069acaf0c057d1f0d23a33b0d Mon Sep 17 00:00:00 2001 From: ympcMark Date: Thu, 18 Sep 2025 21:20:51 +0800 Subject: [PATCH 27/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 6 ++++++ python/sglang/srt/managers/io_struct.py | 5 +++++ python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/managers/tokenizer_manager.py | 5 +++++ 4 files changed, 17 insertions(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 56a87516d2b4..40d345ec6171 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -30,6 +30,7 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( BlockReqInput, + Ranks, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, WatchLoadUpdateReq, @@ -142,6 +143,7 @@ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: # Launch data parallel workers self.scheduler_procs = [] self.workers: List[zmq.Socket] = [None] * server_args.dp_size + self.status: List[int] = [1] * server_args.dp_size if server_args.enable_dp_attention: self.launch_dp_attention_schedulers(server_args, port_args) @@ -166,6 +168,9 @@ def send_control_message(self, obj): def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) + def update_ranks(self, ranks: Ranks): + self.status Ranks.status + def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( [ @@ -173,6 +178,7 @@ def init_dispatcher(self): (TokenizedEmbeddingReqInput, self.dispatching), (BlockReqInput, self.send_to_all_workers), (WatchLoadUpdateReq, self.handle_load_update_req), + (Ranks, self.update_ranks), ] ) self._request_dispatcher.add_fallback_fn(self.send_control_message) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5a7e5ec6d708..f19c1a485f24 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1197,6 +1197,11 @@ def __post_init__(self): if self.rid is None: self.rid = "" +class Ranks: + status: List[int] + + def __init__(self, status: List[int]): + self.status = status @dataclass class GetInternalStateReq(BaseReq): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 97e1f470b930..00c5654f0987 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -95,6 +95,7 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + Ranks, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, RpcReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e646c2a6cdc6..27581acc80a5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -61,6 +61,7 @@ HealthCheckOutput, MultiTokenizerWrapper, OpenSessionReqOutput, + Ranks, SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -362,6 +363,7 @@ def __init__( ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. (HealthCheckOutput, lambda x: None), + (Ranks, self.update_ranks), ] ) @@ -1705,6 +1707,9 @@ def _handle_abort_req(self, recv_obj: AbortReq): } state.out_list.append(out) state.event.set() + + def update_ranks(self, ranks: Ranks): + self.send_to_scheduler.send_pyobj(ranks) def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( From 46c0589c4b27edf3aa4acc4cb0cd636e376a4188 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Thu, 18 Sep 2025 21:22:44 +0800 Subject: [PATCH 28/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 40d345ec6171..5fd3cf02d00d 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -169,7 +169,7 @@ def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) def update_ranks(self, ranks: Ranks): - self.status Ranks.status + self.status = ranks.status def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( From 7162a8fcf0c57991d87845526d147bfd7029e637 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Thu, 18 Sep 2025 21:26:20 +0800 Subject: [PATCH 29/44] feat --- .../srt/managers/data_parallel_controller.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 5fd3cf02d00d..74568e61f003 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -443,12 +443,23 @@ def round_robin_scheduler(self, req: Req): return if self.server_args.disaggregation_mode == "null": - self.workers[self.round_robin_counter].send_pyobj(req) - self.round_robin_counter = (self.round_robin_counter + 1) % len( - self.workers - ) + while True: + if self.status[self.round_robin_counter] == 1: + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) + break + self.round_robin_counter = (self.round_robin_counter + 1) % len( + self.workers + ) else: - self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req) + id = req.bootstrap_room % len(self.workers) + while True: + if self.status[id] == 1: + self.workers[id].send_pyobj(req) + break + id = (id + 1) % len(self.workers) def shortest_queue_scheduler(self, req): if self.maybe_external_dp_rank_routing(req): From 80c2c0b6156dca32d449a37b07ff24f9457b615f Mon Sep 17 00:00:00 2001 From: ympcMark Date: Thu, 18 Sep 2025 21:32:10 +0800 Subject: [PATCH 30/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 74568e61f003..dce763e948a5 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -170,6 +170,7 @@ def handle_load_update_req(self, obj): def update_ranks(self, ranks: Ranks): self.status = ranks.status + print(f"update: {self.status}") def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( From 175f9f42d6350f512a8a98bc5c3596eb0bf6b7c5 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 21:20:05 +0800 Subject: [PATCH 31/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index dce763e948a5..74568e61f003 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -170,7 +170,6 @@ def handle_load_update_req(self, obj): def update_ranks(self, ranks: Ranks): self.status = ranks.status - print(f"update: {self.status}") def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( From 5f72a00fc689442c8bc9b4f1f26c922a250585a3 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 21:29:28 +0800 Subject: [PATCH 32/44] feat --- test/srt/ep/test_mooncake_ep_small.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index af64ad3751fd..7fd8a6fb2712 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -87,7 +87,17 @@ def test_bs_1_fault_tolerance(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) acc_length, speed = send_one_prompt(args) print(f"{speed=:.2f}") - os.system("pkill -f sglang::scheduler_DP2_TP2_EP2") + os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") + acc_length, speed = send_one_prompt(args) + print(f"{speed=:.2f}") acc_length, speed = send_one_prompt(args) print(f"{speed=:.2f}") From a760bab84a14959fa173ffbf7cd7088c5b5dce3f Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 21:50:23 +0800 Subject: [PATCH 33/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 74568e61f003..ecbdccbcbbc3 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -445,6 +445,10 @@ def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": while True: if self.status[self.round_robin_counter] == 1: + print("ATTION !!!!!!!") + print(self.round_robin_counter) + print(self.status) + print("END!!!!!!!") self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers From 3a16563063dc6933bc33c8ac7ee12ca32c37d830 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 22:06:18 +0800 Subject: [PATCH 34/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 3 +++ test/srt/ep/test_mooncake_ep_small.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index ecbdccbcbbc3..d64fa1e3ecc5 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -170,6 +170,9 @@ def handle_load_update_req(self, obj): def update_ranks(self, ranks: Ranks): self.status = ranks.status + print("NEW RANKS!!!") + print(self.status) + print("END!!!!!") def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 7fd8a6fb2712..3d87d917ab8a 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -1,6 +1,8 @@ import os import unittest from types import SimpleNamespace +import time + from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k From d9e286676c5a1ecdf04689e780d21f310a84d857 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 22:19:14 +0800 Subject: [PATCH 35/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index d64fa1e3ecc5..74568e61f003 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -170,9 +170,6 @@ def handle_load_update_req(self, obj): def update_ranks(self, ranks: Ranks): self.status = ranks.status - print("NEW RANKS!!!") - print(self.status) - print("END!!!!!") def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( @@ -448,10 +445,6 @@ def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": while True: if self.status[self.round_robin_counter] == 1: - print("ATTION !!!!!!!") - print(self.round_robin_counter) - print(self.status) - print("END!!!!!!!") self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers From a67f9ab1f6cbcce3c0b19b7cc12091029ec68322 Mon Sep 17 00:00:00 2001 From: ympcMark Date: Fri, 19 Sep 2025 23:37:00 +0800 Subject: [PATCH 36/44] feat --- python/sglang/srt/managers/data_parallel_controller.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 74568e61f003..5cead91b4d9e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -445,10 +445,12 @@ def round_robin_scheduler(self, req: Req): if self.server_args.disaggregation_mode == "null": while True: if self.status[self.round_robin_counter] == 1: + print(f"choose worker {self.round_robin_counter}") self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers ) + break self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers From c018c58821b0ca38dc81e5ff6c7c5c966af39a73 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 15:12:41 +0800 Subject: [PATCH 37/44] fix --- python/sglang/srt/models/deepseek_v2.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 40d2cb2d87d5..3100e14907b4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -623,14 +623,7 @@ def __init__( self.top_k = config.num_experts_per_tok -<<<<<<< HEAD if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): -======= - if get_moe_a2a_backend().is_none(): - self._enable_a2a_moe = False - else: - self._enable_a2a_moe = True ->>>>>>> fa7d9d26a (Fix for more readable code) # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( From 8e4d1e9f24f143af413af7fc2058f7334396e2b2 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 15:15:44 +0800 Subject: [PATCH 38/44] ut --- test/srt/ep/test_mooncake_ep_small.py | 227 ++++++++++++++++++++++---- 1 file changed, 196 insertions(+), 31 deletions(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 3d87d917ab8a..e5d9c9597e6e 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -66,6 +66,7 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): + os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") args = SimpleNamespace( num_shots=5, data_path=None, @@ -104,43 +105,207 @@ def test_bs_1_fault_tolerance(self): print(f"{speed=:.2f}") -class TestPureDP(_BaseTestMooncakeEp): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - ] +class TestHybridDPTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "8", + "--enable-dp-attention", + "--dp", + "2", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", +ib_devices, "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "128", + "--mem-fraction-static", + "0.7", + "--max-running-requests", + "256", + "--disable-custom-all-reduce", + "--enable-eplb", + "--ep-num-redundant-experts", + "24", + "--enable-dp-lm-head", + "--moe-dense-tp-size", + "1", + ], + ) + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) -class TestHybridDPTP(_BaseTestMooncakeEp): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - ] + def test_gsm8k(self): + #os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") + os.system("pkill -f sglang::scheduler_DP1_TP2_EP2") + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + self.assertGreater(metrics["accuracy"], 0.60) -class TestTP(_BaseTestMooncakeEp): - extra_args = [ - "--tp", - "4", - ] +class TestTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", +ib_devices, "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + ], + ) -class TestNoGatherdBuffer(_BaseTestMooncakeEp): - extra_args = [ - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - ] + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestNoGatherdBuffer(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-dp-lm-head", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", +ib_devices, "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--cuda-graph-max-bs", + "32", + "--max-running-requests", + "512", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + +class TestTBO(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--elastic-ep-backend", + "mooncake", + "--mooncake-ib-device", +ib_devices, "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "low_latency", + "--chunked-prefill-size", + "512", + "--enable-two-batch-overlap", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "512", + ], + ) class TestTBO(_BaseTestMooncakeEp): From f28ac4b401a4deb840731ed76f0de72d035733f0 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 16:36:35 +0800 Subject: [PATCH 39/44] test --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ae186dc18d4a..f56bf51139d5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -258,7 +258,7 @@ def __init__( # Parse args self.mem_fraction_static = mem_fraction_static self.device = server_args.device - self.dist_backend = server_args.dist_backend + self.dist_backend = server_args.elastic_ep_backend self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size From c6c241b6478eda0d7e5f23275e886a92b76094cb Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 16:47:01 +0800 Subject: [PATCH 40/44] lint --- python/sglang/srt/distributed/parallel_state.py | 6 ++++++ python/sglang/srt/managers/data_parallel_controller.py | 6 +++--- python/sglang/srt/managers/io_struct.py | 4 +++- python/sglang/srt/managers/tokenizer_manager.py | 2 +- test/srt/ep/test_mooncake_ep_small.py | 5 ++--- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index a720279c12d2..0fb38876e8b9 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -258,6 +258,11 @@ def __init__( self.local_size = get_int_env_var("LOCAL_SIZE", 0) for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a cpu_group to allow direct coordination between processes through + # the CPU. The backend is chosen based on `torch_distributed_backend` if "mooncake" in torch_distributed_backend: cpu_group = torch.distributed.new_group(ranks, backend="mooncake-cpu") else: @@ -1312,6 +1317,7 @@ def init_model_parallel_group( def get_tp_active_ranks(): return _TP_ACTIVE_RANKS + def get_tp_active_ranks_cpu(): return _TP_ACTIVE_RANKS_CPU diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 5cead91b4d9e..5a7ae982e76a 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -450,11 +450,11 @@ def round_robin_scheduler(self, req: Req): self.round_robin_counter = (self.round_robin_counter + 1) % len( self.workers ) - + break self.round_robin_counter = (self.round_robin_counter + 1) % len( - self.workers - ) + self.workers + ) else: id = req.bootstrap_room % len(self.workers) while True: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f19c1a485f24..7263fb838692 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1197,12 +1197,14 @@ def __post_init__(self): if self.rid is None: self.rid = "" + class Ranks: status: List[int] - + def __init__(self, status: List[int]): self.status = status + @dataclass class GetInternalStateReq(BaseReq): pass diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 27581acc80a5..d6a4c8820859 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1707,7 +1707,7 @@ def _handle_abort_req(self, recv_obj: AbortReq): } state.out_list.append(out) state.event.set() - + def update_ranks(self, ranks: Ranks): self.send_to_scheduler.send_pyobj(ranks) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index e5d9c9597e6e..a73be20959a8 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -1,8 +1,7 @@ import os +import time import unittest from types import SimpleNamespace -import time - from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k @@ -151,7 +150,7 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): - #os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") + # os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") os.system("pkill -f sglang::scheduler_DP1_TP2_EP2") args = SimpleNamespace( num_shots=5, From c4e4dd02156c5293183e1a40e3dc5a2b3ca8df03 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 16:58:36 +0800 Subject: [PATCH 41/44] test --- python/sglang/srt/managers/scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 00c5654f0987..546078a9ceb8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -59,7 +59,12 @@ TransferBackend, prepare_abort, ) -from sglang.srt.distributed import get_pp_group, get_world_group +from sglang.srt.distributed import ( + get_pp_group, + get_tp_active_ranks, + get_tp_active_ranks_cpu, + get_world_group, +) from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import compute_dp_attention_world_info From e5d9d16be3fa24062739daf02600626e13d99425 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 17:02:31 +0800 Subject: [PATCH 42/44] t --- python/sglang/srt/managers/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 546078a9ceb8..f8e7ea1ad763 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -154,7 +154,7 @@ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache -from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput From d6e7d26d480a0249836dfde211bb56ca74599919 Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 17:41:32 +0800 Subject: [PATCH 43/44] t --- python/sglang/srt/model_executor/model_runner.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f56bf51139d5..aad6210616ad 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2169,13 +2169,9 @@ def forward( split_forward_count, ) - if not torch.equal( - get_elastic_ep_state().active_ranks, - get_elastic_ep_state().last_active_ranks, - ): - get_elastic_ep_state().last_active_ranks = ( - get_elastic_ep_state().active_ranks.clone() - ) + if get_elastic_ep_state() is not None and not get_elastic_ep_state().is_active_equal_last(): + get_elastic_ep_state().snapshot_active_to_last() + get_elastic_ep_state().sync_active_to_cpu() logging.info(f"recompute _forward_raw") gen = self.eplb_manager.rebalance() while True: From f4a6b7c92b31ff4c2518e97532819389351e386d Mon Sep 17 00:00:00 2001 From: Hank Han Date: Wed, 15 Oct 2025 17:47:46 +0800 Subject: [PATCH 44/44] test --- test/srt/ep/test_mooncake_ep_small.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index a73be20959a8..ad7c49535ae2 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -65,7 +65,7 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): - os.system("pkill -f sglang::scheduler_DP0_TP0_EP0") + os.system("pkill -f sglang::scheduler_DP1_TP1_EP1") args = SimpleNamespace( num_shots=5, data_path=None,