Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
358393b
init offline eplb
Patryk999 Oct 3, 2025
e104aac
fixes
Patryk999 Oct 3, 2025
f7b4645
add space
Patryk999 Oct 3, 2025
58cc52c
precommit
Patryk999 Oct 3, 2025
22e0a87
mxfp4 forward arg
Patryk999 Oct 3, 2025
d1e34fd
Merge branch 'main' into patryk/offline-eplb
PatrykSaffer Oct 7, 2025
3439d05
precommit
Patryk999 Oct 7, 2025
d0c2bef
support transformers_moe.py
Patryk999 Oct 7, 2025
482fd0b
Update transformers_moe.py
PatrykSaffer Oct 7, 2025
db81577
Update transformers_moe.py
PatrykSaffer Oct 7, 2025
65f777e
Merge branch 'main' into patryk/offline-eplb
Patryk999 Oct 17, 2025
dfd8056
merge
Patryk999 Oct 17, 2025
64975d7
precommit
Patryk999 Oct 17, 2025
2ebf603
don't double eplb prefix
Patryk999 Oct 27, 2025
abbeb70
don't double eplb prefix
Patryk999 Oct 27, 2025
66c375c
don't double eplb prefix
Patryk999 Oct 27, 2025
45c9ce4
Merge branch 'main' into patryk/offline-eplb
PatrykSaffer Nov 14, 2025
783ca91
offline eplb with eagle
Patryk999 Nov 14, 2025
ecec01d
more fixes
Patryk999 Nov 14, 2025
ed322fc
more fixes
Patryk999 Nov 17, 2025
3f245da
test fixes
Patryk999 Nov 17, 2025
7874a61
precommit fixes
Patryk999 Nov 17, 2025
1f71042
precommit fixes
Patryk999 Nov 17, 2025
55de3d0
Merge branch 'main' into patryk/offline-eplb
PatrykSaffer Dec 3, 2025
dcf0733
fixes post rebase
Patryk999 Dec 3, 2025
7a03145
fixes post rebase
Patryk999 Dec 3, 2025
ae6dad2
fixes post rebase
Patryk999 Dec 3, 2025
3ca1e6c
fixes post rebase
Patryk999 Dec 3, 2025
2f2a04b
readability \diffs
Patryk999 Dec 3, 2025
dbc6dc8
fixing integration tests
Patryk999 Dec 3, 2025
1c8480d
Merge branch 'main' into patryk/offline-eplb
PatrykSaffer Dec 4, 2025
957b54c
precommit
Patryk999 Dec 4, 2025
6dbeeb6
Merge branch 'main' into patryk/offline-eplb
PatrykSaffer Dec 5, 2025
e06c9ba
Merge remote-tracking branch 'origin/main' into patryk/offline-eplb
Patryk999 Mar 9, 2026
87079af
Merge main fixes
Patryk999 Mar 10, 2026
00289b9
Merge main fixes
Patryk999 Mar 10, 2026
c4bc60c
Merge main fixes
Patryk999 Mar 10, 2026
efdbe3e
Merge main fixes3
Patryk999 Mar 10, 2026
5f4f62c
Merge main fixes3
Patryk999 Mar 10, 2026
484752d
Merge main fixes3
Patryk999 Mar 10, 2026
56a7d81
add not
Patryk999 Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/v1/e2e/test_eplb_offline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

import pytest

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EPLBConfig


@pytest.mark.parametrize(
"model_setup",
[
("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", 2),
],
)
def test_eplb_model(
model_setup: tuple[str, int],
):
model_name, tp_size = model_setup
test_prompt = ["This is a prompt which has more than 10 tokens."]

llm_args = dict(
model=model_name,
tensor_parallel_size=tp_size,
max_model_len=2048,
enable_expert_parallel=True,
enable_eplb=True,
load_format="dummy",
gpu_memory_utilization=0.95,
)

