-
Notifications
You must be signed in to change notification settings - Fork 5.4k
[Feature] Integrate NIXL-EP into SGLang #17605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
8ae4347
pr2 eplb
HanHan009527 f01ba58
Let `token_dispatcher/mooncake.py` use the `global_elastic_ep_metadat…
UNIDY2002 cd65b69
fix
HanHan009527 3e38276
some fix
HanHan009527 484058c
fxi
HanHan009527 cb41b43
lint
HanHan009527 560c595
fix
HanHan009527 e6875fc
test
HanHan009527 63d7b0e
add ut
HanHan009527 8d8aca9
test
HanHan009527 6d36b5b
test
HanHan009527 56fb09c
fix
HanHan009527 2434821
fix
HanHan009527 8c0e187
fix
HanHan009527 37eeaab
fix
HanHan009527 2200898
lint
HanHan009527 9808c8d
fix
HanHan009527 cb54875
fix
HanHan009527 7b1bd4e
test
HanHan009527 2606322
t
HanHan009527 9a99351
ut
HanHan009527 06563c0
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 fd8cc23
Merge branch 'main' into mooncake-pr-eplb
ShangmingCai 7b06878
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 4da41cd
Merge branch 'main' into mooncake-pr-eplb
HanHan009527 f6e86b5
support deepep elastic
96653d3
[NIXL] Support NIXL EP MoE a2a backend
BBiber1 ce83ac7
Add support for tcp_store
zackyoray File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import threading | ||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from typing import Any, Callable, Dict, Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.managers.schedule_batch import ServerArgs | ||
| from sglang.srt.utils import is_cpu, is_cuda | ||
|
|
||
|
|
||
| @dataclass | ||
| class ElasticEPState: | ||
| _active_ranks: Optional[torch.Tensor] | ||
| _last_active_ranks: Optional[torch.Tensor] | ||
| _active_ranks_cpu: Optional[torch.Tensor] | ||
| on_forward: Optional[Callable] = None | ||
| rank_status: Optional[torch.Tensor] = None | ||
|
|
||
| def is_active_equal_last(self) -> bool: | ||
| return torch.equal(self._active_ranks, self._last_active_ranks) | ||
|
|
||
| def sync_active_to_cpu(self): | ||
| if self._active_ranks is not None: | ||
| self._active_ranks_cpu = self._active_ranks.detach().cpu().clone() | ||
|
|
||
| def snapshot_active_to_last(self): | ||
| if self._active_ranks is not None: | ||
| self._last_active_ranks = self._active_ranks.clone() | ||
|
|
||
|
|
||
| class ElasticEPStateManager: | ||
| _instance: Optional[ElasticEPState] = None | ||
| _lock = threading.Lock() | ||
|
|
||
| @staticmethod | ||
| def on_forward_mooncake( | ||
| state: ElasticEPState, status: torch.Tensor = None, **kwargs | ||
| ): | ||
| state._active_ranks = state.rank_status.to(dtype=torch.int32) | ||
|
|
||
| @staticmethod | ||
| def on_forward_deepep(state: ElasticEPState, status: torch.Tensor = None, **kwargs): | ||
| state._active_ranks = 1 - state.rank_status.to(torch.int32) | ||
|
|
||
| @classmethod | ||
| def instance(cls) -> ElasticEPState: | ||
| return cls._instance | ||
|
|
||
| @classmethod | ||
| def init(cls, server_args: ServerArgs): | ||
| with cls._lock: | ||
| if cls._instance is not None: | ||
| return cls._instance | ||
|
|
||
| if server_args.elastic_ep_backend is not None: | ||
| cls._instance = cls._build_state( | ||
| ep_size=None, | ||
| device=None, | ||
| backend_type=server_args.elastic_ep_backend, | ||
| ) | ||
| return cls._instance | ||
|
|
||
| @staticmethod | ||
| def _select_device() -> torch.device: | ||
| if is_cuda(): | ||
| return torch.device("cuda") | ||
| elif is_cpu(): | ||
| return torch.device("cpu") | ||
| else: | ||
| raise NotImplementedError("Only CUDA and CPU support elastic ep now.") | ||
|
|
||
| @classmethod | ||
| def _build_state( | ||
| cls, | ||
| *, | ||
| ep_size: Optional[int], | ||
| device: Optional[torch.device], | ||
| backend_type: str = "none", | ||
| ) -> ElasticEPState: | ||
|
|
||
| active = cls.create_rank_state(ep_size=ep_size, device=device, value=1) | ||
|
|
||
| if backend_type == "mooncake": | ||
| on_forward = cls.on_forward_mooncake | ||
| elif backend_type == "deepep": | ||
| on_forward = cls.on_forward_deepep | ||
| else: | ||
| on_forward = None | ||
|
|
||
| return ElasticEPState( | ||
| _active_ranks=active, | ||
| _last_active_ranks=active.clone(), | ||
| _active_ranks_cpu=active.detach().cpu().clone(), | ||
| rank_status=active.clone(), | ||
| on_forward=on_forward, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def create_rank_state( | ||
| cls, *, ep_size: Optional[int], device: Optional[torch.device], value: int = 1 | ||
| ) -> torch.Tensor: | ||
| size = ep_size if ep_size is not None else torch.distributed.get_world_size() | ||
| dev = device if device is not None else cls._select_device() | ||
|
|
||
| return torch.full((size,), value, dtype=torch.int32, device=dev) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
python/sglang/srt/eplb/eplb_algorithms/elasticity_aware.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical | ||
|
|
||
|
|
||
| def rebalance_experts( | ||
| weight: torch.Tensor, | ||
| num_replicas: int, | ||
| num_groups: int, | ||
| num_nodes: int, | ||
| num_gpus: int, | ||
| enable_hierarchical: bool, | ||
| active_ranks: torch.Tensor, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Entry point for expert-parallelism load balancer. | ||
|
|
||
| Parameters: | ||
| weight: [layers, num_logical_experts], the load statistics for all logical experts | ||
| num_replicas: number of physical experts, must be a multiple of `num_gpus` | ||
| num_groups: number of expert groups | ||
| num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster | ||
| num_gpus: number of GPUs, must be a multiple of `num_nodes` | ||
|
|
||
| Returns: | ||
| physical_to_logical_map: [layers, num_replicas], the expert index of each replica | ||
| logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert | ||
| expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert | ||
| """ | ||
|
|
||
| num_layers, num_logical_experts = weight.shape | ||
| weight = weight.float().cpu() | ||
| num_active_ranks = active_ranks.sum().item() | ||
| num_local_experts = num_replicas // num_gpus | ||
| if num_active_ranks < num_gpus: | ||
| # Must fall back to global load-balance policy | ||
| # and fix some params | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, | ||
| num_local_experts * num_active_ranks, | ||
| 1, | ||
| 1, | ||
| num_active_ranks, | ||
| ) | ||
| elif enable_hierarchical: | ||
| # use hierarchical load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, num_groups, num_nodes, num_gpus | ||
| ) | ||
| else: | ||
| # use global load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, 1, 1, num_gpus | ||
| ) | ||
| maxlogcnt = logcnt.max().item() | ||
| log2phy: torch.Tensor = torch.full( | ||
| (num_layers, num_logical_experts, maxlogcnt), | ||
| -1, | ||
| dtype=torch.int64, | ||
| device=logcnt.device, | ||
| ) | ||
| log2phy.view(num_layers, -1).scatter_( | ||
| -1, | ||
| phy2log * maxlogcnt + phyrank, | ||
| torch.arange( | ||
| num_local_experts * num_active_ranks, | ||
| dtype=torch.int64, | ||
| device=log2phy.device, | ||
| ).expand(num_layers, -1), | ||
| ) | ||
| if num_active_ranks < num_gpus: | ||
| phy2log_slices = list( | ||
| phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) | ||
| ) | ||
| active_ranks_list = active_ranks.tolist() | ||
| for idx, active_rank in enumerate(active_ranks_list): | ||
| if not active_rank: | ||
| phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) | ||
| log2phy = torch.where( | ||
| log2phy >= idx * num_local_experts, | ||
| log2phy + num_local_experts, | ||
| log2phy, | ||
| ) | ||
| phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) | ||
| return phy2log, log2phy, logcnt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method
ElasticEPStateManager.healthy_rank_state()does not exist. This will raise anAttributeErrorif theelasticity_awarealgorithm is used whenElasticEPStateManageris not initialized (i.e., no elastic EP backend is active).To prevent this crash, you should provide a fallback that creates a tensor indicating all ranks are healthy. A tensor of ones with a shape corresponding to the number of GPUs would be a suitable default.