Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
4aeabf2
initial MoERunner refactor
bnellnm Jan 13, 2026
a4d3acb
fix lint
bnellnm Feb 12, 2026
5b7f133
rebase
bnellnm Feb 24, 2026
fad7f33
rebase + remove dead code
bnellnm Mar 5, 2026
83c1863
wip
bnellnm Feb 4, 2026
7c7953e
fix
bnellnm Feb 9, 2026
d5b5805
WIP DOUBLE CHECK THIS
bnellnm Feb 11, 2026
42de827
wip more refactoring
bnellnm Feb 19, 2026
2e4ce00
wip
bnellnm Feb 19, 2026
0144b8b
SharedExperts wip
bnellnm Feb 23, 2026
27ab769
cleanups
bnellnm Feb 23, 2026
bf8b3f3
fix circular import
bnellnm Feb 23, 2026
bb541ae
fixes
bnellnm Feb 24, 2026
0eb70c6
renames
bnellnm Feb 24, 2026
8351d89
add comment
bnellnm Feb 24, 2026
943e667
more renames
bnellnm Feb 24, 2026
0ae2b4f
cleanup
bnellnm Feb 25, 2026
657a4ef
remove memoizing router, not needed yet
bnellnm Feb 26, 2026
6b4ef3b
fix UBD bug
bnellnm Feb 27, 2026
125b34b
cleanup merge
bnellnm Mar 5, 2026
78d5ed8
fix merge
bnellnm Mar 5, 2026
8a5445a
fix merge
bnellnm Mar 5, 2026
5ee510f
fix typos
bnellnm Mar 5, 2026
d675882
fix merge
bnellnm Mar 18, 2026
1dc047a
fix format
bnellnm Mar 18, 2026
111c628
Split of DefaultMoERunner class
bnellnm Jan 13, 2026
ced8fec
fix merge
bnellnm Mar 18, 2026
5c8d40d
fix merge
bnellnm Mar 18, 2026
74eeba9
attempt to fix zero experts
bnellnm Feb 26, 2026
45a48b0
simplify ZeroExpertFusedMoE and add ZeroExpertRouter
bnellnm Feb 27, 2026
65be7f7
add value test
bnellnm Feb 27, 2026
3b88ab2
move ZeroExpertRouter construction into router factory
bnellnm Feb 27, 2026
6cd80fc
move zero expert handling into MoERunnerBase
bnellnm Feb 27, 2026
992672b
slightly improved test
bnellnm Feb 27, 2026
900fc40
simplifications
bnellnm Feb 27, 2026
71f8fb8
better test
bnellnm Feb 27, 2026
c0e12b5
remove ZeroExpertFusedMoE
bnellnm Feb 27, 2026
47205a3
Add comment
bnellnm Mar 2, 2026
3b463bc
fix lint
bnellnm Mar 18, 2026
ec88db3
fix gate overlap
bnellnm Mar 19, 2026
76aff0a
wip
bnellnm Feb 4, 2026
4fab915
fix
bnellnm Feb 9, 2026
d8a7f91
WIP DOUBLE CHECK THIS
bnellnm Feb 11, 2026
3dec78f
wip more refactoring
bnellnm Feb 19, 2026
e94b863
wip
bnellnm Feb 19, 2026
6cc5074
SharedExperts wip
bnellnm Feb 23, 2026
e8865e6
cleanups
bnellnm Feb 23, 2026
f83e0f5
fix circular import
bnellnm Feb 23, 2026
88e80b9
fixes
bnellnm Feb 24, 2026
781d4ea
renames
bnellnm Feb 24, 2026
3695016
add comment
bnellnm Feb 24, 2026
053f66f
more renames
bnellnm Feb 24, 2026
708dd2b
cleanup
bnellnm Feb 25, 2026
5748f7c
remove memoizing router, not needed yet
bnellnm Feb 26, 2026
9123f15
fix UBD bug
bnellnm Feb 27, 2026
04b430f
cleanup merge
bnellnm Mar 5, 2026
526db38
fix merge
bnellnm Mar 5, 2026
67bdab2
fix merge
bnellnm Mar 5, 2026
e9afbe6
fix typos
bnellnm Mar 5, 2026
453ab3d
fix merge
bnellnm Mar 18, 2026
48acc59
fix format
bnellnm Mar 18, 2026
c067844
fix gate overlap
bnellnm Mar 19, 2026
9f0e8d7
merge with main
bnellnm Mar 19, 2026
bc82978
renames, revert lora changes
bnellnm Mar 19, 2026
3dc9d4f
review comments + cleanup
bnellnm Mar 20, 2026
12bda3d
remove _must_reduce_shared_expert_outputs
bnellnm Mar 20, 2026
8aaddea
undo some changes + add Rob's changes
bnellnm Mar 23, 2026
bbaaca7
Merge remote-tracking branch 'origin/main' into moe-runner-2
bnellnm Mar 23, 2026
392f311
hacky fix for unquantized method
bnellnm Mar 23, 2026
7d5adbe
fix lint
bnellnm Mar 23, 2026
f345165
fix lint
bnellnm Feb 12, 2026
bdefdf5
fix merge
bnellnm Mar 25, 2026
377acc8
fix merge
bnellnm Mar 25, 2026
14e58dc
don't pass shared_experts to MK in lora code
bnellnm Mar 25, 2026
392fb60
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 1, 2026
7b86f43
remove cruft
bnellnm Apr 1, 2026
fd7a324
review comments
bnellnm Apr 1, 2026
ea79ff7
fix lint
bnellnm Apr 1, 2026
b6ed07b
remove EXTERNAL SharedExperts order
bnellnm Apr 1, 2026
dd1f23a
make sure some methods are handled properly on ChunkingMoERunner
bnellnm Apr 1, 2026
acebc42
fixes
bnellnm Apr 2, 2026
b08bf02
Merge branch 'main' into moe-runner-3
bnellnm Apr 2, 2026
73a0356
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 2, 2026
9934e37
Merge branch 'main' into moe-runner-3
bnellnm Apr 3, 2026
ed0ff6e
remove assert
bnellnm Apr 3, 2026
0d2b4dc
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 3, 2026
bae2080
Merge remote-tracking branch 'nm-vllm/moe-runner-3' into moe-runner-3
bnellnm Apr 3, 2026
e015682
merge with moe-runner-3
bnellnm Apr 3, 2026
30da43c
remove memoizing router
bnellnm Apr 3, 2026
5785a4e
merge
bnellnm Apr 6, 2026
0405c75
revert bogus changes
bnellnm Apr 6, 2026
c2cc30f
remove test
bnellnm Apr 14, 2026
8d96c38
Merge remote-tracking branch 'origin/main' into moe-runner-4
bnellnm Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions tests/kernels/moe/test_zero_expert_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for FusedMoE with zero experts.

