Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8fe6f82
[Feature] Core EPLB algorithm
abmfy May 14, 2025
bdda8dc
[Feature] Register expert weights for DeepSeek MoE
abmfy May 16, 2025
43d52ac
[Chore] Rename EPLB rebalance algo module name
abmfy May 16, 2025
58bf9fd
[Feature] Store EPLB states in model runner
abmfy May 16, 2025
52b141f
[Feature] EPLB rearrangement execution
abmfy May 16, 2025
98312d3
[Feature] Add expert load metrics collection during forward
abmfy May 19, 2025
22a963d
[Feature] Rearrange experts after a preset step interval
abmfy May 19, 2025
f88d836
Merge branch 'main' into eplb
abmfy May 19, 2025
43ac672
[Feature] Use unified `FusedMoE` in DeepSeek-V3/R1
abmfy May 20, 2025
f7ba162
[Bugfix] Copy expert mappings after rearrangement
abmfy May 20, 2025
ba3d60f
[Chore] Move implementations to `deepseek_v2.py`
abmfy May 23, 2025
ebcfcc7
[Chore] Remove expert load stats from forward context
abmfy May 23, 2025
620f59a
[Feature] Weight loading for redundant experts
abmfy May 23, 2025
90f3ed5
[Feature] Expert replica selection and load metrics recording
abmfy May 27, 2025
b3697de
[Feature] Map logical experts in weight loading
abmfy May 27, 2025
5d85f61
[Bugfix] Use `scatter_add_` instead of `bincount` for compile
abmfy May 27, 2025
e416e3c
[Bugfix] Add EPLB args in `EngineArgs`
abmfy May 27, 2025
233741c
[Bugfix] Sum up steps on EPLb rearrange
abmfy May 27, 2025
cfcd42c
[Bugfix] Collect expert weights into a list
abmfy May 27, 2025
36b0b11
[Bugfix] Fix typo in assertion
abmfy May 27, 2025
d5add3a
[Bugfix] Pad `log2phy` magging in rebalance algo
abmfy May 27, 2025
b00bdb9
[Bugfix] Fix EP group in `DeepseekV2MoE`
abmfy May 27, 2025
c9cf2d4
[Refactor] Use local physical ids in expert load collection
abmfy May 27, 2025
4f79fef
[Bugfix] Map physical id before recording expert load metrics
abmfy May 27, 2025
a97ee39
[Perf] Reduce overhead of expert load recording
abmfy May 28, 2025
0c9340d
Merge branch 'main' into eplb
abmfy May 28, 2025
2b14d51
[Bugfix] Step EPLB state in dummy run to avoid blocking DP
abmfy May 29, 2025
306b21a
[Feature] Do not record expert loads for dummy batches
abmfy May 30, 2025
021578e
[Bugfix] Collect expert weights after weight post-processing
abmfy Jun 2, 2025
c2e0516
[Bugfix] Fix weight loading of replica experts
abmfy Jun 3, 2025
0071b24
Merge branch 'main' into eplb
abmfy Jun 6, 2025
38f9218
Merge branch 'main' into eplb
abmfy Jun 9, 2025
79c0d41
[Bugfix] Remove `e_score_correction_bias` in expert weights
abmfy Jun 9, 2025
b011065
[Bugfix] Fix shapes and dtypes in `FusedMoE`
abmfy Jun 10, 2025
82a6299
Merge branch 'main' into eplb
abmfy Jun 12, 2025
90706aa
[Feature] Disable EPLb step during profile run
abmfy Jun 16, 2025
f1f62b2
[Bugfix] Synchronize CUDA before shuffling layer to avoid hang
abmfy Jun 17, 2025
332a4d6
Merge branch 'main' into eplb
abmfy Jun 18, 2025
90d23ec
Merge branch 'eplb-graph' into eplb
abmfy Jun 19, 2025
993d7d7
[Style] Rename module `eplb.states` to `eplb.eplb_state`
abmfy Jun 19, 2025
90afdaf
[Feature] Run a dummy rearrangement during profile run for CUDA graphs
abmfy Jun 20, 2025
7774e0a
Merge branch 'eplb-graph' into eplb
abmfy Jun 20, 2025
f5d171f
[Feature] Constrain EPLB to main models
abmfy Jun 20, 2025
aaa66a2
[Refactor] Move out `EplbState` in model runner from classvars
abmfy Jun 20, 2025
934bbf0
Merge branch 'main' into eplb
abmfy Jun 23, 2025
4e346be
[Style] Rename `--num-extra-experts` to `--num-redundant-experts`
abmfy Jun 23, 2025
2496a54
[Doc] Add glossary for different types of experts
abmfy Jun 23, 2025
9916913
[Doc] Add staatements in `EplbState` that some var is just config
abmfy Jun 23, 2025
420cb99
[Doc] Add notes on synchronization of rearrangement step
abmfy Jun 23, 2025
ff368a1
[Doc] Add examples for expert mappings
abmfy Jun 23, 2025
425d56c
[Doc] Add explanation on why picking the last layer for MoE config
abmfy Jun 23, 2025
76fbdf8
[Refactor] Revert `fused_moe.py` since not used
abmfy Jun 23, 2025
6777877
[Doc] Add explanations for calling points of `_dummy_run`
abmfy Jun 23, 2025
12401b1
[Doc] Add comments on when do real communication happen
abmfy Jun 23, 2025
80b3a1b
[Doc] Add comments on only last `eplb_window_size` steps will be used
abmfy Jun 23, 2025
3ea6f2c
[Feature] Disable balancedness logging by default
abmfy Jun 23, 2025
aff7991
[Style] Rename shadowed variables to make linter happy
abmfy Jun 24, 2025
8ac089e
[Style] Add parameters of `apply` for subclasses of `FusedMoEMethodBase`
abmfy Jun 24, 2025
a6a4a3a
[Test] Add test for EPLB algo
abmfy Jun 24, 2025
1ed45b2
[Test] Add test for EPLB execute
abmfy Jun 25, 2025
4eeb0ff
[Style] Split some long lines
abmfy Jun 25, 2025
0c177d0
Merge branch 'main' into eplb
abmfy Jun 25, 2025
5b1e354
[Feature] Use `get_node_count` and remove magic number
abmfy Jun 25, 2025
495f782
[Test] Disable `first_k_dense_replace` in `test_initialization`
abmfy Jun 26, 2025
66fe93f
[Test] Use only 2 experts in `test_initialization`
abmfy Jun 26, 2025
3ec9032
[Test] Get at least `n_group` experts in `test_initialization`
abmfy Jun 26, 2025
c479d2c
[Test] Allow 2 experts per group in `test_initialization`
abmfy Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,15 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
num_extra_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size: int = 1000
"""Window size for expert load recording."""
eplb_step_interval: int = 3000
"""Interval for rearranging experts in expert parallelism."""

