Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
Expand Down Expand Up @@ -214,8 +214,8 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:


# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
@PluggableLayer.register("fused_moe")
class FusedMoE(PluggableLayer):
Comment thread
ProExpertProg marked this conversation as resolved.
"""FusedMoE layer for MoE models.

This layer contains both MergedColumnParallel weights (gate_up_proj /
Expand Down Expand Up @@ -1520,7 +1520,7 @@ def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tens
"""
return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states)

def forward_native(
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -1536,13 +1536,6 @@ def expert_map(self) -> torch.Tensor | None:
self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask
)

def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits)

@classmethod
def make_expert_params_mapping(
cls,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/transformers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import maybe_prefix
Expand All @@ -38,7 +38,7 @@


# --8<-- [start:transformers_fused_moe]
@CustomOp.register("transformers_fused_moe")
@PluggableLayer.register("transformers_fused_moe")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the avoidance of doubt, this is a custom to allow us to accept topk_ids in FusedMoE.forward and have it reappear as the output of custom_routing_function when torch.compile/CUDA Graphs are enabled.

Copy link
Copy Markdown
Contributor Author

@whx-sjtu whx-sjtu Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I treat this class as PluggableLayer because it doesn't have different implementations for different in-tree platforms. Do you mean that this extra functionality of transformers_fused_moe needs to be compiled by torch through CustomOp.maybe_compile? @hmellor

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just annotating this particular change so that nobody thinks it can be removed entirely.

A quick way to check that it still works is to install transformers from main and run the following test in vLLM pytest tests/models/test_transformers.py -k olmoe -vsx. If this still passes, then the change from CustomOp to PluggableLayer is ok

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor I don't understand your comment either - CustomOp class doesn't affect compilation/Dynamo/cudagraphs. Are you talking about direct_register_custom_op?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Yeah it sounds like I should have been referencing direct_register_custom_op.

TL;DR the Transformers modelling backend needs direct_register_custom_op to support compilation (torch and CUDA Graphs) with MoE models

class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers modeling backend."""

Expand Down
Loading