Skip to content
81 changes: 63 additions & 18 deletions tests/distributed/test_eplb_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch

from vllm.distributed.eplb.eplb_state import compute_logical_maps
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy


Expand All @@ -24,9 +25,10 @@ def test_basic_rebalance():
num_nodes = 2
num_gpus = 8

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify output shapes
assert phy2log.shape == (
Expand Down Expand Up @@ -78,9 +80,10 @@ def test_single_gpu_case():
num_nodes = 1
num_gpus = 1

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
log2phy, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify shapes
assert phy2log.shape == (1, 4)
Expand All @@ -100,9 +103,10 @@ def test_equal_weights():
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify shapes
assert phy2log.shape == (1, 8)
Expand All @@ -123,9 +127,10 @@ def test_extreme_weight_imbalance():
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify shapes
assert phy2log.shape == (1, 12)
Expand All @@ -151,9 +156,10 @@ def test_multiple_layers():
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify shapes
assert phy2log.shape == (3, 8)
Expand All @@ -176,7 +182,8 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
phy2log = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4)

Expand All @@ -198,9 +205,10 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Verify basic constraints
assert phy2log.shape == (1, 12)
Expand All @@ -225,9 +233,10 @@ def test_global_load_balance_fallback():
num_nodes = 2
num_gpus = 4

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Should work normally, just using global load balancing strategy
assert phy2log.shape == (1, 8)
Expand All @@ -247,9 +256,10 @@ def test_device_compatibility(device):
num_nodes = 1
num_gpus = 2

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
_, logcnt = compute_logical_maps(phy2log, weight.shape[-1])

# Function will convert to CPU internally, but should handle different
# device inputs normally
Expand All @@ -264,9 +274,8 @@ def test_additional_cases():
weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
)
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
weight1, 24, 8, 4, 8
)
phy2log1 = DefaultEplbPolicy.rebalance_experts(weight1, 24, 8, 4, 8)
_, logcnt1 = compute_logical_maps(phy2log1, weight1.shape[-1])

assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16)
Expand All @@ -279,9 +288,8 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights
]
)
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
weight2, 10, 3, 1, 2
)
phy2log2 = DefaultEplbPolicy.rebalance_experts(weight2, 10, 3, 1, 2)
_, logcnt2 = compute_logical_maps(phy2log2, weight2.shape[-1])

assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6)
Expand All @@ -292,6 +300,42 @@ def test_additional_cases():
assert logcnt2[layer, max_weight_idx] >= 2


def test_compute_logical_maps_with_negative_indices():
"""
Test that compute_logical_maps correctly handles physical slots containing
-1 (unused slots).
"""
# 2 layers, 6 physical slots, 4 logical experts.
# Slots 2 and 5 are unused (-1).
phy2log = torch.tensor(
[
[0, 1, -1, 2, 3, -1],
[3, -1, 2, 1, 0, -1],
]
)
num_layers = 2
num_logical_experts = 4

log2phy, logcnt = compute_logical_maps(phy2log, num_logical_experts)

assert logcnt.shape == (num_layers, num_logical_experts)
assert log2phy.shape == (num_layers, num_logical_experts, 1)

expected_logcnt = torch.ones(num_layers, num_logical_experts, dtype=phy2log.dtype)
assert torch.all(logcnt == expected_logcnt), (
f"Expected that all replica counts == 1, got {logcnt}"
)

assert torch.all(log2phy >= 0), (
"log2phy should only contain valid physical indices, not -1"
)

assert log2phy[0, 0, 0] == 0
assert log2phy[0, 1, 0] == 1
assert log2phy[0, 2, 0] == 3
assert log2phy[0, 3, 0] == 4


if __name__ == "__main__":
weight = torch.tensor(
[
Expand All @@ -305,7 +349,7 @@ def test_additional_cases():
num_nodes = 2
num_gpus = 8

phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
phy2log = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
print(phy2log)
Expand Down Expand Up @@ -434,9 +478,10 @@ def test_preserve_intragpu_slots(
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)

post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
post_phy2log = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, num_ranks, old_phy2log
)
post_phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(post_phy2log)

# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape
Expand Down
16 changes: 1 addition & 15 deletions vllm/distributed/eplb/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def run_rebalance_experts(
# Move the global expert load window to CPU for computation.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
new_physical_to_logical_map = eplb_state.policy.rebalance_experts(
global_expert_load_window,
eplb_stats.num_replicas,
eplb_stats.num_groups,
Expand All @@ -89,16 +85,6 @@ def run_rebalance_experts(

model_state.new_physical_to_logical_map = new_physical_to_logical_map

max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica


async def transfer_run_periodically(
state: "EplbState",
Expand Down
Loading
Loading