-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Feature] Expert Parallelism Load Balancer (EPLB) #18343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 34 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
8fe6f82
[Feature] Core EPLB algorithm
abmfy bdda8dc
[Feature] Register expert weights for DeepSeek MoE
abmfy 43d52ac
[Chore] Rename EPLB rebalance algo module name
abmfy 58bf9fd
[Feature] Store EPLB states in model runner
abmfy 52b141f
[Feature] EPLB rearrangement execution
abmfy 98312d3
[Feature] Add expert load metrics collection during forward
abmfy 22a963d
[Feature] Rearrange experts after a preset step interval
abmfy f88d836
Merge branch 'main' into eplb
abmfy 43ac672
[Feature] Use unified `FusedMoE` in DeepSeek-V3/R1
abmfy f7ba162
[Bugfix] Copy expert mappings after rearrangement
abmfy ba3d60f
[Chore] Move implementations to `deepseek_v2.py`
abmfy ebcfcc7
[Chore] Remove expert load stats from forward context
abmfy 620f59a
[Feature] Weight loading for redundant experts
abmfy 90f3ed5
[Feature] Expert replica selection and load metrics recording
abmfy b3697de
[Feature] Map logical experts in weight loading
abmfy 5d85f61
[Bugfix] Use `scatter_add_` instead of `bincount` for compile
abmfy e416e3c
[Bugfix] Add EPLB args in `EngineArgs`
abmfy 233741c
[Bugfix] Sum up steps on EPLb rearrange
abmfy cfcd42c
[Bugfix] Collect expert weights into a list
abmfy 36b0b11
[Bugfix] Fix typo in assertion
abmfy d5add3a
[Bugfix] Pad `log2phy` magging in rebalance algo
abmfy b00bdb9
[Bugfix] Fix EP group in `DeepseekV2MoE`
abmfy c9cf2d4
[Refactor] Use local physical ids in expert load collection
abmfy 4f79fef
[Bugfix] Map physical id before recording expert load metrics
abmfy a97ee39
[Perf] Reduce overhead of expert load recording
abmfy 0c9340d
Merge branch 'main' into eplb
abmfy 2b14d51
[Bugfix] Step EPLB state in dummy run to avoid blocking DP
abmfy 306b21a
[Feature] Do not record expert loads for dummy batches
abmfy 021578e
[Bugfix] Collect expert weights after weight post-processing
abmfy c2e0516
[Bugfix] Fix weight loading of replica experts
abmfy 0071b24
Merge branch 'main' into eplb
abmfy 38f9218
Merge branch 'main' into eplb
abmfy 79c0d41
[Bugfix] Remove `e_score_correction_bias` in expert weights
abmfy b011065
[Bugfix] Fix shapes and dtypes in `FusedMoE`
abmfy 82a6299
Merge branch 'main' into eplb
abmfy 90706aa
[Feature] Disable EPLb step during profile run
abmfy f1f62b2
[Bugfix] Synchronize CUDA before shuffling layer to avoid hang
abmfy 332a4d6
Merge branch 'main' into eplb
abmfy 90d23ec
Merge branch 'eplb-graph' into eplb
abmfy 993d7d7
[Style] Rename module `eplb.states` to `eplb.eplb_state`
abmfy 90afdaf
[Feature] Run a dummy rearrangement during profile run for CUDA graphs
abmfy 7774e0a
Merge branch 'eplb-graph' into eplb
abmfy f5d171f
[Feature] Constrain EPLB to main models
abmfy aaa66a2
[Refactor] Move out `EplbState` in model runner from classvars
abmfy 934bbf0
Merge branch 'main' into eplb
abmfy 4e346be
[Style] Rename `--num-extra-experts` to `--num-redundant-experts`
abmfy 2496a54
[Doc] Add glossary for different types of experts
abmfy 9916913
[Doc] Add staatements in `EplbState` that some var is just config
abmfy 420cb99
[Doc] Add notes on synchronization of rearrangement step
abmfy ff368a1
[Doc] Add examples for expert mappings
abmfy 425d56c
[Doc] Add explanation on why picking the last layer for MoE config
abmfy 76fbdf8
[Refactor] Revert `fused_moe.py` since not used
abmfy 6777877
[Doc] Add explanations for calling points of `_dummy_run`
abmfy 12401b1
[Doc] Add comments on when do real communication happen
abmfy 80b3a1b
[Doc] Add comments on only last `eplb_window_size` steps will be used
abmfy 3ea6f2c
[Feature] Disable balancedness logging by default
abmfy aff7991
[Style] Rename shadowed variables to make linter happy
abmfy 8ac089e
[Style] Add parameters of `apply` for subclasses of `FusedMoEMethodBase`
abmfy a6a4a3a
[Test] Add test for EPLB algo
abmfy 1ed45b2
[Test] Add test for EPLB execute
abmfy 4eeb0ff
[Style] Split some long lines
abmfy 0c177d0
Merge branch 'main' into eplb
abmfy 5b1e354
[Feature] Use `get_node_count` and remove magic number
abmfy 495f782
[Test] Disable `first_k_dense_replace` in `test_initialization`
abmfy 66fe93f
[Test] Use only 2 experts in `test_initialization`
abmfy 3ec9032
[Test] Get at least `n_group` experts in `test_initialization`
abmfy c479d2c
[Test] Allow 2 experts per group in `test_initialization`
abmfy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from .rebalance_algo import * | ||
| from .states import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| Expert parallelism load balancer (EPLB) for vLLM. | ||
|
|
||
| This module implements the core rearrangement algorithm. | ||
|
|
||
| The rearrangement algorithm is adapted from | ||
| [DeepSeek EPLB](https://github.com/deepseek-ai/eplb). | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def balanced_packing(weight: torch.Tensor, | ||
| num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Pack n weighted objects to m packs, such that each bin contains exactly | ||
| n/m objects and the weights of all packs are as balanced as possible. | ||
|
|
||
| Parameters: | ||
| weight: [X, n], the weight of each item | ||
| num_packs: number of packs | ||
|
|
||
| Returns: | ||
| pack_index: [X, n], the pack index of each item | ||
| rank_in_pack: [X, n], the rank of the item in the pack | ||
| """ | ||
| num_layers, num_groups = weight.shape | ||
| assert num_groups % num_packs == 0 | ||
| groups_per_pack = num_groups // num_packs | ||
|
|
||
| if groups_per_pack == 1: | ||
| pack_index = torch.arange(weight.size(-1), | ||
| dtype=torch.int64, | ||
| device=weight.device).expand(weight.shape) | ||
| rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) | ||
| return pack_index, rank_in_pack | ||
|
|
||
| indices = weight.float().sort(-1, descending=True).indices.cpu() | ||
| pack_index = torch.full_like(weight, | ||
| fill_value=-1, | ||
| dtype=torch.int64, | ||
| device="cpu") | ||
| rank_in_pack = torch.full_like(pack_index, fill_value=-1) | ||
| for i in range(num_layers): | ||
| pack_weights = [0] * num_packs | ||
| pack_items = [0] * num_packs | ||
| for group in indices[i]: | ||
| pack = min( | ||
| (i | ||
| for i in range(num_packs) if pack_items[i] < groups_per_pack), | ||
| key=pack_weights.__getitem__, | ||
| ) | ||
| assert pack_items[pack] < groups_per_pack | ||
| pack_index[i, group] = pack | ||
| rank_in_pack[i, group] = pack_items[pack] | ||
| pack_weights[pack] += weight[i, group] | ||
| pack_items[pack] += 1 | ||
| return pack_index, rank_in_pack | ||
|
|
||
|
|
||
| def replicate_experts( | ||
| weight: torch.Tensor, | ||
| num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Replicate `num_log` experts to `num_phy` replicas, such that the maximum | ||
| load of all replicas is minimized. | ||
|
|
||
| Parameters: | ||
| weight: [X, num_log] | ||
| num_phy: total number of experts after replication | ||
|
|
||
| Returns: | ||
| phy2log: [X, num_phy], logical expert id of each physical expert | ||
| rank: [X, num_phy], the replica rank | ||
| logcnt: [X, num_log], number of replicas for each logical expert | ||
| """ | ||
| n, num_log = weight.shape | ||
| num_redundant = num_phy - num_log | ||
| assert num_redundant >= 0 | ||
| device = weight.device | ||
| phy2log = torch.arange(num_phy, dtype=torch.int64, | ||
| device=device).repeat(n, 1) | ||
| rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) | ||
| logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) | ||
| arangen = torch.arange(n, dtype=torch.int64, device=device) | ||
| for i in range(num_log, num_phy): | ||
| redundant_indices = (weight / logcnt).max(dim=-1).indices | ||
| phy2log[:, i] = redundant_indices | ||
| rank[:, i] = logcnt[arangen, redundant_indices] | ||
| logcnt[arangen, redundant_indices] += 1 | ||
| return phy2log, rank, logcnt | ||
|
|
||
|
|
||
| def rebalance_experts_hierarchical( | ||
| weight: torch.Tensor, | ||
| num_physical_experts: int, | ||
| num_groups: int, | ||
| num_nodes: int, | ||
| num_gpus: int, | ||
| ): | ||
| """ | ||
| Parameters: | ||
| weight: [num_moe_layers, num_logical_experts] | ||
| num_physical_experts: number of physical experts after replication | ||
| num_groups: number of expert groups | ||
| num_nodes: number of server nodes, where the intra-node network | ||
| (e.g, NVLink) is faster | ||
| num_gpus: number of GPUs, must be a multiple of `num_nodes` | ||
|
|
||
| Returns: | ||
| physical_to_logical_map: [num_moe_layers, num_physical_experts] | ||
| logical_to_physical_map: [num_moe_layers, num_logical_experts, X] | ||
| logical_count: [num_moe_layers, num_logical_experts] | ||
| """ | ||
| num_layers, num_logical_experts = weight.shape | ||
| assert num_logical_experts % num_groups == 0 | ||
| group_size = num_logical_experts // num_groups | ||
| assert num_groups % num_nodes == 0 | ||
| groups_per_node = num_groups // num_nodes | ||
| assert num_gpus % num_nodes == 0 | ||
| assert num_physical_experts % num_gpus == 0 | ||
| phy_experts_per_gpu = num_physical_experts // num_gpus | ||
|
|
||
| def inverse(perm: torch.Tensor) -> torch.Tensor: | ||
| inv = torch.empty_like(perm) | ||
| inv.scatter_( | ||
| 1, | ||
| perm, | ||
| torch.arange(perm.size(1), dtype=torch.int64, | ||
| device=perm.device).expand(perm.shape), | ||
| ) | ||
| return inv | ||
|
|
||
| # Step 1: pack groups to nodes | ||
| tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) | ||
| group_pack_index, group_rank_in_pack = balanced_packing( | ||
| tokens_per_group, num_nodes) | ||
| log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * | ||
| group_size).unsqueeze(-1) + | ||
| torch.arange(group_size, | ||
| dtype=torch.int64, | ||
| device=group_pack_index.device)).flatten(-2) | ||
| mlog2log = inverse(log2mlog) | ||
|
|
||
| # Step 2: construct redundant experts within nodes | ||
| # [num_layers * num_nodes, num_logical_experts // num_nodes] | ||
| tokens_per_mlog = weight.gather(-1, mlog2log).view( | ||
| -1, num_logical_experts // num_nodes) | ||
| phy2mlog, phyrank, mlogcnt = replicate_experts( | ||
| tokens_per_mlog, num_physical_experts // num_nodes) | ||
|
|
||
| # Step 3: pack physical_experts to GPUs | ||
| # [num_layers * num_nodes, num_physical_experts // num_nodes] | ||
| tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) | ||
| pack_index, rank_in_pack = balanced_packing(tokens_per_phy, | ||
| num_gpus // num_nodes) | ||
| phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack | ||
| pphy2phy = inverse(phy2pphy) | ||
|
|
||
| pphy2mlog = phy2mlog.gather( | ||
| -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] | ||
| pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( | ||
| 0, | ||
| num_logical_experts, | ||
| num_logical_experts // num_nodes, | ||
| device=group_pack_index.device, | ||
| ).view(1, -1, 1)).flatten(-2) | ||
| pphy2log = mlog2log.gather(-1, pphy2mlog) | ||
| pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) | ||
| logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) | ||
| return pphy2log, pphyrank, logcnt | ||
|
|
||
|
|
||
| def rebalance_experts( | ||
| weight: torch.Tensor, | ||
| num_replicas: int, | ||
| num_groups: int, | ||
| num_nodes: int, | ||
| num_gpus: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Entry point for expert-parallelism load balancer. | ||
|
|
||
| Parameters: | ||
| weight: [layers, num_logical_experts], the load statistics for all | ||
| logical experts | ||
| num_replicas: number of physical experts, must be a multiple of | ||
| `num_gpus` | ||
| num_groups: number of expert groups | ||
| num_nodes: number of server nodes, where the intra-node network | ||
| (e.g, NVLink) is faster | ||
| num_gpus: number of GPUs, must be a multiple of `num_nodes` | ||
|
|
||
| Returns: | ||
| physical_to_logical_map: [layers, num_replicas], the expert index of | ||
| each replica | ||
| logical_to_physical_map: [layers, num_logical_experts, X], the replica | ||
| indices for each expert | ||
| expert_count: [layers, num_logical_experts], number of physical | ||
| replicas for each logical expert | ||
| """ | ||
| num_layers, num_logical_experts = weight.shape | ||
| weight = weight.float().cpu() | ||
| if num_groups % num_nodes == 0: | ||
| # use hierarchical load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, num_groups, num_nodes, num_gpus) | ||
| else: | ||
| # use global load-balance policy | ||
| phy2log, phyrank, logcnt = rebalance_experts_hierarchical( | ||
| weight, num_replicas, 1, 1, num_gpus) | ||
| num_redundant_experts = num_replicas - num_logical_experts | ||
| maxlogcnt = num_redundant_experts + 1 | ||
| log2phy: torch.Tensor = torch.full( | ||
| (num_layers, num_logical_experts, maxlogcnt), | ||
| -1, | ||
| dtype=torch.int64, | ||
| device=logcnt.device, | ||
| ) | ||
| log2phy.view(num_layers, -1).scatter_( | ||
| -1, | ||
| phy2log * maxlogcnt + phyrank, | ||
| torch.arange(num_replicas, dtype=torch.int64, | ||
| device=log2phy.device).expand(num_layers, -1), | ||
| ) | ||
| return phy2log, log2phy, logcnt | ||
|
|
||
|
|
||
| __all__ = ["rebalance_experts"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.