Verifies that:
- The ZeroExpertRouter is properly created and used as the layer router.
- A forward pass through FusedMoE with zero experts produces correct output.
- The output decomposes correctly into real expert + zero expert contributions.

Note: tests generated with Claude.
"""

import pytest
import torch

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe.router.zero_expert_router import (
ZeroExpertRouter,
)
from vllm.v1.worker.workspace import init_workspace_manager


@pytest.fixture
def zero_expert_moe(dist_init, default_vllm_config):
"""Create a FusedMoE layer with zero experts."""
num_experts = 4
top_k = 2
# hidden_size must be >= 256 for the zero expert identity kernel to
# produce output (its BLOCK_SIZE=256 causes grid=0 when hidden_dim<256).
hidden_size = 256
intermediate_size = 512
zero_expert_num = 1

e_score_correction_bias = torch.zeros(
num_experts + zero_expert_num,
dtype=torch.float32,
device="cuda",
)

vllm_config = VllmConfig()
vllm_config.compilation_config.static_forward_context = dict()

with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
init_workspace_manager(torch.accelerator.current_device_index())

layer = FusedMoE(
zero_expert_type="identity",
e_score_correction_bias=e_score_correction_bias,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=torch.bfloat16,
prefix="test_zero_expert_moe",
renormalize=False,
routed_scaling_factor=1.0,
scoring_func="softmax",
).cuda()

layer.quant_method.process_weights_after_loading(layer)

yield layer, vllm_config


@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_router_is_zero_expert_router(zero_expert_moe, num_tokens):
"""Verify that FusedMoE with zero_expert_type creates a ZeroExpertRouter."""
layer, _ = zero_expert_moe
assert isinstance(layer.router, ZeroExpertRouter), (
f"Expected ZeroExpertRouter but got {type(layer.router).__name__}."
)


@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_no_custom_routing_fn(zero_expert_moe, num_tokens):
"""Verify that custom_routing_function is not set (routing is handled
by ZeroExpertRouter, not a memoizing closure)."""
layer, _ = zero_expert_moe
assert layer.custom_routing_function is None


@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_forward(zero_expert_moe, num_tokens):
"""Run a forward pass through FusedMoE with zero experts and verify output shape."""
layer, vllm_config = zero_expert_moe

hidden_size = layer.hidden_size
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num

hidden_states = torch.randn(
num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda"
)
router_logits = torch.randn(
num_tokens, total_experts, dtype=torch.float32, device="cuda"
)

# Initialize weights to small random values to avoid NaN from
# uninitialized memory.
with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)

with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None
output = layer.forward(hidden_states, router_logits)

assert output.shape == hidden_states.shape, (
f"Expected output shape {hidden_states.shape}, got {output.shape}"
)
assert output.dtype == hidden_states.dtype
assert not torch.isnan(output).any(), "Output contains NaN values"


@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_output_decomposition(zero_expert_moe, num_tokens):
"""Validate that the FusedMoE output equals a plain FusedMoE
output (real experts only) plus the zero expert contribution.

The key invariant is:
zero_layer.forward(h, r_full) == plain_layer.forward(h, r_real)
+ zero_expert_output

We create a plain FusedMoE layer with the same weights and real-expert-only
router logits, compute the zero expert output via the ZeroExpertRouter, and
verify the sum matches the FusedMoE output.
"""
layer, vllm_config = zero_expert_moe
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num

hidden_states = torch.randn(
num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda"
)
router_logits = torch.randn(
num_tokens, total_experts, dtype=torch.float32, device="cuda"
)

with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)

with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None

# Create a plain FusedMoE layer with the same config but no zero
# experts. Use a separate prefix to avoid collision.
plain_layer = FusedMoE(
num_experts=num_experts,
top_k=layer.top_k,
hidden_size=layer.hidden_size,
intermediate_size=layer.intermediate_size_per_partition,
params_dtype=torch.bfloat16,
prefix="test_zero_expert_moe_plain",
renormalize=False,
scoring_func="softmax",
e_score_correction_bias=layer.e_score_correction_bias,
).cuda()

# Share weights from the zero expert layer.
plain_layer.w13_weight.data.copy_(layer.w13_weight.data)
plain_layer.w2_weight.data.copy_(layer.w2_weight.data)
plain_layer.quant_method.process_weights_after_loading(plain_layer)

# Compute routing via the ZeroExpertRouter. This produces masked
# topk_weights/topk_ids (zero expert entries have weight=0, id=0)
# and stores zero_expert_output as a side effect.
topk_weights, topk_ids = layer.router.select_experts(
hidden_states, router_logits
)
zero_output = layer.router.zero_expert_output

# Compute real expert output using the plain layer with the masked
# routing from the ZeroExpertRouter.
real_output = plain_layer.quant_method.apply(
layer=plain_layer,
x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
shared_experts_input=None,
)

# Get the combined output from the zero expert layer.
full_output = layer.forward(hidden_states, router_logits)

assert zero_output is not None, "Zero expert output should not be None"
assert not torch.isnan(real_output).any(), "Real expert output has NaN"
assert not torch.isnan(zero_output).any(), "Zero expert output has NaN"
assert not torch.isnan(full_output).any(), "Full output has NaN"

expected = real_output + zero_output
torch.testing.assert_close(
full_output,
expected,
atol=0,
rtol=0,
msg="FusedMoE output should equal plain FusedMoE output "
"plus zero expert contribution",
)


@pytest.mark.parametrize("num_tokens", [1, 32])
def test_zero_expert_moe_zero_expert_is_identity(zero_expert_moe, num_tokens):
"""Validate zero expert identity behavior.

When routing strongly favors the zero expert, its contribution should
be a scaled version of hidden_states (identity operation). We verify
this by manually computing the expected zero expert output from the
routing weights and comparing against what the router produces.
"""
layer, vllm_config = zero_expert_moe
num_experts = 4
zero_expert_num = 1
total_experts = num_experts + zero_expert_num

hidden_states = torch.randn(
num_tokens, layer.hidden_size, dtype=torch.bfloat16, device="cuda"
)
# Strongly bias toward the zero expert (index 4).
router_logits = torch.full(
(num_tokens, total_experts), -10.0, dtype=torch.float32, device="cuda"
)
router_logits[:, num_experts] = 10.0 # zero expert gets high logit

with torch.no_grad():
for param in layer.parameters():
if param.dtype.is_floating_point:
param.normal_(0, 0.01)

with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
get_forward_context().all_moe_layers = None

# Run routing to get topk_weights/topk_ids before masking.
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)

topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=layer.router.e_score_correction_bias.data,
topk=layer.top_k,
renormalize=layer.router.renormalize,
scoring_func=layer.router.scoring_func,
)

# Manually compute expected zero expert identity output:
# For each token, sum routing weights assigned to zero expert slots,
# then multiply by hidden_states.
zero_mask = topk_ids >= num_experts
zero_weight_per_token = (topk_weights * zero_mask.float()).sum(
dim=-1, keepdim=True
)
expected_zero_output = (hidden_states.float() * zero_weight_per_token).to(
hidden_states.dtype
)

# Run routing directly to trigger zero expert computation
# without going through the runner (which consumes the output).
layer.router.select_experts(hidden_states, router_logits)
actual_zero_output = layer.router.zero_expert_output

assert actual_zero_output is not None
assert zero_mask.any(), (
"With high zero expert logit, at least some slots should route "
"to the zero expert"
)

torch.testing.assert_close(
actual_zero_output,
expected_zero_output,
atol=1e-3,
rtol=1e-3,
msg="Zero expert identity output should equal "
"hidden_states * sum(zero_expert_weights)",
)
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
ZeroExpertFusedMoE,
)
from vllm.triton_utils import HAS_TRITON

_config: dict[str, Any] | None = None
Expand Down Expand Up @@ -68,7 +65,6 @@ def get_config() -> dict[str, Any] | None:
"GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",
"activation_without_mul",
"apply_moe_activation",
"override_config",
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
zero_expert_type: str | None = None,
):
super().__init__()

Expand Down Expand Up @@ -462,6 +463,8 @@ def __init__(
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
zero_expert_type=zero_expert_type,
num_logical_experts=self.logical_num_experts,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type

Expand Down
42 changes: 38 additions & 4 deletions vllm/model_executor/layers/fused_moe/router/router_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
RoutingSimulatorRouter,
)
from vllm.model_executor.layers.fused_moe.router.zero_expert_router import (
ZeroExpertRouter,
)

EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()

Expand All @@ -49,17 +52,21 @@ def create_fused_moe_router(
# eplb parameters
enable_eplb: bool = False,
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
# zero expert parameters
zero_expert_type: str | None = None,
num_logical_experts: int | None = None,
) -> FusedMoERouter:
"""
Factory function to create the appropriate FusedMoERouter subclass based on
the provided parameters.

The selection logic follows this priority order:
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
2. GroupedTopKRouter - if use_grouped_topk is True
3. CustomRoutingRouter - if custom_routing_function is not None
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
5. FusedTopKRouter - default fallback
2. ZeroExpertRouter - if zero_expert_type is not None
3. GroupedTopKRouter - if use_grouped_topk is True
4. CustomRoutingRouter - if custom_routing_function is not None
5. FusedTopKBiasRouter - if e_score_correction_bias is not None
6. FusedTopKRouter - default fallback

Common arguments:
top_k: Number of experts to select per token
Expand All @@ -86,6 +93,12 @@ def create_fused_moe_router(
enable_eplb: Whether EPLB is enabled
eplb_state: EPLB (Expert Parallelism Load Balancing) state

Zero expert arguments:
zero_expert_type: Type of zero expert (e.g. identity). If not None,
creates a ZeroExpertRouter.
num_logical_experts: Number of real (non-zero) experts. Required when
zero_expert_type is not None.

Returns:
An instance of the appropriate FusedMoERouter subclass
"""
Expand All @@ -100,6 +113,27 @@ def create_fused_moe_router(
indices_type_getter=indices_type_getter,
)

if zero_expert_type is not None:
assert num_logical_experts is not None, (
"num_logical_experts is required when zero_expert_type is set"
)
assert e_score_correction_bias is not None, (
"e_score_correction_bias is required when zero_expert_type is set"
)
return ZeroExpertRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
num_logical_experts=num_logical_experts,
zero_expert_type=zero_expert_type,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)

if use_grouped_topk:
assert custom_routing_function is None
if num_expert_group is None or topk_group is None:
Expand Down
Loading
Loading