# Save EPLB statistics to disk
eplb_config_save = EPLBConfig(
window_size=8, step_interval=10, save_load_window=True, save_dir="/tmp"
)
llm = LLM(eplb_config=eplb_config_save, **llm_args)
llm.generate(test_prompt)
del llm
cleanup_dist_env_and_memory()

# Load EPLB statistics from disk
eplb_config_load = EPLBConfig(
load_initial_load_window=True,
load_path="/tmp/global_expert_load_window_i0.safetensors",
use_async=True,
)
llm = LLM(eplb_config=eplb_config_load, **llm_args)
llm.generate(test_prompt)
del llm
40 changes: 40 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, overload

import torch
Expand Down Expand Up @@ -81,6 +82,30 @@ class EPLBConfig:

policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB)."""
load_initial_load_window: bool = False
"""
Whether to load initial load window.
"""
save_load_window: bool = False
"""
Whether to save load window.
"""
static: bool = False
"""
Whether to do just one expert reshuffling at the start.
"""
save_dir: Path | None = None
"""Directory to save expert load balance metrics."""
load_path: Path | None = None
"""Path to load expert load balance metrics."""

@model_validator(mode="after")
def _validate_eplb_config(self) -> Self:
if self.use_async and self.policy != "default":
raise ValueError("Async EPLB is only supported with the default policy.")
if self.log_balancedness and self.log_balancedness_interval <= 0:
raise ValueError("log_balancedness_interval must be greater than 0.")
return self

@model_validator(mode="after")
def _validate_eplb_config(self) -> Self:
Expand Down Expand Up @@ -865,5 +890,20 @@ def _verify_args(self) -> Self:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."
)
if (
self.eplb_config.load_initial_load_window
and self.eplb_config.load_path is None
):
raise ValueError(
"load_initial_load_window is set to True,but load_path is not provided."
)
if self.eplb_config.save_load_window and self.eplb_config.save_dir is None:
raise ValueError(
"save_load_window is set to True, but save_dir is provided."
)
if self.eplb_config.save_load_window and self.eplb_config.static:
raise ValueError("save_load_window cannot be set to true with static eplb.")
if self.eplb_config.use_async and self.eplb_config.static:
raise ValueError("use_async cannot be set to true with static eplb.")

return self
125 changes: 99 additions & 26 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
import threading
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
from safetensors.torch import load_file, save_file
from torch.distributed import ProcessGroup, all_reduce

from vllm.config import ModelConfig, ParallelConfig
Expand Down Expand Up @@ -247,6 +249,40 @@ class EplbModelState:
"""


def save_eplb_state(
global_expert_load_windows: list[torch.Tensor],
save_dir: Path,
state_iter: int,
model_states: list[EplbModelState],
) -> None:
tensors = {}
for eplb_model_state, global_expert_load_window in zip(
model_states, global_expert_load_windows
):
name = type(eplb_model_state.model).__name__
tensors[f"global_expert_load_window_{name}"] = global_expert_load_window
try:
file_path = f"{save_dir}/global_expert_load_window_i{state_iter}.safetensors" # noqa: E501
save_file(tensors, file_path)
logger.info("Successfully saved to %s.", file_path)
except Exception as e:
logger.error("An error occurred while saving the tensor: %s.", e)


def load_eplb_state(
eplb_load_path: Path, model_states: list[EplbModelState]
) -> list[torch.Tensor]:
loaded_tensors = load_file(eplb_load_path)
global_load_windows = []
for eplb_model_state in model_states:
name = type(eplb_model_state.model).__name__
tensor = loaded_tensors[f"global_expert_load_window_{name}"]
tensor = tensor.to(eplb_model_state.expert_load_window.device)
global_load_windows.append(tensor)
logger.info("Successfully loaded %s.", eplb_load_path)
return global_load_windows


class EplbState:
"""
EplbState of each expert parallel model. Key is the model config hash.
Expand Down Expand Up @@ -287,6 +323,10 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device):
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
self.saved_state_iter = 0
"""
Iterator for the saved state file.
"""
self.is_async: bool = False
"""
The flag indicates whether the EPLB is running in async mode.
Expand Down Expand Up @@ -461,11 +501,17 @@ def add_model(
device=self.device,
)

