diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index aad5bbc05a7e..a0c520b2c2ea 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -134,7 +134,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None | -| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None | +| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Supports 'mooncake' and 'deepep'. Use 'none' to disable. | None | | `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None | | `--tp-size` | The tensor parallelism size. | 1 | | `--pp-size` | The pipeline parallelism size. | 1 | diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py new file mode 100644 index 000000000000..6f9b564bdc25 --- /dev/null +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from sglang.srt.managers.schedule_batch import ServerArgs +from sglang.srt.utils import is_cpu, is_cuda + + +@dataclass +class ElasticEPState: + _active_ranks: Optional[torch.Tensor] + _last_active_ranks: Optional[torch.Tensor] + _active_ranks_cpu: Optional[torch.Tensor] + on_forward: Optional[Callable] = None + rank_status: Optional[torch.Tensor] = None + + def is_active_equal_last(self) -> bool: + return torch.equal(self._active_ranks, self._last_active_ranks) + + def sync_active_to_cpu(self): + if self._active_ranks is not None: + self._active_ranks_cpu = self._active_ranks.detach().cpu().clone() + + def snapshot_active_to_last(self): + if self._active_ranks is not None: + self._last_active_ranks = self._active_ranks.clone() + + +class ElasticEPStateManager: + _instance: Optional[ElasticEPState] = None + _lock = threading.Lock() + + @staticmethod + def on_forward_mooncake( + state: ElasticEPState, status: torch.Tensor = None, **kwargs + ): + state._active_ranks = state.rank_status.to(dtype=torch.int32) + + @staticmethod + def on_forward_deepep(state: ElasticEPState, status: torch.Tensor = None, **kwargs): + state._active_ranks = 1 - state.rank_status.to(torch.int32) + + @classmethod + def instance(cls) -> ElasticEPState: + return cls._instance + + @classmethod + def init(cls, server_args: ServerArgs): + with cls._lock: + if cls._instance is not None: + return cls._instance + + if server_args.elastic_ep_backend is not None: + cls._instance = cls._build_state( + ep_size=None, + device=None, + backend_type=server_args.elastic_ep_backend, + ) + return cls._instance + + @staticmethod + def _select_device() -> torch.device: + if is_cuda(): + return torch.device("cuda") + elif is_cpu(): + return torch.device("cpu") + else: + raise NotImplementedError("Only CUDA and CPU support elastic ep now.") + + @classmethod + def _build_state( + cls, + *, + ep_size: Optional[int], + device: Optional[torch.device], + backend_type: str = "none", + ) -> ElasticEPState: + + active = cls.create_rank_state(ep_size=ep_size, device=device, value=1) + + if backend_type == "mooncake": + on_forward = cls.on_forward_mooncake + elif backend_type == "deepep": + on_forward = cls.on_forward_deepep + else: + on_forward = None + + return ElasticEPState( + _active_ranks=active, + _last_active_ranks=active.clone(), + _active_ranks_cpu=active.detach().cpu().clone(), + rank_status=active.clone(), + on_forward=on_forward, + ) + + @classmethod + def create_rank_state( + cls, *, ep_size: Optional[int], device: Optional[torch.device], value: int = 1 + ) -> torch.Tensor: + size = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else cls._select_device() + + return torch.full((size,), value, dtype=torch.int32, device=dev) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index e2a2678104af..a75559c3e108 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,8 @@ import torch -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware class EplbAlgorithm(Enum): @@ -11,6 +12,7 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() + elasticity_aware = auto() # TODO may have more algorithm later @@ -45,6 +47,21 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) + if algorithm == EplbAlgorithm.elasticity_aware: + return elasticity_aware.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + enable_hierarchical=True, + active_ranks=( + ElasticEPStateManager.instance()._active_ranks + if ElasticEPStateManager.instance() is not None + else ElasticEPStateManager.healthy_rank_state() + ), + ) + raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py new file mode 100644 index 000000000000..c781c444ae3b --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py @@ -0,0 +1,87 @@ +from typing import Tuple + +import torch + +from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, + active_ranks: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + num_active_ranks = active_ranks.sum().item() + num_local_experts = num_replicas // num_gpus + if num_active_ranks < num_gpus: + # Must fall back to global load-balance policy + # and fix some params + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, + num_local_experts * num_active_ranks, + 1, + 1, + num_active_ranks, + ) + elif enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange( + num_local_experts * num_active_ranks, + dtype=torch.int64, + device=log2phy.device, + ).expand(num_layers, -1), + ) + if num_active_ranks < num_gpus: + phy2log_slices = list( + phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) + ) + active_ranks_list = active_ranks.tolist() + for idx, active_rank in enumerate(active_ranks_list): + if not active_rank: + phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) + log2phy = torch.where( + log2phy >= idx * num_local_experts, + log2phy + num_local_experts, + log2phy, + ) + phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) + return phy2log, log2phy, logcnt diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 618c4cf9eb1c..84380f363197 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -211,6 +212,7 @@ def get_deepep_buffer( low_latency_mode=deepep_mode.enable_low_latency(), num_qps_per_rank=num_qps_per_rank, # TODO can be false when unneeded + enable_shrink=True, allow_mnnvl=True, ) return cls._buffer @@ -299,6 +301,7 @@ def __init__( # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024 # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it assert self.num_max_dispatch_tokens_per_rank <= 1024 + self.status_tensor = ElasticEPStateManager.instance().rank_status self.handle = None @@ -664,6 +667,9 @@ def _combine_core( else {} ), ) + torch.cuda.synchronize() + buffer.low_latency_query_mask_buffer(self.status_tensor) + torch.cuda.synchronize() self.packed_recv_count = self.handle = None return combined_hidden_states, event, hook diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index d6d56186563a..82482c22d493 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -62,14 +63,6 @@ def format(self) -> CombineInputFormat: assert isinstance(MooncakeCombineInput, CombineInput) -_ACTIVE_RANKS: Optional[torch.Tensor] = None - - -def get_ep_active_ranks() -> torch.Tensor: - assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized" - return _ACTIVE_RANKS - - class EPBuffer: _buffer = None _hidden_size: Optional[int] = None @@ -152,12 +145,7 @@ def __init__( self.first_execution = True self.timeout_us = 10000000 - global _ACTIVE_RANKS - if _ACTIVE_RANKS is None: - _ACTIVE_RANKS = torch.ones( - (self.num_experts,), dtype=torch.int32, device="cuda" - ) - self.active_ranks = _ACTIVE_RANKS + self.active_ranks = ElasticEPStateManager.instance().rank_status self.handle = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ef8bc99373f..d43c709036c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,7 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -382,6 +383,11 @@ def initialize(self, min_per_gpu_memory: float): ) self.expert_location_updater = ExpertLocationUpdater() + ( + ElasticEPStateManager.init(self.server_args) + if self.server_args.elastic_ep_backend + else None + ) # Load the model self.sampler = Sampler() self.load_model() @@ -926,16 +932,33 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - self.expert_location_updater.update( - self.model.routed_experts_weights_of_layer, - new_expert_location_metadata, - update_layer_ids=update_layer_ids, - nnodes=self.server_args.nnodes, - rank=self.tp_rank, - ) + if ElasticEPStateManager.instance() is not None: + # TODO: refactor the weights update when elastic ep + old_expert_location_metadata = get_global_expert_location_metadata() + assert old_expert_location_metadata is not None + old_expert_location_metadata.update( + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + ) + self.update_weights_from_disk( + self.server_args.model_path, + self.server_args.load_format, + lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name, + ) + else: + self.expert_location_updater.update( + self.model.routed_experts_weights_of_layer, + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=self.server_args.nnodes, + rank=self.tp_rank, + ) def update_weights_from_disk( - self, model_path: str, load_format: str + self, + model_path: str, + load_format: str, + weight_name_filter: Optional[Callable[[str], bool]] = None, ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" logger.info( @@ -957,6 +980,11 @@ def get_weight_iter(config): iter = loader._get_weights_iterator( DefaultModelLoader.Source.init_new(config, self.model) ) + if weight_name_filter is not None: + iter = ( + (name, weight) for name, weight in iter if weight_name_filter(name) + ) + return iter def model_load_weights(model, iter): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8644324963c4..0e13c0613ae3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -229,7 +229,7 @@ class ServerArgs: # Runtime options device: Optional[str] = None - elastic_ep_backend: Literal[None, "mooncake"] = None + elastic_ep_backend: Literal[None, "mooncake", "deepep"] = None mooncake_ib_device: Optional[str] = None tp_size: int = 1 pp_size: int = 1 @@ -577,6 +577,9 @@ def __post_init__(self): # Handle any other necessary validations. self._handle_other_validations() + # Handle elastic expert parallelism. + self._handle_elastic_ep() + def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1145,6 +1148,15 @@ def _handle_eplb_and_dispatch(self): if self.enable_eplb: assert self.ep_size > 1 + def _handle_elastic_ep(self): + if self.elastic_ep_backend is not None: + if self.enable_eplb: + if self.eplb_algorithm == "auto": + self.eplb_algorithm = "elasticity_aware" + assert ( + self.eplb_algorithm == "elasticity_aware" + ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'." + def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( self.expert_distribution_recorder_mode is None @@ -1752,8 +1764,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "--elastic-ep-backend", type=str, default=ServerArgs.elastic_ep_backend, - choices=["none", "mooncake"], - help="Specify the collective communication backend for elastic EP. Currently supports 'mooncake'.", + choices=["none", "mooncake", "deepep"], + help=( + "Specify the collective communication backend for elastic EP. " + "Supports 'mooncake' and 'deepep'. Use 'none' to disable." + ), ) parser.add_argument( "--mooncake-ib-device", diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 111260a8c82d..391cdc4c65f5 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -3,6 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,166 +12,12 @@ popen_launch_server, ) - -class TestPureDP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - "--mem-fraction-static", - "0.5", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestHybridDPTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "256", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +ib_devices = get_rdma_devices_args() class TestTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "128", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) + extra_args = [] - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestNoGatherdBuffer(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -183,16 +30,10 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", "--elastic-ep-backend", "mooncake", "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + ib_devices, "--moe-a2a-backend", "deepep", "--deepep-mode", @@ -200,9 +41,12 @@ def setUpClass(cls): "--chunked-prefill-size", "512", "--cuda-graph-max-bs", - "32", + "128", "--max-running-requests", "512", + "--mem-fraction-static", + "0.5", + *cls.extra_args, ], ) @@ -226,60 +70,73 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) -class TestTBO(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--enable-two-batch-overlap", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestPureDP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + ] + + +class TestHybridDPTP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + ] + + +class TestNoGatherdBuffer(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + ] + + +class TestTBO(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + ] + + +class TestMooncakeWitchEPLB(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + ] if __name__ == "__main__":