Skip to content

[Bugfix][MoE] Unpad routed output before shared expert add [Fixes #35949]#40794

Merged
tomeras91 merged 1 commit intovllm-project:mainfrom
netanel-haber:bugfix/truncate-padded-fused-output-before-adding-to-shared-output
Apr 24, 2026
Merged

[Bugfix][MoE] Unpad routed output before shared expert add [Fixes #35949]#40794
tomeras91 merged 1 commit intovllm-project:mainfrom
netanel-haber:bugfix/truncate-padded-fused-output-before-adding-to-shared-output

Conversation

@netanel-haber
Copy link
Copy Markdown
Contributor

@netanel-haber netanel-haber commented Apr 24, 2026

Fixes https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4.

FI TRTLLM NVFP4 MoE can pad the routed hidden dim, e.g. 2688 -> 2816, via align_trtllm_fp4_moe_hidden_dim_for_fi.

Before #35949, FusedMoE returned routed and shared outputs separately. The routed output was truncated back to the original hidden dim before model code added the shared expert output, so the world looked like:

routed kernel output: [tokens, 2816] -> truncate -> [tokens, 2688]
shared output:        [tokens, 2688]
add:                  [tokens, 2688] + [tokens, 2688]

#35949 moved the shared/routed add into MoERunner. That changed the order to add first and truncate later:

routed kernel output: [tokens, 2816]
shared output:        [tokens, 2688]
add:                  [tokens, 2816] + [tokens, 2688]

Dynamo catches this during fake tensor tracing as a shape mismatch.

This PR records the routed hidden dim before _maybe_pad_hidden_states() and trims the fused routed output back to that dim before shared expert addition.

DailyOmni is on par for nano-v3-omni before and after this pr.

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the bug Something isn't working label Apr 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces logic to handle hidden dimension padding in the Fused MoE runner. It records the original hidden dimension before potential padding and ensures that the fused output is sliced back to its original size if padding was applied. I have no feedback to provide as there are no review comments to evaluate.

@netanel-haber netanel-haber changed the title [Bugfix][MoE] Unpad routed output before shared expert add [Bugfix][MoE] Unpad routed output before shared expert add [Fixes #35949] Apr 24, 2026
@tomeras91
Copy link
Copy Markdown
Member

Thanks! Approved

Cc @robertgshaw2-redhat

BTW @netanel-haber - do you know how this works with latentMoE (regardless of padding)? Is the routed hidden states are added to the shared hidden states only after applying the latent up proj to match hidden dims again?

@tomeras91 tomeras91 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 24, 2026
@tomeras91 tomeras91 enabled auto-merge (squash) April 24, 2026 10:13
@netanel-haber
Copy link
Copy Markdown
Contributor Author

Thanks! Approved

Cc @robertgshaw2-redhat

BTW @netanel-haber - do you know how this works with latentMoE (regardless of padding)? Is the routed hidden states are added to the shared hidden states only after applying the latent up proj to match hidden dims again?

See image re latent: image

@tomeras91 tomeras91 merged commit e8eb049 into vllm-project:main Apr 24, 2026
72 checks passed
hnt2601 pushed a commit to hnt2601/vllm that referenced this pull request Apr 25, 2026
@bnellnm
Copy link
Copy Markdown
Collaborator

bnellnm commented Apr 25, 2026

I think this change might have broken lora/test_gptoss_tp.py::test_gpt_oss_lora_tp2[True-False] again.

@netanel-haber
Copy link
Copy Markdown
Contributor Author

netanel-haber commented Apr 25, 2026

Re gptoss test breakages,
See: #40865

@bnellnm bnellnm mentioned this pull request Apr 25, 2026
4 tasks
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…m-project#35949] (vllm-project#40794)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…m-project#35949] (vllm-project#40794)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Adrian <info@zzit.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants