Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
Expand Down Expand Up @@ -201,7 +202,7 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
dtype=torch.int32,
device=device,
)
fml.enable_eplb = True
fml.eplb_state = EplbLayerState()
fml.set_eplb_state(
lidx,
torch.zeros(
Expand Down
3 changes: 0 additions & 3 deletions tests/kernels/moe/test_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,6 @@ def make_fake_moe_layer(
activation: str = "silu",
indices_type: torch.dtype | None = None,
expert_map: torch.Tensor | None = None,
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
Expand All @@ -1022,7 +1021,6 @@ def make_fake_moe_layer(
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=global_num_experts,
# eplb_state=None, # TODO
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
Expand All @@ -1032,7 +1030,6 @@ def make_fake_moe_layer(
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=0, # TODO
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: indices_type,
Expand Down
11 changes: 4 additions & 7 deletions tests/kernels/moe/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def _is_aiter_capable() -> bool:
NUM_EXPERTS = [8, 16, 64]


def setup_eplb_state(enable_eplb: bool, global_num_experts: int) -> EplbLayerState:
def setup_eplb_state(
enable_eplb: bool, global_num_experts: int
) -> EplbLayerState | None:
if not enable_eplb:
return EplbLayerState()
return None

# Initialize EPLB state with proper tensors for testing
# For testing purposes, we use a simple 1:1 mapping (no redundant experts)
Expand Down Expand Up @@ -349,7 +351,6 @@ def test_fused_topk(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)

Expand Down Expand Up @@ -400,7 +401,6 @@ def test_fused_topk_bias(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)

Expand Down Expand Up @@ -469,7 +469,6 @@ def test_grouped_topk(
top_k=top_k,
global_num_experts=global_num_experts,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)

Expand Down Expand Up @@ -540,7 +539,6 @@ def test_custom(
global_num_experts=global_num_experts,
custom_routing_function=custom_routing_function,
renormalize=renormalize,
enable_eplb=enable_eplb,
eplb_state=eplb_state,
)

Expand Down Expand Up @@ -580,7 +578,6 @@ def test_custom(
# router = create_fused_moe_router(
# top_k=top_k,
# global_num_experts=global_num_experts,
# enable_eplb=enable_eplb,
# eplb_state=eplb_state,
# )

Expand Down
11 changes: 11 additions & 0 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,17 @@ class EplbLayerState:
GPU work.
"""

def set_layer_state(
self,
moe_layer_idx: int,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
self.expert_load_view = expert_load_view[moe_layer_idx]
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]


def _node_count_with_rank_mapping(
pg: ProcessGroup | StatelessProcessGroup,
Expand Down
49 changes: 29 additions & 20 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,24 @@ def __init__(
compilation_config.static_all_moe_layers.append(prefix)
self.layer_name = prefix

self.enable_eplb = enable_eplb
# TODO(bnell): should this be owned by router?
self.eplb_state = EplbLayerState()
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)

self.eplb_state: EplbLayerState | None = None
if enable_eplb:
if self.use_ep and self.global_num_experts % self.ep_size != 0:
raise ValueError(
f"EPLB currently only supports even distribution of "
f"experts across ranks. Got {self.global_num_experts} experts "
f"and {self.ep_size} EP ranks."
)
self.eplb_state = EplbLayerState()
else:
assert not self.use_ep or num_redundant_experts == 0, (
"Redundant experts are only supported with EPLB."
)

# ROCm aiter shared experts fusion
# AITER only supports gated activations (silu/gelu), so disable it
# for non-gated MoE (is_act_and_mul=False)
Expand Down Expand Up @@ -237,17 +248,6 @@ def __init__(
)

# Determine expert maps
if self.use_ep:
if self.enable_eplb:
assert self.global_num_experts % self.ep_size == 0, (
"EPLB currently only supports even distribution of "
"experts across ranks."
)
else:
assert num_redundant_experts == 0, (
"Redundant experts are only supported with EPLB."
)

max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens

# Create ExpertMapManager to handle expert mapping and placement for EP.
Expand All @@ -261,7 +261,7 @@ def __init__(
num_expert_group=num_expert_group,
moe_parallel_config=self.moe_parallel_config,
placement_strategy=self.expert_placement_strategy,
enable_eplb=self.enable_eplb,
enable_eplb=enable_eplb,
num_fused_shared_experts=self.num_fused_shared_experts,
rocm_aiter_enabled=self.rocm_aiter_fmoe_enabled,
)
Expand Down Expand Up @@ -313,7 +313,6 @@ def __init__(
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
num_fused_shared_experts=self.num_fused_shared_experts,
enable_eplb=enable_eplb,
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
Expand Down Expand Up @@ -381,7 +380,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
"is_act_and_mul=False is supported only for CUDA and XPU for now"
)

if self.enable_eplb and not self.quant_method.supports_eplb:
if enable_eplb and not self.quant_method.supports_eplb:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self, it seems weird that It would be the quant method that determines this

Shoudlnt it be the kernel?

# TODO: Add support for additional quantization methods.
# The implementation for other quantization methods does not
# contain essential differences, but the current quant API
Expand Down Expand Up @@ -1283,10 +1282,20 @@ def set_eplb_state(

This is used later in forward pass, where we get the expert mapping
and record the load metrics in `expert_load_view`.

Args:
moe_layer_idx: Index of this MoE layer
expert_load_view: View into global expert load tracking tensor
logical_to_physical_map: Mapping from logical to physical expert IDs
logical_replica_count: Number of replicas for each logical expert
"""
self.eplb_state.expert_load_view = expert_load_view[moe_layer_idx]
self.eplb_state.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.eplb_state.logical_replica_count = logical_replica_count[moe_layer_idx]
if self.eplb_state is not None:
self.eplb_state.set_layer_state(
moe_layer_idx,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)

def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
num_fused_shared_experts: int,
eplb_state: EplbLayerState | None = None,
scoring_func: str = "softmax",
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
Expand Down
54 changes: 27 additions & 27 deletions vllm/model_executor/layers/fused_moe/router/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
enable_eplb: bool = False,
eplb_state: EplbLayerState | None = None,
# TODO(bnell): Once the MK is constructed at layer init time, we
# can make this a plain value instead of a callback.
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
Expand All @@ -159,12 +158,17 @@ def __init__(
time, so we need to supply a callback to get it at runtime. This is
because the indices type is supplied by modular kernels which are
created after MoE layer/router construction.

Args:
top_k: Number of experts to select per token
global_num_experts: Total number of experts
eplb_state: Optional EPLBLayerState for load balancing
indices_type_getter: Optional callback to get indices dtype
"""
super().__init__()
self.top_k = top_k
self.global_num_experts = global_num_experts
self.eplb_state = eplb_state
self.enable_eplb = enable_eplb
self.indices_type_getter = indices_type_getter
self.capture_fn: Callable[[torch.Tensor], None] | None = None

Expand All @@ -174,21 +178,16 @@ def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> N

def _validate_eplb_state(self) -> None:
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
if self.enable_eplb:
if self.eplb_state.expert_load_view is None:
raise ValueError("enable_eplb=True requires expert_load_view != None")
if self.eplb_state.logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requires logical_to_physical_map != None"
)
if self.eplb_state.logical_replica_count is None:
raise ValueError(
"enable_eplb=True requires logical_replica_count != None"
)
if self.eplb_state.should_record_tensor is None:
raise ValueError(
"enable_eplb=True requires should_record_tensor != None"
)
if self.eplb_state is not None:
eplb_state = self.eplb_state
if eplb_state.expert_load_view is None:
raise ValueError("EPLB requires expert_load_view != None")
if eplb_state.logical_to_physical_map is None:
raise ValueError("EPLB requires logical_to_physical_map != None")
if eplb_state.logical_replica_count is None:
raise ValueError("EPLB requires logical_replica_count != None")
if eplb_state.should_record_tensor is None:
raise ValueError("EPLB requires should_record_tensor != None")

def _get_indices_type(self) -> torch.dtype | None:
"""Get the desired indices dtype from the getter function."""
Expand All @@ -198,17 +197,18 @@ def _get_indices_type(self) -> torch.dtype | None:

def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
"""Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
if self.enable_eplb:
assert self.eplb_state.expert_load_view is not None
assert self.eplb_state.logical_to_physical_map is not None
assert self.eplb_state.logical_replica_count is not None
assert self.eplb_state.should_record_tensor is not None
if self.eplb_state is not None:
eplb_state = self.eplb_state
assert eplb_state.expert_load_view is not None
assert eplb_state.logical_to_physical_map is not None
assert eplb_state.logical_replica_count is not None
assert eplb_state.should_record_tensor is not None
return eplb_map_to_physical_and_record(
topk_ids=topk_ids,
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
logical_replica_count=self.eplb_state.logical_replica_count,
expert_load_view=self.eplb_state.expert_load_view,
record_enabled=self.eplb_state.should_record_tensor,
logical_to_physical_map=eplb_state.logical_to_physical_map,
logical_replica_count=eplb_state.logical_replica_count,
expert_load_view=eplb_state.expert_load_view,
record_enabled=eplb_state.should_record_tensor,
)
return topk_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
custom_routing_function: Callable,
eplb_state: EplbLayerState | None = None,
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.custom_routing_function = custom_routing_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,10 @@ def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor | None = None,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
eplb_state: EplbLayerState | None = None,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
*,
scoring_func: str = "sigmoid",
Expand All @@ -249,7 +248,6 @@ def __init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.e_score_correction_bias = e_score_correction_bias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,15 @@ def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
scoring_func: str = "softmax",
renormalize: bool = True,
enable_eplb: bool = False,
eplb_state: EplbLayerState | None = None,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
Expand Down
Loading
Loading