Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4aeabf2
initial MoERunner refactor
bnellnm Jan 13, 2026
a4d3acb
fix lint
bnellnm Feb 12, 2026
5b7f133
rebase
bnellnm Feb 24, 2026
fad7f33
rebase + remove dead code
bnellnm Mar 5, 2026
ec88db3
fix gate overlap
bnellnm Mar 19, 2026
76aff0a
wip
bnellnm Feb 4, 2026
4fab915
fix
bnellnm Feb 9, 2026
d8a7f91
WIP DOUBLE CHECK THIS
bnellnm Feb 11, 2026
3dec78f
wip more refactoring
bnellnm Feb 19, 2026
e94b863
wip
bnellnm Feb 19, 2026
6cc5074
SharedExperts wip
bnellnm Feb 23, 2026
e8865e6
cleanups
bnellnm Feb 23, 2026
f83e0f5
fix circular import
bnellnm Feb 23, 2026
88e80b9
fixes
bnellnm Feb 24, 2026
781d4ea
renames
bnellnm Feb 24, 2026
3695016
add comment
bnellnm Feb 24, 2026
053f66f
more renames
bnellnm Feb 24, 2026
708dd2b
cleanup
bnellnm Feb 25, 2026
5748f7c
remove memoizing router, not needed yet
bnellnm Feb 26, 2026
9123f15
fix UBD bug
bnellnm Feb 27, 2026
04b430f
cleanup merge
bnellnm Mar 5, 2026
526db38
fix merge
bnellnm Mar 5, 2026
67bdab2
fix merge
bnellnm Mar 5, 2026
e9afbe6
fix typos
bnellnm Mar 5, 2026
453ab3d
fix merge
bnellnm Mar 18, 2026
48acc59
fix format
bnellnm Mar 18, 2026
c067844
fix gate overlap
bnellnm Mar 19, 2026
9f0e8d7
merge with main
bnellnm Mar 19, 2026
bc82978
renames, revert lora changes
bnellnm Mar 19, 2026
3dc9d4f
review comments + cleanup
bnellnm Mar 20, 2026
12bda3d
remove _must_reduce_shared_expert_outputs
bnellnm Mar 20, 2026
8aaddea
undo some changes + add Rob's changes
bnellnm Mar 23, 2026
bbaaca7
Merge remote-tracking branch 'origin/main' into moe-runner-2
bnellnm Mar 23, 2026
392f311
hacky fix for unquantized method
bnellnm Mar 23, 2026
7d5adbe
fix lint
bnellnm Mar 23, 2026
f345165
fix lint
bnellnm Feb 12, 2026
bdefdf5
fix merge
bnellnm Mar 25, 2026
377acc8
fix merge
bnellnm Mar 25, 2026
14e58dc
don't pass shared_experts to MK in lora code
bnellnm Mar 25, 2026
392fb60
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 1, 2026
7b86f43
remove cruft
bnellnm Apr 1, 2026
fd7a324
review comments
bnellnm Apr 1, 2026
ea79ff7
fix lint
bnellnm Apr 1, 2026
b6ed07b
remove EXTERNAL SharedExperts order
bnellnm Apr 1, 2026
dd1f23a
make sure some methods are handled properly on ChunkingMoERunner
bnellnm Apr 1, 2026
acebc42
fixes
bnellnm Apr 2, 2026
b08bf02
Merge branch 'main' into moe-runner-3
bnellnm Apr 2, 2026
73a0356
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 2, 2026
9934e37
Merge branch 'main' into moe-runner-3
bnellnm Apr 3, 2026
ed0ff6e
remove assert
bnellnm Apr 3, 2026
0d2b4dc
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 3, 2026
bae2080
Merge remote-tracking branch 'nm-vllm/moe-runner-3' into moe-runner-3
bnellnm Apr 3, 2026
9cfa719
Merge branch 'main' into moe-runner-3
robertgshaw2-redhat Apr 4, 2026
4af2087
Merge remote-tracking branch 'origin/main' into moe-runner-3
bnellnm Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from vllm.model_executor.layers.fused_moe.router.router_factory import (
create_fused_moe_router,
)
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import (
DefaultMoERunner,
from vllm.model_executor.layers.fused_moe.runner.moe_runner_factory import (
create_moe_runner,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
Expand Down Expand Up @@ -572,8 +572,8 @@ def _get_quant_method() -> FusedMoEMethodBase:
# Storing the runner in the FusedMoE is an intermediate state, eventually
# the runner will own the FusedMoE layer and provide the execution interface
# for MoE ops.
self.runner = DefaultMoERunner(
layer=self,
self.runner = create_moe_runner(
layer_name=self.layer_name,
moe_config=self.moe_config,
router=self.router,
routed_input_transform=self._routed_input_transform,
Expand Down
243 changes: 243 additions & 0 deletions vllm/model_executor/layers/fused_moe/runner/chunking_moe_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

from vllm.forward_context import (
get_forward_context,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.v1.worker.workspace import current_workspace_manager


class ChunkingMoERunner(MoERunnerBase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noting that its a bit odd to have both the inner and outer being MoERunnerBase

Copy link
Copy Markdown
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this actually be a mixin ... just an idea

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is a bit awkward. I could probably tease apart MoERunnerBase so that this could inherit from something a little more abstract. I'm not sure I follow how a mixin would work here.

"""
MoE runner wrapper that adds chunked processing to any MoERunnerBase.

This runner wraps an inner MoERunnerBase and overrides _forward_impl to
process large batches by breaking them into smaller chunks. Each chunk
is delegated to the inner runner's _forward_impl, making chunking
composable with any runner implementation.

All MoERunnerBase state (moe_config, router, quant_method, etc.) is
transparently delegated to the inner runner via __getattr__.
ChunkingMoERunner only owns chunking-specific state: the pre-allocated
workspace buffers and the reduce_results override.

Key behaviors:
- Pre-allocates workspace tensors for CUDA graph compatibility
- Processes chunks via inner._forward_impl per chunk
- Never reduces results (reduce_results always returns False)
"""

def __init__(self, inner: MoERunnerBase):
# Assert that _maybe_dispatch/_maybe_combine will be nops.
assert inner.moe_config.pcp_size == 1

# Skip MoERunnerBase.__init__ — all state is delegated to inner
# via __getattr__. Only chunking-specific state lives here.
self._inner = inner

# Pre-allocated staging buffers. These need to exist ahead of time
# due to CUDA graph construction needing fixed buffer addresses.
self.batched_hidden_states, self.batched_router_logits = (
self._init_dp_chunking()
)

def __getattr__(self, name):
Comment thread
bnellnm marked this conversation as resolved.
# Delegate attribute access to the inner runner. This is only
# called when normal lookup (instance __dict__, class MRO) fails,
# so ChunkingMoERunner's own attributes and methods take priority.
return getattr(self._inner, name)

@property
def shared_experts(self) -> SharedExperts | None:
return self._inner.shared_experts

# TODO(bnell): temporary hack, do not call this method.
def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
self._inner._replace_quant_method(quant_method)
self.quant_method = quant_method

def is_internal_router(self) -> bool:
return self._inner.gate is not None

# Reducing results when chunking is handled by the MK finalize operations
# when DP chunking is enabled..
# This will be removed by #35949
@property
def reduce_results(self) -> bool:
Comment thread
bnellnm marked this conversation as resolved.
return False

def _init_dp_chunking(self) -> list[torch.Tensor]:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]

moe = self.moe_config

if self.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
else:
states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)

# Does this need some kind of profiling run check like modular_kernel.py?
return current_workspace_manager().get_simultaneous(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this current_workspace_manager introduced?

I didn't see this in previous implementation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously, it was

device = torch.accelerator.current_device_index()
self.batched_hidden_states = torch.zeros(
    states_shape,
    dtype=moe.in_dtype,
    device=device,
)

self.batched_router_logits = torch.zeros(
    logits_shape,
    dtype=moe.router_logits_dtype,
    device=device,
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the workspace is now shared among all layers. Previously there were separate buffers for each layer.

(states_shape, moe.in_dtype),
(logits_shape, moe.router_logits_dtype),
)

def _allocate_dp_chunking_outputs(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> tuple[torch.Tensor | None, torch.Tensor]:
# Assert the inputs are of the proper type and shape.
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None

assert self.batched_hidden_states.dtype == hidden_states.dtype, (
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
)
assert self.batched_router_logits.dtype == router_logits.dtype, (
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
)

# Check size compatibility.
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == router_logits.size(-1)

final_fused_hidden_states = torch.empty_like(hidden_states)
if self.shared_experts is not None:
if shared_experts_input is not None:
final_shared_hidden_states = torch.empty_like(shared_experts_input)
else:
final_shared_hidden_states = torch.empty_like(hidden_states)
else:
final_shared_hidden_states = None

return final_shared_hidden_states, final_fused_hidden_states

def _slice_and_copy_input(
self,
out_slice: torch.Tensor,
orig: torch.Tensor | None,
start: int,
end: int,
) -> torch.Tensor:
assert orig is not None
slice_size = end - start
orig_slice = orig[start:end, :]
if self.enable_dbo:
assert out_slice.dim() == 3
batch_buffer_idx = dbo_current_ubatch_id()
out_slice = out_slice[batch_buffer_idx, :]

assert out_slice.size(0) >= slice_size
out_slice = out_slice[:slice_size, :]
out_slice.copy_(orig_slice, non_blocking=True)
return out_slice

def _forward_impl(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really needed to chunk the shared expert? I actually dont quite remember why this was required in the first place. I think that we could simplify things a lot if we only chunked the grouped expert

Shouldnt be done in this PR, but something to consider for the follow up

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes things easier if it is chunked since the MK naturally chunks it if it is the one executing it.

self,
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
final_shared_hidden_states, final_fused_hidden_states = (
self._allocate_dp_chunking_outputs(
hidden_states, router_logits, shared_experts_input
)
)

ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens

# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.moe_config.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(
max_tokens_across_dispatchers, self.moe_config.sp_size
)

num_tokens = hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
):
chunk_start = chunk_start_
chunk_end = min(
chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
Comment on lines +185 to +186
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The clamping logic for chunk_start is incorrect when num_tokens is 0. num_tokens - 1 becomes -1, which can lead to incorrect slicing and potential errors. For example, if num_tokens is 0, chunk_start becomes -1, and slice_size in _slice_and_copy_input becomes 1, while the sliced orig_slice is empty, causing a shape mismatch error in copy_.

The clamping should be against num_tokens instead of num_tokens - 1 to correctly handle empty chunks.

Suggested change
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
chunk_start = min(chunk_start, num_tokens)
chunk_end = min(chunk_end, num_tokens)

chunk_sizes = ctx.dp_metadata.chunked_sizes(
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
)
with chunk_sizes:
hidden_states_chunk = self._slice_and_copy_input(
self.batched_hidden_states,
hidden_states,
chunk_start,
chunk_end,
)

router_logits_chunk = self._slice_and_copy_input(
self.batched_router_logits,
router_logits,
chunk_start,
chunk_end,
)

shared_experts_input_chunk = (
shared_experts_input[chunk_start:chunk_end, :]
if shared_experts_input is not None
else None
)

# Delegate per-chunk computation to the inner runner.
chunk_result = self._inner._forward_impl(
layer=layer,
hidden_states=hidden_states_chunk,
router_logits=router_logits_chunk,
shared_experts_input=shared_experts_input_chunk,
)

# Store outputs
# TODO(bnell): document when chunk_start >= num_tokens
if chunk_start < num_tokens:
if self.shared_experts is not None:
assert isinstance(chunk_result, tuple)
shared_output_chunk, hidden_states_chunk = chunk_result
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
hidden_states_chunk, non_blocking=True
)
assert shared_output_chunk is not None
assert final_shared_hidden_states is not None
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
shared_output_chunk, non_blocking=True
)
else:
assert isinstance(chunk_result, torch.Tensor)
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
chunk_result, non_blocking=True
)

if self.shared_experts is None:
return final_fused_hidden_states
else:
assert final_shared_hidden_states is not None
return (final_shared_hidden_states, final_fused_hidden_states)
Loading
Loading