# Set the initial progress of rearrangement to 3/4
eplb_step_interval = self.parallel_config.eplb_config.step_interval
self.expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4
)
if (
self.parallel_config.eplb_config.load_initial_load_window
or self.parallel_config.eplb_config.save_load_window
):
self.expert_rearrangement_step = 0
else:
# Set the initial progress of rearrangement to 3/4
self.expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4
)
self.expert_rearrangement_step_interval = eplb_step_interval

policy_type = self.parallel_config.eplb_config.policy
Expand Down Expand Up @@ -643,6 +689,9 @@ def step(
def rearrange(
self,
is_profile: bool = False,
execute_shuffle: bool = True,
load_initial_load_window: bool = False,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
Expand Down Expand Up @@ -673,32 +722,56 @@ def rearrange(
"(profile)" if is_profile else "",
)

# Map the physical expert load to global logical experts
global_expert_load_windows = []
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
if load_initial_load_window:
assert self.parallel_config.eplb_config.load_path is not None
global_expert_load_windows = load_eplb_state(
self.parallel_config.eplb_config.load_path,
list(self.model_states.values()),
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
assert global_expert_load_windows is not None
else:
# Map the physical expert load to global logical experts
global_expert_load_windows = []
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=expert_load_window,
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=expert_load_window,
)

global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)

should_save_eplb_state = (
self.parallel_config.eplb_config.save_load_window
and not is_profile
and not load_initial_load_window
)
if is_main_rank and should_save_eplb_state:
assert self.parallel_config.eplb_config.save_dir is not None
save_eplb_state(
global_expert_load_windows,
self.parallel_config.eplb_config.save_dir,
self.saved_state_iter,
list(self.model_states.values()),
)
self.saved_state_iter += 1

global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)

Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(
activation: str = "silu",
is_act_and_mul: bool = True,
enable_eplb: bool = False,
eplb_static: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
Expand Down Expand Up @@ -390,6 +391,7 @@ def __init__(
self.layer_name = prefix

self.enable_eplb = enable_eplb
self.eplb_static = eplb_static
# TODO(bnell): should this be owned by router?
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
Expand Down Expand Up @@ -517,6 +519,7 @@ def __init__(
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
eplb_static=self.eplb_static,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
Expand Down
51 changes: 28 additions & 23 deletions vllm/model_executor/layers/fused_moe/router/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def eplb_map_to_physical_and_record(
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
eplb_static: bool = False,
) -> torch.Tensor:
"""
Map the logical expert ids to physical expert ids
Expand Down Expand Up @@ -60,29 +61,30 @@ def eplb_map_to_physical_and_record(

topk_ids = physical_ids

# 2. Record expert load metrics.

# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalizeModular` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.

# `expert_load_view`: (num_physical_experts,)

# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
if not eplb_static:
# 2. Record expert load metrics.

# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.

# `expert_load_view`: (num_physical_experts,)

# `torch.bincount` is not compilable, so use `scatter_add_` instead.
topk_ids_flatten = topk_ids.flatten()
expert_load_view.scatter_add_(
dim=0,
index=topk_ids_flatten.long(),
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
)
return topk_ids
else:

Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
eplb_static: bool = False,
# TODO(bnell): Once the MK is constructed at layer init time, we
# can make this a plain value instead of a callback.
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
Expand All @@ -126,6 +129,7 @@ def __init__(
self.global_num_experts = global_num_experts
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.eplb_static = eplb_static
self.indices_type_getter = indices_type_getter
self.capture_fn: Callable[[torch.Tensor], None] | None = None

Expand Down Expand Up @@ -164,6 +168,7 @@ def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
expert_load_view=self.eplb_state.expert_load_view,
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
logical_replica_count=self.eplb_state.logical_replica_count,
eplb_static=self.eplb_static,
)
return topk_ids

Expand Down
Loading