From 992cd023aff102b52fa680a5e2139a98b872d34a Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Sat, 25 Apr 2026 09:10:04 +0300 Subject: [PATCH 1/3] [Bugfix][MoE] Only unpad routed output before shared expert add Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/runner/moe_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index d37e29938f4b..7941be9fbe1c 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -552,6 +552,7 @@ def forward( # Record before `_maybe_pad_hidden_states` pads activations to match # `moe_config.hidden_dim`, e.g. after `align_trtllm_fp4_moe_hidden_dim_for_fi` + # so routed output can be trimmed before shared+routed add if needed. routed_hidden_dim = hidden_states.shape[-1] hidden_states, og_hidden_dim = self._maybe_pad_hidden_states( shared_experts_input, @@ -577,7 +578,7 @@ def forward( # Extract outputs from result shared_output, fused_output = _unpack(result) - if hidden_dim_was_padded: + if shared_output is not None and hidden_dim_was_padded: fused_output = fused_output[..., :routed_hidden_dim] # If combine kernel already reduced fused, reduce shared to match. From ab253dd030c4e3d62d7e25c32a2c57d2771d316b Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Sat, 25 Apr 2026 22:15:18 +0300 Subject: [PATCH 2/3] still truncate if self.routed_output_transform is not None, even if there are no shared experts, for latent models Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/runner/moe_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 7941be9fbe1c..d55f6340afaf 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -578,7 +578,9 @@ def forward( # Extract outputs from result shared_output, fused_output = _unpack(result) - if shared_output is not None and hidden_dim_was_padded: + if ( + shared_output is not None or self.routed_output_transform is not None + ) and hidden_dim_was_padded: fused_output = fused_output[..., :routed_hidden_dim] # If combine kernel already reduced fused, reduce shared to match. From 71acc3eb8f6402d32ca984e811029ce2dfe80916 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Sat, 25 Apr 2026 22:16:48 +0300 Subject: [PATCH 3/3] fix comment Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/runner/moe_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index d55f6340afaf..6b13c0c36323 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -552,7 +552,8 @@ def forward( # Record before `_maybe_pad_hidden_states` pads activations to match # `moe_config.hidden_dim`, e.g. after `align_trtllm_fp4_moe_hidden_dim_for_fi` - # so routed output can be trimmed before shared+routed add if needed. + # so routed output can be trimmed before + # shared+routed add / latent up proj if needed. routed_hidden_dim = hidden_states.shape[-1] hidden_states, og_hidden_dim = self._maybe_pad_hidden_states( shared_experts_input,