-
Notifications
You must be signed in to change notification settings - Fork 5k
[2/N] Added the core structure of elastic EP and the eplb algorithm with faulty rank #10606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8ae4347
f01ba58
cd65b69
3e38276
484058c
cb41b43
560c595
e6875fc
63d7b0e
8d8aca9
6d36b5b
56fb09c
2434821
8c0e187
37eeaab
2200898
9808c8d
cb54875
7b1bd4e
2606322
9a99351
06563c0
fd8cc23
7b06878
4da41cd
687ba59
d66c884
fbc874e
47cb4ad
9168004
8fdffde
a133b2d
5b10b90
2cd15a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.managers.schedule_batch import ServerArgs | ||
| from sglang.srt.utils import is_cpu, is_cuda | ||
|
|
||
|
|
||
| @dataclass | ||
| class ElasticEPState: | ||
| active_ranks: Optional[torch.Tensor] | ||
| last_active_ranks: Optional[torch.Tensor] | ||
| active_ranks_cpu: Optional[torch.Tensor] | ||
|
|
||
| def is_active_equal_last(self) -> bool: | ||
| return torch.equal(self.active_ranks, self.last_active_ranks) | ||
|
|
||
| def sync_active_to_cpu(self): | ||
| if self.active_ranks is not None: | ||
| self.active_ranks_cpu = self.active_ranks.detach().cpu().clone() | ||
|
|
||
| def snapshot_active_to_last(self): | ||
| if self.active_ranks is not None: | ||
| self.last_active_ranks = self.active_ranks.clone() | ||
|
|
||
|
|
||
| class ElasticEPStateManager: | ||
| _instance: Optional[ElasticEPState] = None | ||
|
|
||
| @classmethod | ||
| def instance(cls) -> ElasticEPState: | ||
| return cls._instance | ||
|
|
||
| @classmethod | ||
| def init(cls, server_args: ServerArgs): | ||
| if cls._instance is not None: | ||
| return cls._instance | ||
|
|
||
| if server_args.elastic_ep_backend is not None: | ||
| cls._instance = cls._build_state(ep_size=None, device=None) | ||
| return cls._instance | ||
|
|
||
| @staticmethod | ||
| def _select_device() -> torch.device: | ||
| if is_cuda(): | ||
| return torch.device("cuda") | ||
| elif is_cpu(): | ||
| return torch.device("cpu") | ||
| else: | ||
| raise NotImplementedError("Only CUDA and CPU support elastic ep now.") | ||
|
|
||
| @classmethod | ||
| def _build_state( | ||
| cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None | ||
| ) -> ElasticEPState: | ||
|
|
||
| active = cls.healthy_rank_state(ep_size=ep_size, device=device) | ||
| return ElasticEPState( | ||
| active_ranks=active, | ||
| last_active_ranks=active.clone(), | ||
| active_ranks_cpu=active.detach().cpu().clone(), | ||
| ) | ||
|
|
||
| @classmethod | ||
| def healthy_rank_state( | ||
| cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None | ||
| ) -> torch.Tensor: | ||
| size = ep_size if ep_size is not None else torch.distributed.get_world_size() | ||
| dev = device if device is not None else cls._select_device() | ||
|
|
||
| return torch.ones(size, dtype=torch.int32, device=dev) | ||
|
Comment on lines
+67
to
+74
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this dtype be changed to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mooncake EP currently uses int32. BTW, what does "future usage" refer to?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @UNIDY2002 Just wonder if it should align with
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had a quick verification. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical | ||
|
|
||
|
|
||
| def rebalance_experts( | ||
| weight: torch.Tensor, | ||
| num_replicas: int, | ||
| num_groups: int, | ||
| num_nodes: int, | ||
| num_gpus: int, | ||
| enable_hierarchical: bool, | ||
| active_ranks: torch.Tensor, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Entry point for expert-parallelism load balancer. | ||
|
|
||
| Parameters: | ||
| weight: [layers, num_logical_experts], the load statistics for all logical experts | ||
| num_replicas: number of physical experts, must be a multiple of `num_gpus` | ||
| num_groups: number of expert groups | ||
| num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster | ||
| num_gpus: number of GPUs, must be a multiple of `num_nodes` | ||
|
|
||
| Returns: | ||
| physical_to_logical_map: [layers, num_replicas], the expert index of each replica | ||
| logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert | ||
| expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert | ||
| """ | ||
|
|
||
| num_layers, num_logical_experts = weight.shape | ||
| weight = weight.float().cpu() | ||
| num_active_ranks = active_ranks.sum().item() | ||
| num_local_experts = num_replicas // num_gpus | ||
| if num_active_ranks < num_gpus: | ||
| # Must fall back to global load-balance policy | ||
| # and fix some params | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, | ||
| num_local_experts * num_active_ranks, | ||
| 1, | ||
| 1, | ||
| num_active_ranks, | ||
| ) | ||
| elif enable_hierarchical: | ||
| # use hierarchical load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, num_groups, num_nodes, num_gpus | ||
| ) | ||
| else: | ||
| # use global load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, 1, 1, num_gpus | ||
| ) | ||
| maxlogcnt = logcnt.max().item() | ||
| log2phy: torch.Tensor = torch.full( | ||
| (num_layers, num_logical_experts, maxlogcnt), | ||
| -1, | ||
| dtype=torch.int64, | ||
| device=logcnt.device, | ||
| ) | ||
| log2phy.view(num_layers, -1).scatter_( | ||
| -1, | ||
| phy2log * maxlogcnt + phyrank, | ||
| torch.arange( | ||
| num_local_experts * num_active_ranks, | ||
| dtype=torch.int64, | ||
| device=log2phy.device, | ||
| ).expand(num_layers, -1), | ||
| ) | ||
| if num_active_ranks < num_gpus: | ||
| phy2log_slices = list( | ||
| phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1) | ||
| ) | ||
| active_ranks_list = active_ranks.tolist() | ||
| for idx, active_rank in enumerate(active_ranks_list): | ||
| if not active_rank: | ||
| phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0])) | ||
| log2phy = torch.where( | ||
| log2phy >= idx * num_local_experts, | ||
| log2phy + num_local_experts, | ||
| log2phy, | ||
| ) | ||
| phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1) | ||
| return phy2log, log2phy, logcnt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe give
ep_sizeanddevicea default value:NoneThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done