diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c4fc1fd2557e..80c884c12baa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -19,7 +19,7 @@ ) from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -214,8 +214,8 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: # --8<-- [start:fused_moe] -@CustomOp.register("fused_moe") -class FusedMoE(CustomOp): +@PluggableLayer.register("fused_moe") +class FusedMoE(PluggableLayer): """FusedMoE layer for MoE models. This layer contains both MergedColumnParallel weights (gate_up_proj / @@ -1520,7 +1520,7 @@ def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tens """ return self.runner.maybe_all_reduce_tensor_model_parallel(final_hidden_states) - def forward_native( + def forward( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1536,13 +1536,6 @@ def expert_map(self) -> torch.Tensor | None: self._expert_map if not self.rocm_aiter_fmoe_enabled else self.expert_mask ) - def forward_cuda( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return self.forward_native(hidden_states, router_logits) - @classmethod def make_expert_params_mapping( cls, diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index f65a197abcfc..81d21abbd069 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -24,7 +24,7 @@ from vllm.config.utils import getattr_iter from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.utils import maybe_prefix @@ -38,7 +38,7 @@ # --8<-- [start:transformers_fused_moe] -@CustomOp.register("transformers_fused_moe") +@PluggableLayer.register("transformers_fused_moe") class TransformersFusedMoE(FusedMoE): """Custom FusedMoE for the Transformers modeling backend."""