max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
Expand Down Expand Up @@ -1909,6 +1918,20 @@ def __post_init__(self) -> None:
f"{current_platform.device_type.upper()} backend only "
"supports Ray for distributed inference.")

if self.enable_eplb:
if not current_platform.is_cuda():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
if self.num_extra_experts < 0:
raise ValueError(
"num_extra_experts must be non-negative, but got "
f"{self.num_extra_experts}.")
else:
if self.num_extra_experts != 0:
raise ValueError("num_extra_experts should be used with EPLB."
f"{self.num_extra_experts}.")

if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
Expand Down
4 changes: 4 additions & 0 deletions vllm/distributed/eplb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0

from .rebalance_algo import *
from .states import *
230 changes: 230 additions & 0 deletions vllm/distributed/eplb/rebalance_algo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
"""
Expert parallelism load balancer (EPLB) for vLLM.

This module implements the core rearrangement algorithm.

The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
"""

import torch


def balanced_packing(weight: torch.Tensor,
num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.

Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs

Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs

if groups_per_pack == 1:
pack_index = torch.arange(weight.size(-1),
dtype=torch.int64,
device=weight.device).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack

indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight,
fill_value=-1,
dtype=torch.int64,
device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i
for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
return pack_index, rank_in_pack


def replicate_experts(
weight: torch.Tensor,
num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.

Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication

Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64,
device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt


def rebalance_experts_hierarchical(
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus

def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(perm.size(1), dtype=torch.int64,
device=perm.device).expand(perm.shape),
)
return inv

# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(
tokens_per_group, num_nodes)
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
group_size).unsqueeze(-1) +
torch.arange(group_size,
dtype=torch.int64,
device=group_pack_index.device)).flatten(-2)
mlog2log = inverse(log2mlog)

# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes)

# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)

pphy2mlog = phy2mlog.gather(
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt


def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.

Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`

Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of
each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica
indices for each expert
expert_count: [layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64,
device=log2phy.device).expand(num_layers, -1),
)
return phy2log, log2phy, logcnt


__all__ = ["rebalance_experts"]
Loading
Loading