diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 2eee8acf6b8f..8a7e816b14f6 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -93,6 +93,7 @@ def _moe_forward( shared_experts_input: torch.Tensor | None, input_ids: torch.Tensor | None, layer_name: _layer_name_type, + hidden_dim_unpadded: int, ) -> torch.Tensor: layer = get_layer_from_name(_resolve_layer_name(layer_name)) return layer.runner._forward_impl( @@ -110,7 +111,14 @@ def _moe_forward_fake( shared_experts_input: torch.Tensor | None, input_ids: torch.Tensor | None, layer_name: _layer_name_type, + hidden_dim_unpadded: int, ) -> torch.Tensor: + # `hidden_dim_unpadded > 0` only on the TRT-LLM MXFP4 path, where the + # real kernel writes narrower than `hidden_states.shape[-1]`. Plumbed + # as an op arg (not peeked from the layer registry) to keep the fake + # a pure shape function of its inputs and preserve subgraph dedup. + if hidden_dim_unpadded > 0: + return hidden_states.new_empty((*hidden_states.shape[:-1], hidden_dim_unpadded)) return torch.empty_like(hidden_states) @@ -120,6 +128,7 @@ def _moe_forward_shared( shared_experts_input: torch.Tensor | None, input_ids: torch.Tensor | None, layer_name: _layer_name_type, + hidden_dim_unpadded: int, ) -> tuple[torch.Tensor, torch.Tensor]: layer = get_layer_from_name(_resolve_layer_name(layer_name)) return layer.runner._forward_impl( @@ -137,13 +146,17 @@ def _moe_forward_shared_fake( shared_experts_input: torch.Tensor | None, input_ids: torch.Tensor | None, layer_name: _layer_name_type, + hidden_dim_unpadded: int, ) -> tuple[torch.Tensor, torch.Tensor]: - # Output shapes: - # - fused_out: same as hidden_states (routed experts use transformed size) - # - shared_out: same as shared_experts_input if provided, else same as - # hidden_states - # (For latent MoE: shared experts use original hidden_size, not latent size) - fused_out = torch.empty_like(hidden_states) + # `fused_out`: see `_moe_forward_fake` for hidden_dim_unpadded semantics. + # `shared_out`: matches `shared_experts_input` if provided (latent MoE), + # else `hidden_states`. + if hidden_dim_unpadded > 0: + fused_out = hidden_states.new_empty( + (*hidden_states.shape[:-1], hidden_dim_unpadded) + ) + else: + fused_out = torch.empty_like(hidden_states) if shared_experts_input is not None: shared_out = torch.empty_like(shared_experts_input) else: @@ -389,6 +402,29 @@ def _encode_layer_name(self) -> str | LayerName: return "from_forward_context" return self.layer_name + def _trtllm_mxfp4_unpadded_dim(self) -> int: + """Return ``hidden_dim_unpadded`` when the active backend is TRT-LLM + MXFP4 (whose kernel writes narrower than the padded + ``hidden_states.shape[-1]``), else 0. Other MXFP4 backends (notably + Cutlass MXFP4 MXFP8) write the full padded width, so + ``moe_config.hidden_dim_unpadded`` alone is insufficient: it encodes + the model's logical hidden, not whether the kernel narrows. Computed + caller-side and passed as an op arg; doing the isinstance check + inside the fake would specialize per ``layer_name`` and break + subgraph dedup for identical-architecture models (e.g. Phi-MoE). + """ + from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import ( + TrtLlmMxfp4ExpertsBase, + ) + + moe_kernel = getattr(self._quant_method, "moe_kernel", None) + fused_experts = getattr( + getattr(moe_kernel, "impl", None), "fused_experts", None + ) + if isinstance(fused_experts, TrtLlmMxfp4ExpertsBase): + return self.moe_config.hidden_dim_unpadded or self.moe_config.hidden_dim + return 0 + def _maybe_pad_hidden_states( self, shared_experts_input: torch.Tensor | None, @@ -577,6 +613,7 @@ def forward( shared_experts_input, input_ids, self._encode_layer_name(), + self._trtllm_mxfp4_unpadded_dim(), ) #