diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index aad5bbc05a7e..49dfe2be5da0 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -134,7 +134,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None | -| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None | +| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Supports 'mooncake' and 'deepep'. Use 'none' to disable. | None | | `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None | | `--tp-size` | The tensor parallelism size. | 1 | | `--pp-size` | The pipeline parallelism size. | 1 | @@ -248,7 +248,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--ep-size` | The expert parallelism size. | 1 | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep` or `mooncake`. | none | +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep`, `mooncake`, or `nixl`. | none | | `--moe-runner-backend` | Select the runner backend for MoE. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 7e18d06db78a..c8002c56bee9 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -39,6 +39,7 @@ import torch.distributed from torch.distributed import Backend, ProcessGroup +from sglang.srt.distributed.utils import set_global_tcp_store from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -1398,6 +1399,59 @@ def set_symm_mem_all_reduce(enable: bool): _ENABLE_SYMM_MEM_ALL_REDUCE = enable +def _create_global_tcp_store( + rank: int, world_size: int +) -> None: + """Create a global TCPStore for coordination across ranks. + + This function creates a TCPStore that all ranks can use for coordination + (e.g., for NIXL buffer setup). + """ + from torch.distributed import TCPStore + + master_ip = os.environ.get("MASTER_ADDR") + + if not master_ip: + logger.warning( + "Could not determine master IP for global TCPStore. " + "Broadcasting from rank 0 to all ranks." + ) + + base_store_port = int(os.environ.get("SGLANG_TCP_STORE_PORT", "29600")) + + # Rank 0 gets its local IP and broadcasts it to all ranks + # Use broadcast_object_list which works with any backend (handles CPU/GPU automatically) + if not master_ip: + if rank == 0: + master_ip = get_local_ip_auto() + ip_list = [master_ip] + else: + ip_list = [None] + + torch.distributed.broadcast_object_list(ip_list, src=0) + master_ip = ip_list[0] + + logger.debug("yorayz: master_ip: %s, base_store_port: %d", master_ip, base_store_port) + try: + tcp_store = TCPStore( + host_name=master_ip, + port=base_store_port, + world_size=world_size, + is_master=(rank == 0), + ) + set_global_tcp_store(tcp_store) + logger.info( + "Created global TCPStore at %s:%d (rank=%d, world_size=%d)", + master_ip, base_store_port, rank, world_size + ) + except Exception as e: + logger.warning( + "Failed to create global TCPStore at %s:%d: %s. " + "Components requiring TCPStore (like NIXL) may not work.", + master_ip, base_store_port, e + ) + + def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -1444,6 +1498,9 @@ def init_distributed_environment( timeout=timeout, ) + # Create a global TCPStore for coordination (used by NIXL) + _create_global_tcp_store(rank, world_size) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index bfe54b9d478a..7783e0ad63cf 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -17,6 +17,42 @@ logger = logging.getLogger(__name__) +# Global TCPStore that is created during distributed initialization +# This is the single shared store that all components should use +_global_tcp_store: Optional[TCPStore] = None + + +def set_global_tcp_store(store: TCPStore) -> None: + """Set the global TCPStore instance. + + This should be called during distributed initialization to make + the store available to all components that need it. + """ + global _global_tcp_store + _global_tcp_store = store + logger.info("Global TCPStore has been set") + + +def get_global_tcp_store() -> Optional[TCPStore]: + """Get the existing global TCPStore. + + This function provides access to the shared TCPStore instance that was + created during distributed initialization. All components (like NIXL buffers) + should use this same store for coordination. + + Returns: + The global TCPStore instance, or None if not initialized yet. + """ + global _global_tcp_store + + if _global_tcp_store is None: + logger.warning( + "Global TCPStore not found. Make sure init_distributed_environment " + "was called with a tcp:// init method." + ) + + return _global_tcp_store + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" diff --git a/python/sglang/srt/elastic_ep/elastic_ep.py b/python/sglang/srt/elastic_ep/elastic_ep.py new file mode 100644 index 000000000000..6f9b564bdc25 --- /dev/null +++ b/python/sglang/srt/elastic_ep/elastic_ep.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Union + +import torch + +from sglang.srt.managers.schedule_batch import ServerArgs +from sglang.srt.utils import is_cpu, is_cuda + + +@dataclass +class ElasticEPState: + _active_ranks: Optional[torch.Tensor] + _last_active_ranks: Optional[torch.Tensor] + _active_ranks_cpu: Optional[torch.Tensor] + on_forward: Optional[Callable] = None + rank_status: Optional[torch.Tensor] = None + + def is_active_equal_last(self) -> bool: + return torch.equal(self._active_ranks, self._last_active_ranks) + + def sync_active_to_cpu(self): + if self._active_ranks is not None: + self._active_ranks_cpu = self._active_ranks.detach().cpu().clone() + + def snapshot_active_to_last(self): + if self._active_ranks is not None: + self._last_active_ranks = self._active_ranks.clone() + + +class ElasticEPStateManager: + _instance: Optional[ElasticEPState] = None + _lock = threading.Lock() + + @staticmethod + def on_forward_mooncake( + state: ElasticEPState, status: torch.Tensor = None, **kwargs + ): + state._active_ranks = state.rank_status.to(dtype=torch.int32) + + @staticmethod + def on_forward_deepep(state: ElasticEPState, status: torch.Tensor = None, **kwargs): + state._active_ranks = 1 - state.rank_status.to(torch.int32) + + @classmethod + def instance(cls) -> ElasticEPState: + return cls._instance + + @classmethod + def init(cls, server_args: ServerArgs): + with cls._lock: + if cls._instance is not None: + return cls._instance + + if server_args.elastic_ep_backend is not None: + cls._instance = cls._build_state( + ep_size=None, + device=None, + backend_type=server_args.elastic_ep_backend, + ) + return cls._instance + + @staticmethod + def _select_device() -> torch.device: + if is_cuda(): + return torch.device("cuda") + elif is_cpu(): + return torch.device("cpu") + else: + raise NotImplementedError("Only CUDA and CPU support elastic ep now.") + + @classmethod + def _build_state( + cls, + *, + ep_size: Optional[int], + device: Optional[torch.device], + backend_type: str = "none", + ) -> ElasticEPState: + + active = cls.create_rank_state(ep_size=ep_size, device=device, value=1) + + if backend_type == "mooncake": + on_forward = cls.on_forward_mooncake + elif backend_type == "deepep": + on_forward = cls.on_forward_deepep + else: + on_forward = None + + return ElasticEPState( + _active_ranks=active, + _last_active_ranks=active.clone(), + _active_ranks_cpu=active.detach().cpu().clone(), + rank_status=active.clone(), + on_forward=on_forward, + ) + + @classmethod + def create_rank_state( + cls, *, ep_size: Optional[int], device: Optional[torch.device], value: int = 1 + ) -> torch.Tensor: + size = ep_size if ep_size is not None else torch.distributed.get_world_size() + dev = device if device is not None else cls._select_device() + + return torch.full((size,), value, dtype=torch.int32, device=dev) diff --git a/python/sglang/srt/eplb/eplb_algorithms/__init__.py b/python/sglang/srt/eplb/eplb_algorithms/__init__.py index e2a2678104af..69664869b551 100644 --- a/python/sglang/srt/eplb/eplb_algorithms/__init__.py +++ b/python/sglang/srt/eplb/eplb_algorithms/__init__.py @@ -3,7 +3,8 @@ import torch -from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager +from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware class EplbAlgorithm(Enum): @@ -11,6 +12,8 @@ class EplbAlgorithm(Enum): deepseek_hierarchical = auto() deepseek_vec = auto() deepseek_vec_hierarchical = auto() + elasticity_aware = auto() + elasticity_aware_hierarchical = auto() # TODO may have more algorithm later @@ -45,6 +48,24 @@ def rebalance_experts( enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, ) + if algorithm in [ + EplbAlgorithm.elasticity_aware, + EplbAlgorithm.elasticity_aware_hierarchical, + ]: + return elasticity_aware.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + enable_hierarchical=algorithm == EplbAlgorithm.elasticity_aware_hierarchical, + active_ranks=( + ElasticEPStateManager.instance()._active_ranks + if ElasticEPStateManager.instance() is not None + else ElasticEPStateManager.healthy_rank_state() + ), + ) + raise NotImplementedError diff --git a/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py new file mode 100644 index 000000000000..c781c444ae3b --- /dev/null +++ b/python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py @@ -0,0 +1,87 @@ +from typing import Tuple + +import torch + +from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, + active_ranks: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + num_active_ranks = active_ranks.sum().item() + num_local_experts = num_replicas // num_gpus + if num_active_ranks < num_gpus: + # Must fall back to global load-balance policy + # and fix some params + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, + num_local_experts * num_active_ranks, + 1, + 1, + num_active_ranks, + ) + elif enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange( + num_local_experts * num_active_ranks, + dtype=torch.int64, + device=log2phy.device, + ).expand(num_layers, -1), + ) + if num_active_ranks < num_gpus: + phy2log_slices = list( + phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) + ) + active_ranks_list = active_ranks.tolist() + for idx, active_rank in enumerate(active_ranks_list): + if not active_rank: + phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) + log2phy = torch.where( + log2phy >= idx * num_local_experts, + log2phy + num_local_experts, + log2phy, + ) + phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) + return phy2log, log2phy, logcnt diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index af82d54a41db..6880420b7e5b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -706,7 +706,11 @@ def _forward_ll(dispatch_output: DeepEPLLOutput): def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): - if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + if ( + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() + ): return DeepEPMoE # NEW: Direct FP4 detection (bypasses EP requirements) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 7526f73dedc5..731c28f9bb4d 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -21,6 +21,11 @@ MooncakeDispatchOutput, MooncakeEPDispatcher, ) +from sglang.srt.layers.moe.token_dispatcher.nixl import ( + NixlEPCombineInput, + NixlEPDispatcher, + NixlEPDispatchOutput, +) from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardCombineInput, StandardDispatchOutput, @@ -38,6 +43,9 @@ "MooncakeCombineInput", "MooncakeDispatchOutput", "MooncakeEPDispatcher", + "NixlEPCombineInput", + "NixlEPDispatchOutput", + "NixlEPDispatcher", "StandardDispatchOutput", "StandardCombineInput", "DeepEPConfig", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 618c4cf9eb1c..84380f363197 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -211,6 +212,7 @@ def get_deepep_buffer( low_latency_mode=deepep_mode.enable_low_latency(), num_qps_per_rank=num_qps_per_rank, # TODO can be false when unneeded + enable_shrink=True, allow_mnnvl=True, ) return cls._buffer @@ -299,6 +301,7 @@ def __init__( # DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024 # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it assert self.num_max_dispatch_tokens_per_rank <= 1024 + self.status_tensor = ElasticEPStateManager.instance().rank_status self.handle = None @@ -664,6 +667,9 @@ def _combine_core( else {} ), ) + torch.cuda.synchronize() + buffer.low_latency_query_mask_buffer(self.status_tensor) + torch.cuda.synchronize() self.packed_recv_count = self.handle = None return combined_hidden_states, event, hook diff --git a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py index d6d56186563a..82482c22d493 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/mooncake.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional, Tuple +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( BaseDispatcher, @@ -62,14 +63,6 @@ def format(self) -> CombineInputFormat: assert isinstance(MooncakeCombineInput, CombineInput) -_ACTIVE_RANKS: Optional[torch.Tensor] = None - - -def get_ep_active_ranks() -> torch.Tensor: - assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized" - return _ACTIVE_RANKS - - class EPBuffer: _buffer = None _hidden_size: Optional[int] = None @@ -152,12 +145,7 @@ def __init__( self.first_execution = True self.timeout_us = 10000000 - global _ACTIVE_RANKS - if _ACTIVE_RANKS is None: - _ACTIVE_RANKS = torch.ones( - (self.num_experts,), dtype=torch.int32, device="cuda" - ) - self.active_ranks = _ACTIVE_RANKS + self.active_ranks = ElasticEPStateManager.instance().rank_status self.handle = None diff --git a/python/sglang/srt/layers/moe/token_dispatcher/nixl.py b/python/sglang/srt/layers/moe/token_dispatcher/nixl.py new file mode 100644 index 000000000000..88b25c18091a --- /dev/null +++ b/python/sglang/srt/layers/moe/token_dispatcher/nixl.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +import logging +import os +from contextlib import nullcontext +from enum import Enum, auto +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple + +import torch +import torch.distributed as dist + +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.moe.token_dispatcher.base import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.utils import DeepEPMode +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.distributed.utils import get_global_tcp_store +from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_npu + +_is_npu = is_npu() + +if TYPE_CHECKING: + from sglang.srt.single_batch_overlap import CombineOverlapArgs + +try: + from nixl_ep import Buffer + + use_nixl = True +except ImportError: + use_nixl = False + +logger = logging.getLogger(__name__) + + +class NixlEPDispatchOutput(NamedTuple): + """NixlEP dispatch output. + + Note: Uses same format as DeepEPLLOutput for compatibility with downstream code. + hidden_states_fp8 is a tuple of (hidden_states, scale) or just hidden_states if no scale. + """ + + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor] + topk_idx: torch.Tensor + topk_weights: torch.Tensor + masked_m: torch.Tensor + expected_m: int + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_LL + + +assert isinstance(NixlEPDispatchOutput, DispatchOutput) + + +class NixlEPCombineInput(NamedTuple): + """NixlEP combine input.""" + + pass + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + +assert isinstance(NixlEPCombineInput, CombineInput) + + +class NixlEPBuffer: + _buffer = None + _hidden_size: Optional[int] = None + _num_max_dispatch_tokens_per_rank: Optional[int] = None + _num_experts: Optional[int] = None + _num_local_experts: Optional[int] = None + + @classmethod + def get_nixl_buffer( + cls, + group: dist.ProcessGroup, + hidden_size: int, + deepep_mode: DeepEPMode, + num_max_dispatch_tokens_per_rank: int = -1, + num_experts: int = -1, + num_local_experts: int = -1, + ): + if cls._buffer is not None: + return cls._buffer + + cls._hidden_size = hidden_size + cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + cls._num_experts = num_experts + cls._num_local_experts = num_local_experts + + num_rdma_bytes = 0 + if deepep_mode.enable_normal(): + raise NotImplementedError("Normal mode is not supported for Nixl EP yet.") + if deepep_mode.enable_low_latency(): + assert num_max_dispatch_tokens_per_rank != -1 + assert num_experts != -1 and num_experts % group.size() == 0 + num_rdma_bytes = Buffer.get_rdma_size_hint( + num_max_dispatch_tokens_per_rank, + hidden_size, + group.size(), + num_experts, + ) + + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + # Get the global TCPStore for coordination + tcp_store = get_global_tcp_store() + if tcp_store is None: + raise RuntimeError( + "Global TCPStore is not initialized. " + "Make sure init_distributed_environment was called before using NIXL EP." + ) + + logger.info( + f"Using NIXL EP (world_size={world_size}, rank={rank}, " + f"num_experts={cls._num_experts}, num_experts_per_rank={cls._num_local_experts}) " + ) + + cls._buffer = Buffer( + rank=rank, + tcp_store_group=tcp_store, + ) + + cls._buffer.update_memory_buffers( + num_ranks=world_size, + num_experts_per_rank=cls._num_local_experts, + num_rdma_bytes=num_rdma_bytes, + ) + all_ranks = list(range(world_size)) + cls._buffer.connect_ranks(all_ranks) + + return cls._buffer + + @classmethod + def clean_buffer(cls): + cls._buffer.clean_buffer( + cls._num_max_dispatch_tokens_per_rank, + cls._hidden_size, + cls._num_experts, + ) + + +class _NixlEPDispatcherImplBase: + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + deepep_mode: DeepEPMode, + ): + if not use_nixl: + raise ImportError( + "NixlEP is not installed. Please install NixlEP package from " + "https://github.com/ai-dynamo/nixl." + ) + + self.group = group + self.router_topk = router_topk + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_dtype = params_dtype + self.deepep_mode = deepep_mode + + self.num_max_dispatch_tokens_per_rank = get_int_env_var( + "SGLANG_NIXL_EP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128 + ) + # NixlEP internode_ll dispatch uses FINISHED_SUM_TAG=1024 + # and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it + assert self.num_max_dispatch_tokens_per_rank <= 1024 + self.status_tensor = ElasticEPStateManager.instance().rank_status + + self.handle = None + + def dispatch_a( + self, + hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + raise NotImplementedError + + def dispatch_b(self, *args, **kwargs): + raise NotImplementedError + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"] = None, + ): + raise NotImplementedError + + def combine_b(self, *args, **kwargs): + raise NotImplementedError + + def _get_buffer(self): + raise NotImplementedError + + +class _NixlEPDispatcherImpl(_NixlEPDispatcherImplBase): + def __init__(self, return_recv_hook: bool, **kwargs): + super().__init__(**kwargs) + + """ + num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 + https://github.com/ai-dynamo/nixl + """ + self.return_recv_hook = return_recv_hook + self.device_module = torch.get_device_module() + + def dispatch_a( + self, + hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ): + buffer = self._get_buffer() + topk_idx = topk_idx.to(torch.int64) + expected_m = ( + hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + + self.num_experts + ) // self.num_experts + hidden_states, masked_m, event, hook = self._dispatch_core( + hidden_states, + input_global_scale, + topk_idx, + ) + return ( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ) + + def dispatch_b( + self, + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + event, + hook, + ): + hook() if self.return_recv_hook else event.current_stream_wait() + + get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency( + masked_m + ) + + nixl_output = NixlEPDispatchOutput( + hidden_states, + topk_idx, + topk_weights, + masked_m, + expected_m, + ) + return nixl_output + + def _dispatch_core( + self, + hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], + topk_idx: torch.Tensor, + ): + use_nvfp4 = use_fp8 = False + if input_global_scale is not None: + use_nvfp4 = True + elif not get_bool_env_var("SGLANG_NIXL_EP_BF16_DISPATCH"): + use_fp8 = True + + buffer = self._get_buffer() + packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = ( + buffer.dispatch( + hidden_states, + topk_idx, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + use_fp8=use_fp8, + **(dict(use_nvfp4=True) if use_nvfp4 else dict()), + **( + dict(x_global_scale=input_global_scale) + if input_global_scale is not None + else dict() + ), + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, + use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, + ) + ) + return packed_recv_hidden, self.packed_recv_count, event, hook + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"] = None, + ): + hidden_states, event, hook = self._combine_core( + hidden_states, + topk_idx, + topk_weights, + overlap_args=overlap_args, + ) + return hidden_states, event, hook, overlap_args + + def combine_b(self, hidden_states, event, hook, overlap_args): + if overlap_args is not None: + overlap_args.stream.wait_stream(self.device_module.current_stream()) + + hook() if self.return_recv_hook else event.current_stream_wait() + + if overlap_args is not None: + self.device_module.current_stream().wait_stream(overlap_args.stream) + + return hidden_states + + def _combine_core( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"] = None, + ): + buffer = self._get_buffer() + + ctx = nullcontext() + if overlap_args is not None: + overlap_args.stream.wait_event(overlap_args.wait_event) + ctx = torch.cuda.stream(overlap_args.stream) + + with ctx: + combined_hidden_states, event, hook = buffer.combine( + x=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + handle=self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + **( + dict( + overlap=overlap_args.overlap, + src_signals=overlap_args.signal, + src_signal_expect_value=overlap_args.threshold, + ) + if overlap_args is not None + else {} + ), + ) + torch.cuda.synchronize() + buffer.query_mask_buffer(self.status_tensor) + torch.cuda.synchronize() + + self.packed_recv_count = self.handle = None + return combined_hidden_states, event, hook + + def _get_buffer(self): + return NixlEPBuffer.get_nixl_buffer( + self.group, + self.hidden_size, + self.deepep_mode, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + self.num_local_experts, + ) + + +class _Stage(Enum): + INITIAL = auto() + AFTER_DISPATCH_A = auto() + AFTER_DISPATCH_B = auto() + AFTER_COMBINE_A = auto() + + +class NixlEPDispatcher(BaseDispatcher): + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.LOW_LATENCY, + async_finish: bool = False, + return_recv_hook: bool = False, + ): + self.deepep_mode = deepep_mode + + common_kwargs = dict( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + deepep_mode=deepep_mode, + ) + + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher = _NixlEPDispatcherImpl( + return_recv_hook=return_recv_hook, + **common_kwargs, + ) + if self.deepep_mode.enable_normal(): + raise NotImplementedError("Normal mode is not supported for Nixl EP yet.") + + self._stage = _Stage.INITIAL + + def dispatch(self, *args, **kwargs) -> DispatchOutput: + self.dispatch_a(*args, **kwargs) + ret = self.dispatch_b() + return ret + + def dispatch_a( + self, + hidden_states: torch.Tensor, + input_global_scale: Optional[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + ): + self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) + inner_state = self._get_impl(forward_batch).dispatch_a( + hidden_states=hidden_states, + input_global_scale=input_global_scale, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) + self._dispatch_intermediate_state = forward_batch, inner_state + + def dispatch_b(self): + self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) + forward_batch, inner_state = self._dispatch_intermediate_state + del self._dispatch_intermediate_state + return self._get_impl(forward_batch).dispatch_b(*inner_state) + + def combine(self, *args, **kwargs) -> Tuple: + self.combine_a(*args, **kwargs) + ret = self.combine_b() + return ret + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + overlap_args: Optional["CombineOverlapArgs"] = None, + ): + self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) + inner_state = self._get_impl(forward_batch).combine_a( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + overlap_args=overlap_args, + ) + self._combine_intermediate_state = forward_batch, inner_state + + def combine_b(self): + self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) + forward_batch, inner_state = self._combine_intermediate_state + del self._combine_intermediate_state + return self._get_impl(forward_batch).combine_b(*inner_state) + + def _get_impl(self, forward_batch: ForwardBatch) -> _NixlEPDispatcherImplBase: + resolved_deepep_mode = self.deepep_mode.resolve( + forward_batch.is_extend_in_batch + ) + if resolved_deepep_mode == DeepEPMode.NORMAL: + raise NotImplementedError("Normal mode is not supported for Nixl EP yet.") + elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: + return self._low_latency_dispatcher + else: + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + + def _update_stage(self, old_stage, new_stage): + assert self._stage == old_stage + self._stage = new_stage diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 65149936343a..b39061ca1932 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -26,6 +26,7 @@ class MoeA2ABackend(Enum): NONE = "none" DEEPEP = "deepep" MOONCAKE = "mooncake" + NIXL = "nixl" @classmethod def _missing_(cls, value): @@ -45,6 +46,9 @@ def is_deepep(self): def is_mooncake(self): return self == MoeA2ABackend.MOONCAKE + def is_nixl(self): + return self == MoeA2ABackend.NIXL + class MoeRunnerBackend(Enum): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a1a25102d70c..ac5485ee52af 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1017,9 +1017,8 @@ def create_moe_runner( moe_runner_backend = get_moe_runner_backend() if moe_runner_backend.is_auto(): - if ( - deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM - and get_moe_a2a_backend().is_deepep() + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and ( + get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_nixl() ): moe_runner_backend = MoeRunnerBackend.DEEP_GEMM else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ef8bc99373f..d43c709036c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -24,7 +24,7 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -51,6 +51,7 @@ set_symm_mem_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.expert_distribution import ( ExpertDistributionRecorder, @@ -382,6 +383,11 @@ def initialize(self, min_per_gpu_memory: float): ) self.expert_location_updater = ExpertLocationUpdater() + ( + ElasticEPStateManager.init(self.server_args) + if self.server_args.elastic_ep_backend + else None + ) # Load the model self.sampler = Sampler() self.load_model() @@ -926,16 +932,33 @@ def update_expert_location( new_expert_location_metadata: ExpertLocationMetadata, update_layer_ids: List[int], ): - self.expert_location_updater.update( - self.model.routed_experts_weights_of_layer, - new_expert_location_metadata, - update_layer_ids=update_layer_ids, - nnodes=self.server_args.nnodes, - rank=self.tp_rank, - ) + if ElasticEPStateManager.instance() is not None: + # TODO: refactor the weights update when elastic ep + old_expert_location_metadata = get_global_expert_location_metadata() + assert old_expert_location_metadata is not None + old_expert_location_metadata.update( + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + ) + self.update_weights_from_disk( + self.server_args.model_path, + self.server_args.load_format, + lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name, + ) + else: + self.expert_location_updater.update( + self.model.routed_experts_weights_of_layer, + new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=self.server_args.nnodes, + rank=self.tp_rank, + ) def update_weights_from_disk( - self, model_path: str, load_format: str + self, + model_path: str, + load_format: str, + weight_name_filter: Optional[Callable[[str], bool]] = None, ) -> tuple[bool, str]: """Update engine weights in-place from the disk.""" logger.info( @@ -957,6 +980,11 @@ def get_weight_iter(config): iter = loader._get_weights_iterator( DefaultModelLoader.Source.init_new(config, self.model) ) + if weight_name_filter is not None: + iter = ( + (name, weight) for name, weight in iter if weight_name_filter(name) + ) + return iter def model_load_weights(model, iter): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fb9cd4f6c9f8..9b7918fb4498 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -593,6 +593,7 @@ def __init__( dict(tp_rank=0, tp_size=1) if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() or should_use_flashinfer_cutlass_moe_fp4_allgather() else {} ), @@ -623,7 +624,11 @@ def __init__( self.top_k = config.num_experts_per_tok - if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + if ( + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() + ): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -653,7 +658,9 @@ def __init__( ) self._enable_a2a_moe = ( - get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() ) def get_moe_weights(self): diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 35ce0c40db50..c3dec6cce5ce 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -467,7 +467,11 @@ def __init__( self.top_k = config.num_experts_per_tok - if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + if ( + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() + ): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -497,7 +501,9 @@ def __init__( ) self._enable_a2a_moe = ( - get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_nixl() ) def forward_normal_dual_stream( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8644324963c4..896ec206f819 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -229,7 +229,7 @@ class ServerArgs: # Runtime options device: Optional[str] = None - elastic_ep_backend: Literal[None, "mooncake"] = None + elastic_ep_backend: Literal[None, "mooncake", "deepep"] = None mooncake_ib_device: Optional[str] = None tp_size: int = 1 pp_size: int = 1 @@ -347,7 +347,7 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 - moe_a2a_backend: Literal["none", "deepep", "mooncake"] = "none" + moe_a2a_backend: Literal["none", "deepep", "mooncake", "nixl"] = "none" moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False @@ -577,6 +577,9 @@ def __post_init__(self): # Handle any other necessary validations. self._handle_other_validations() + # Handle elastic expert parallelism. + self._handle_elastic_ep() + def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1130,6 +1133,14 @@ def _handle_a2a_moe(self): f"Mooncake MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) + if self.moe_a2a_backend == "nixl": + self.ep_size = self.tp_size + self.disable_cuda_graph = True + logger.warning("Cuda graph is disabled because moe_a2a_backend=`nixl`") + logger.warning( + f"Nixl MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): self.expert_distribution_recorder_mode = "stat" @@ -1145,6 +1156,18 @@ def _handle_eplb_and_dispatch(self): if self.enable_eplb: assert self.ep_size > 1 + def _handle_elastic_ep(self): + if self.elastic_ep_backend is not None: + if self.enable_eplb: + if self.eplb_algorithm == "auto": + self.eplb_algorithm = "elasticity_aware" + assert ( + self.eplb_algorithm in [ + "elasticity_aware", + "elasticity_aware_hierarchical", + ] + ), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware(_hierarchical)'." + def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( self.expert_distribution_recorder_mode is None @@ -1752,8 +1775,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "--elastic-ep-backend", type=str, default=ServerArgs.elastic_ep_backend, - choices=["none", "mooncake"], - help="Specify the collective communication backend for elastic EP. Currently supports 'mooncake'.", + choices=["none", "mooncake", "deepep"], + help=( + "Specify the collective communication backend for elastic EP. " + "Supports 'mooncake' and 'deepep'. Use 'none' to disable." + ), ) parser.add_argument( "--mooncake-ib-device", @@ -2384,7 +2410,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--moe-a2a-backend", type=str, - choices=["none", "deepep", "mooncake"], + choices=["none", "deepep", "mooncake", "nixl"], default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index b09c72dae209..c60cc9c6022b 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher import ( DeepEPDispatcher, MooncakeEPDispatcher, + NixlEPDispatcher, ) from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -979,6 +980,10 @@ def __init__(self, **kwargs): self._inners = [ MooncakeEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) ] + elif get_moe_a2a_backend().is_nixl(): + self._inners = [ + NixlEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) + ] def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs): return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs) diff --git a/test/srt/ep/test_mooncake_ep_small.py b/test/srt/ep/test_mooncake_ep_small.py index 111260a8c82d..391cdc4c65f5 100644 --- a/test/srt/ep/test_mooncake_ep_small.py +++ b/test/srt/ep/test_mooncake_ep_small.py @@ -3,6 +3,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_disaggregation_utils import get_rdma_devices_args from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -11,166 +12,12 @@ popen_launch_server, ) - -class TestPureDP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - "--mem-fraction-static", - "0.5", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestHybridDPTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "2", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "256", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +ib_devices = get_rdma_devices_args() class TestTP(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "128", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) + extra_args = [] - self.assertGreater(metrics["accuracy"], 0.60) - - -class TestNoGatherdBuffer(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -183,16 +30,10 @@ def setUpClass(cls): "--trust-remote-code", "--tp", "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--enable-dp-lm-head", "--elastic-ep-backend", "mooncake", "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + ib_devices, "--moe-a2a-backend", "deepep", "--deepep-mode", @@ -200,9 +41,12 @@ def setUpClass(cls): "--chunked-prefill-size", "512", "--cuda-graph-max-bs", - "32", + "128", "--max-running-requests", "512", + "--mem-fraction-static", + "0.5", + *cls.extra_args, ], ) @@ -226,60 +70,73 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.60) -class TestTBO(CustomTestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "4", - "--enable-dp-attention", - "--dp", - "4", - "--moe-dense-tp-size", - "1", - "--elastic-ep-backend", - "mooncake", - "--mooncake-ib-device", - "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "low_latency", - "--chunked-prefill-size", - "512", - "--enable-two-batch-overlap", - "--cuda-graph-max-bs", - "128", - "--max-running-requests", - "512", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) +class TestPureDP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + ] + + +class TestHybridDPTP(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "2", + ] + + +class TestNoGatherdBuffer(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + ] + + +class TestTBO(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + ] + + +class TestMooncakeWitchEPLB(TestTP): + extra_args = [ + "--tp", + "4", + "--enable-dp-attention", + "--dp", + "4", + "--moe-dense-tp-size", + "1", + "--enable-two-batch-overlap", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + ] if __name__ == "__main__":