Skip to content

[MoE Refactor] Split of DefaultMoERunner class#35326

Merged
robertgshaw2-redhat merged 54 commits intovllm-project:mainfrom
neuralmagic:moe-runner-3
Apr 6, 2026
Merged

[MoE Refactor] Split of DefaultMoERunner class#35326
robertgshaw2-redhat merged 54 commits intovllm-project:mainfrom
neuralmagic:moe-runner-3

Conversation

@bnellnm
Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm commented Feb 25, 2026

Purpose

Split DefaultMoERunner into two more classes:

  • MoERunnerBase that contains common code for all runners.
  • ChunkingMoERunner a runner that handles DP chunking. Inherits from MoERunnerBase. Acts as a wrapper around another MoERunner instance in order to be usable with any runner type.

DefaultMoERunner inherits from MoERunnerBase and only handles the non-chunked/naive execution path.

In ChunkingMoERunner, allocate the backing buffers for chunking using the workspace class so there won't be one set of buffers per layer.

Based off #35153

Test Plan

tests/kernels/moe
tests/lora/{test_gptoss_tp.py, test_olmoe.py}
Hand testing of Deepseek and other models w/EP+DP.
MoE refactor CI tests

Test Result

cc @robertgshaw2-redhat , @yzong-rh


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the MoE runner logic. It splits the monolithic DefaultMoERunner into a MoERunnerBase, a ChunkedMoERunner for DP chunking, and a DefaultMoERunner for the non-chunked path. Additionally, it introduces a SharedExperts class to encapsulate the logic for shared experts, which greatly improves code organization and clarity. The overall changes make the MoE execution path more modular and easier to understand.

I've found one critical issue in the new ChunkingMoERunner related to incorrect index clamping when handling chunks, which could lead to runtime errors. I've provided a suggestion to fix this. Apart from that, the refactoring is excellent.

Comment on lines +187 to +188
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The clamping logic for chunk_start is incorrect when num_tokens is 0. num_tokens - 1 becomes -1, which can lead to incorrect slicing and potential errors. For example, if num_tokens is 0, chunk_start becomes -1, and slice_size in _slice_and_copy_input becomes 1, while the sliced orig_slice is empty, causing a shape mismatch error in copy_.

The clamping should be against num_tokens instead of num_tokens - 1 to correctly handle empty chunks.

Suggested change
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
chunk_start = min(chunk_start, num_tokens)
chunk_end = min(chunk_end, num_tokens)

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Feb 26, 2026
@bnellnm bnellnm marked this pull request as ready for review March 2, 2026 00:03
@bnellnm bnellnm changed the title [MoE Refactor] Initial split of DefaultMoERunner class [MoE Refactor] Split of DefaultMoERunner class Mar 2, 2026
bnellnm added 4 commits March 18, 2026 16:48
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 18, 2026

Hi @bnellnm, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

bnellnm added 2 commits March 19, 2026 16:44
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
assert shared_experts_input is not None
self._shared_experts.apply(shared_experts_input, order)

def _apply_quant_method(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: in a follow up, we should change this name. its not really applying the quant method if anything its running the whole layer

# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer_name

self.forward_entry = self._select_forward()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a follow up, changing this name to "custom_op_entry" or something like that would make things a lot clearer

out_slice.copy_(orig_slice, non_blocking=True)
return out_slice

def _forward_impl(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really needed to chunk the shared expert? I actually dont quite remember why this was required in the first place. I think that we could simplify things a lot if we only chunked the grouped expert

Shouldnt be done in this PR, but something to consider for the follow up

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes things easier if it is chunked since the MK naturally chunks it if it is the one executing it.

logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)

# Does this need some kind of profiling run check like modular_kernel.py?
return current_workspace_manager().get_simultaneous(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this current_workspace_manager introduced?

I didn't see this in previous implementation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously, it was

device = torch.accelerator.current_device_index()
self.batched_hidden_states = torch.zeros(
    states_shape,
    dtype=moe.in_dtype,
    device=device,
)

self.batched_router_logits = torch.zeros(
    logits_shape,
    dtype=moe.router_logits_dtype,
    device=device,
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the workspace is now shared among all layers. Previously there were separate buffers for each layer.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 6, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 6, 2026
Signed-off-by: Bill Nell <bnell@redhat.com>
@mergify mergify Bot removed the needs-rebase label Apr 6, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit 93bada4 into vllm-project:main Apr 6, 2026
179 of 181 checks passed
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA Apr 6, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Apr 10, 2026
…xtral, MoE and Granite regressions (#1311)

## Summary
This PR fixes a set of regressions introduced by recent upstream changes
and observed in vLLM-Gaudi hourly validation.

The branch now includes:
- Pixtral HPUAttention projection path fix,
- MoE dispatch and method override alignment updates for fused MoE and
compressed tensors,
- unit test updates to match the new MoE runner API usage,
- fix hybrid model page size alignment for Granite 4.0-H.

## Related upstream PRs that introduced the regressions
- vllm-project/vllm#37234
- vllm-project/vllm#35153
- vllm-project/vllm#36963
- vllm-project/vllm#38960
- vllm-project/vllm#35326
- vllm-project/vllm#37467

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
@bnellnm bnellnm deleted the moe-runner-3 branch April 15, 2026 20:48
Comment on lines +508 to +509
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this cause gate be called twice? eg, qwen3_moe will calculate gate in modeling file https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_moe.py#L235-L237

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look like it might be getting called twice in qwen3_moe.py. If the gate is passed to FusedMoE then the model should check whether or not to run it, e.g. see if self.experts.is_internal_router in deepseek_v2.py. This behavior isn't new so the models need to decide whether or not they want the FusedMoE to handle it or not.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should fix this urgently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants