Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions vllm/model_executor/layers/fused_moe/runner/moe_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _moe_forward(
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
hidden_dim_unpadded: int,
) -> torch.Tensor:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner._forward_impl(
Expand All @@ -110,7 +111,14 @@ def _moe_forward_fake(
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
hidden_dim_unpadded: int,
) -> torch.Tensor:
# `hidden_dim_unpadded > 0` only on the TRT-LLM MXFP4 path, where the
# real kernel writes narrower than `hidden_states.shape[-1]`. Plumbed
# as an op arg (not peeked from the layer registry) to keep the fake
# a pure shape function of its inputs and preserve subgraph dedup.
if hidden_dim_unpadded > 0:
return hidden_states.new_empty((*hidden_states.shape[:-1], hidden_dim_unpadded))
return torch.empty_like(hidden_states)


Expand All @@ -120,6 +128,7 @@ def _moe_forward_shared(
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
hidden_dim_unpadded: int,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(_resolve_layer_name(layer_name))
return layer.runner._forward_impl(
Expand All @@ -137,13 +146,17 @@ def _moe_forward_shared_fake(
shared_experts_input: torch.Tensor | None,
input_ids: torch.Tensor | None,
layer_name: _layer_name_type,
hidden_dim_unpadded: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
# `fused_out`: see `_moe_forward_fake` for hidden_dim_unpadded semantics.
# `shared_out`: matches `shared_experts_input` if provided (latent MoE),
# else `hidden_states`.
if hidden_dim_unpadded > 0:
fused_out = hidden_states.new_empty(
(*hidden_states.shape[:-1], hidden_dim_unpadded)
)
else:
fused_out = torch.empty_like(hidden_states)
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
Expand Down Expand Up @@ -389,6 +402,29 @@ def _encode_layer_name(self) -> str | LayerName:
return "from_forward_context"
return self.layer_name

def _trtllm_mxfp4_unpadded_dim(self) -> int:
"""Return ``hidden_dim_unpadded`` when the active backend is TRT-LLM
MXFP4 (whose kernel writes narrower than the padded
``hidden_states.shape[-1]``), else 0. Other MXFP4 backends (notably
Cutlass MXFP4 MXFP8) write the full padded width, so
``moe_config.hidden_dim_unpadded`` alone is insufficient: it encodes
the model's logical hidden, not whether the kernel narrows. Computed
caller-side and passed as an op arg; doing the isinstance check
inside the fake would specialize per ``layer_name`` and break
subgraph dedup for identical-architecture models (e.g. Phi-MoE).
"""
from vllm.model_executor.layers.fused_moe.experts.trtllm_mxfp4_moe import (
TrtLlmMxfp4ExpertsBase,
)

moe_kernel = getattr(self._quant_method, "moe_kernel", None)
fused_experts = getattr(
getattr(moe_kernel, "impl", None), "fused_experts", None
)
if isinstance(fused_experts, TrtLlmMxfp4ExpertsBase):
return self.moe_config.hidden_dim_unpadded or self.moe_config.hidden_dim
return 0

def _maybe_pad_hidden_states(
self,
shared_experts_input: torch.Tensor | None,
Expand Down Expand Up @@ -577,6 +613,7 @@ def forward(
shared_experts_input,
input_ids,
self._encode_layer_name(),
self._trtllm_mxfp4_unpadded_dim(),
)

#
Expand Down
Loading