Skip to content

[Feature]: Implement naive prepare/finalize class to replace naive dispatching in fused_moe/layer.py#30775

Closed
teddygood wants to merge 14 commits intovllm-project:mainfrom
teddygood:naive-moe-refactor-clean
Closed

[Feature]: Implement naive prepare/finalize class to replace naive dispatching in fused_moe/layer.py#30775
teddygood wants to merge 14 commits intovllm-project:mainfrom
teddygood:naive-moe-refactor-clean

Conversation

@teddygood
Copy link
Copy Markdown

@teddygood teddygood commented Dec 16, 2025

Purpose

  • I’m reopening this PR because the commit history got messed up due to an incorrect rebase in the previous PR (#28570).
  • Resolve #28236 by moving the legacy EP+DP "naive" dispatch/combine path into a dedicated FusedMoENaivePrepareAndFinalize subclass and wiring it through the modular kernel hooks.
  • Extend FusedMoEPrepareAndFinalize with lightweight pre/post hooks so all prepare/finalize implementations can control routing/combine, eliminating the special-case logic from fused_moe/layer.py.
  • Add unit coverage that proves the new subclass dispatches/combines correctly for plain, shared-expert, and zero-expert outputs.

Test Plan

I ran the suggested lm_eval sanity tests using 4×A100 GPUs, serving Qwen/Qwen1.5-MoE-A2.7B with vLLM configured as TP=2, DP=2, and expert parallelism enabled, and evaluated the model via the local completions endpoint.

Test Result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|| 0.01|±  |  0.01|
|     |       |strict-match    |     5|exact_match|| 0.00|±  |  0.00|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@teddygood teddygood changed the title Naive moe refactor clean [Feature]: Implement naive prepare/finalize class to replace naive dispatching in fused_moe/layer.py Dec 16, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MoE implementation to better separate concerns, moving the 'naive' EP+DP dispatch/combine logic into its own FusedMoENaivePrepareAndFinalize class. This is a good architectural improvement that increases modularity. However, I've identified two critical issues in vllm/model_executor/layers/fused_moe/layer.py that need to be addressed: a redundant computation of shared expert outputs and a double reduction of the final hidden states. Both issues will lead to incorrect model outputs.

Comment on lines +1927 to +1952
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if has_separate_shared_experts:
assert self.shared_experts is not None

if self.shared_experts_stream is not None:
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
self.shared_experts_stream.wait_stream(current_stream())

# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
shared_output = self.shared_experts(hidden_states_clone)

# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: we dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)

else:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block for handling shared experts appears to be redundant and introduces a bug. The logic for computing shared_output is already handled later in this function (lines 1983-1985 and 1997-2007). This new block computes shared_output, which is then overwritten. Additionally, the logic for using a separate stream here is incorrect as it doesn't use the use_shared_experts_stream flag, which depends on the number of tokens. This entire block should be removed to fix the bug.

Comment on lines +2027 to +2033
if (
not self.is_sequence_parallel
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This all_reduce operation is redundant. The forward_native method, which calls forward_impl, already performs a reduction on the output. Adding another reduction here will result in a double all_reduce, which is incorrect and will lead to wrong results. This block should be removed to avoid the double reduction.

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @teddygood.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2025
@teddygood teddygood force-pushed the naive-moe-refactor-clean branch from 77e026b to ebc1ecb Compare December 18, 2025 13:09
@mergify mergify bot removed the needs-rebase label Dec 18, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @teddygood.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 22, 2025
@teddygood teddygood force-pushed the naive-moe-refactor-clean branch from ebc1ecb to 8f637d7 Compare December 22, 2025 11:40
@mergify mergify bot removed the needs-rebase label Dec 22, 2025
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
Signed-off-by: teddygood <ibear6954@gmail.com>
@teddygood teddygood force-pushed the naive-moe-refactor-clean branch from 8f637d7 to 8fa8b01 Compare December 22, 2025 11:45
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

robertgshaw2-redhat commented Jan 8, 2026

hey @teddygood - thanks for your hard work here. I am also motivated to make this change.

I think we can have a simpler approach by instead just doing the dispatch/combine on the topk_weight and topk_ids. I have this POC PR which is giving me correctness: #31933

In addition, this is also a performance optimization because we send much less data (only the topk experts).

Would you like to collaborate with me on it?

@teddygood
Copy link
Copy Markdown
Author

hey @teddygood - thanks for your hard work here. I am also motivated to make this change.

I think we can have a simpler approach by instead just doing the dispatch/combine on the topk_weight and topk_ids. I have this POC PR which is giving me correctness: #31933

In addition, this is also a performance optimization because we send much less data (only the topk experts).

Would you like to collaborate with me on it?

hey @robertgshaw2-redhat thanks for the comment. I agree with dispatching/combining only on topk_weights / topk_ids. It looks simpler than what I was doing in #30775, and the reduced communication should help performance. I’d love to collaborate with you on #31933.

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

robertgshaw2-redhat commented Jan 8, 2026

hey @teddygood - thanks for your hard work here. I am also motivated to make this change.
I think we can have a simpler approach by instead just doing the dispatch/combine on the topk_weight and topk_ids. I have this POC PR which is giving me correctness: #31933
In addition, this is also a performance optimization because we send much less data (only the topk experts).
Would you like to collaborate with me on it?

hey @robertgshaw2-redhat thanks for the comment. I agree with dispatching/combining only on topk_weights / topk_ids. It looks simpler than what I was doing in #30775, and the reduced communication should help performance. I’d love to collaborate with you on #31933.

sounds good, lets discuss via slack. Ill close this PR for now if thats okay?

@teddygood
Copy link
Copy Markdown
Author

hey @teddygood - thanks for your hard work here. I am also motivated to make this change.
I think we can have a simpler approach by instead just doing the dispatch/combine on the topk_weight and topk_ids. I have this POC PR which is giving me correctness: #31933
In addition, this is also a performance optimization because we send much less data (only the topk experts).
Would you like to collaborate with me on it?

hey @robertgshaw2-redhat thanks for the comment. I agree with dispatching/combining only on topk_weights / topk_ids. It looks simpler than what I was doing in #30775, and the reduced communication should help performance. I’d love to collaborate with you on #31933.

sounds good, lets discuss via slack. Ill close this PR for now if thats okay?

sure, thanks.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @teddygood.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Implement naive prepare/finalize class to replace naive dispatching in fused_moe/layer.py

2 participants