diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 85619a91005c..fca4096b0860 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -37,7 +37,7 @@ get_eplb_group, ) from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE, fused_experts +from vllm.model_executor.layers.fused_moe import FusedMoE, fused_experts from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.router.router_factory import ( @@ -858,11 +858,7 @@ def make_fused_moe_layer( quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts) kwargs = dict() - if shared_experts is None: - builder = FusedMoE - else: - builder = SharedFusedMoE - kwargs["shared_experts"] = shared_experts + kwargs["shared_experts"] = shared_experts # Add gate and routed_input_transform if provided if gate is not None: @@ -872,7 +868,7 @@ def make_fused_moe_layer( kwargs["routed_input_transform"] = routed_input_transform kwargs["routed_output_transform"] = routed_output_transform - layer = builder( + layer = FusedMoE( num_experts=global_num_experts, top_k=top_k, hidden_size=hidden_size, diff --git a/tests/kernels/moe/test_shared_fused_moe_routed_transform.py b/tests/kernels/moe/test_shared_fused_moe_routed_transform.py index 464754c9f1b0..4515021a4e91 100644 --- a/tests/kernels/moe/test_shared_fused_moe_routed_transform.py +++ b/tests/kernels/moe/test_shared_fused_moe_routed_transform.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Tests for SharedFusedMoE with routed_input_transform. +Tests for FusedMoE with routed_input_transform. -Verifies that applying routed_input_transform inside SharedFusedMoE +Verifies that applying routed_input_transform inside FusedMoE produces the same results as applying the transform manually outside. """ @@ -13,7 +13,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer, set_random_seed @@ -133,9 +133,9 @@ def test_routed_input_transform_inside_vs_outside( workspace_init, monkeypatch, ): - """Compare SharedFusedMoE with transform inside vs manually applying outside. - Method A (inside): SharedFusedMoE with routed_input_transform - Method B (outside): Manually transform, then SharedFusedMoE without transform + """Compare FusedMoE with transform inside vs manually applying outside. + Method A (inside): FusedMoE with routed_input_transform + Method B (outside): Manually transform, then FusedMoE without transform """ if current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0") @@ -157,8 +157,8 @@ def test_routed_input_transform_inside_vs_outside( routed_transform = SimpleLinear(hidden_size, latent_size, dtype) with set_current_vllm_config(vllm_config): - # Method A: SharedFusedMoE WITH routed_input_transform - moe_with_transform = SharedFusedMoE( + # Method A: FusedMoE WITH routed_input_transform + moe_with_transform = FusedMoE( shared_experts=shared_experts, routed_input_transform=routed_transform, num_experts=num_experts, @@ -173,9 +173,9 @@ def test_routed_input_transform_inside_vs_outside( prefix="moe_with_transform", ) - # Method B: SharedFusedMoE WITHOUT routed_input_transform + # Method B: FusedMoE WITHOUT routed_input_transform # Note: shared_experts=None because when transform is done outside, - moe_without_transform = SharedFusedMoE( + moe_without_transform = FusedMoE( shared_experts=None, routed_input_transform=None, num_experts=num_experts, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 2125f7381fe2..0b4b81f93bb4 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -7,6 +7,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from vllm.utils import is_moe_layer + class Cache: def __init__(self): @@ -317,16 +319,7 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None if not self.is_ep_communicator: return - moe_modules = [ - module - for module in model.modules() - # TODO(bnell): Should use isinstance but can't. Maybe search for - # presence of quant_method.maybe_init_modular_kernel? - if ( - module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE" - ) - ] + moe_modules = [module for module in model.modules() if is_moe_layer(module)] for module in moe_modules: module.maybe_init_modular_kernel() diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index a316a54bd519..24979b62af6d 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -38,6 +38,7 @@ from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.utils import is_moe_layer from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.workspace import lock_workspace, unlock_workspace @@ -319,10 +320,7 @@ def switch_and_prepare(self) -> None: moe_modules = [ module for module in self.worker.model_runner.model.modules() - if ( - module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE" - ) + if is_moe_layer(module) ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all( diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index d6eec675c6dd..b07b471c4b89 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -610,7 +610,7 @@ def can_replace_layer( ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" - # source_layer is FusedMoE or SharedFusedMoE + # source_layer is FusedMoE return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2 @@ -772,5 +772,5 @@ def can_replace_layer( model_config: PretrainedConfig | None = None, ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" - # source_layer is FusedMoE or SharedFusedMoE + # source_layer is FusedMoE return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1 diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 926f0d1d0154..1b2ce61f7c8d 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -29,7 +29,6 @@ FusedMoERouter, ) from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) @@ -64,7 +63,6 @@ def get_config() -> dict[str, Any] | None: "FusedMoEPrepareAndFinalizeModular", "GateLinear", "RoutingMethodType", - "SharedFusedMoE", "activation_without_mul", "apply_moe_activation", "override_config", diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py deleted file mode 100644 index 9cfcb1baa9bb..000000000000 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.model_executor.layers.fused_moe.layer import FusedMoE - - -# TODO(bnell): Remove this entirely -class SharedFusedMoE(FusedMoE): - """ - A FusedMoE operation that also computes the results of shared experts. - If an all2all communicator is being used the shared expert computation - can be interleaved with the fused all2all dispatch communication step. - """ - - def forward( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor: - return super().forward( - hidden_states=hidden_states, - router_logits=router_logits, - ) diff --git a/vllm/model_executor/models/AXK1.py b/vllm/model_executor/models/AXK1.py index d42fbed42ae2..c33d5b973722 100644 --- a/vllm/model_executor/models/AXK1.py +++ b/vllm/model_executor/models/AXK1.py @@ -42,7 +42,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -163,7 +163,7 @@ def __init__( prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, gate=self.gate, num_experts=config.n_routed_experts, @@ -916,7 +916,7 @@ def compute_logits( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -950,7 +950,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/afmoe.py b/vllm/model_executor/models/afmoe.py index 5bad52a0c496..e34a418c9814 100644 --- a/vllm/model_executor/models/afmoe.py +++ b/vllm/model_executor/models/afmoe.py @@ -18,7 +18,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -124,8 +124,8 @@ def __init__( prefix=f"{prefix}.shared_experts", ) - # Routed experts using SharedFusedMoE - self.experts = SharedFusedMoE( + # Routed experts using FusedMoE + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.num_experts, top_k=config.num_experts_per_tok, @@ -479,7 +479,7 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -637,7 +637,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers self.num_expert_groups = config.n_group - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers: list[FusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 7a079c565402..9696dec6d877 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -14,7 +14,7 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -214,7 +214,7 @@ def forward( return out -class AriaFusedMoE(SharedFusedMoE): +class AriaFusedMoE(FusedMoE): def weight_loader( self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str ) -> None: diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 510d605f8046..ef4f66614a3f 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -41,7 +41,7 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -285,7 +285,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=self.num_experts, top_k=self.top_k, @@ -461,7 +461,7 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py index a63ad83f45bb..df36659b10c3 100644 --- a/vllm/model_executor/models/bailing_moe_linear.py +++ b/vllm/model_executor/models/bailing_moe_linear.py @@ -21,7 +21,7 @@ RMSNormGated, layernorm_fn, ) -from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -351,8 +351,8 @@ def __init__( else: self.shared_experts = None - # Routed experts using SharedFusedMoE - self.experts = SharedFusedMoE( + # Routed experts using FusedMoE + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=self.num_experts, top_k=self.top_k, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index a66ec7aa3e69..898e4333409f 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -11,7 +11,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -252,7 +252,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] stacked_params_mapping.extend(indexer_fused_mapping) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1b01caded94c..53bcf87c6cc6 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -48,9 +48,9 @@ from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import ( + FusedMoE, GateLinear, RoutingMethodType, - SharedFusedMoE, ) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( @@ -311,7 +311,7 @@ def __init__( prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, gate=self.gate, num_experts=config.n_routed_experts, @@ -1432,7 +1432,7 @@ def compute_logits( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -1474,7 +1474,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index c176b7365689..181bd598e8e1 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -40,7 +40,7 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -155,7 +155,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -413,7 +413,7 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index c92e230bcd21..58dd61e9d928 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -42,7 +42,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -188,7 +188,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.moe_num_experts, top_k=config.moe_k, @@ -485,7 +485,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -667,7 +667,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = len(moe_layers_indices) self.num_expert_groups = 1 - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers: list[FusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index e4b7ac6fb006..b4e7af9304b1 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -36,7 +36,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -257,7 +257,7 @@ def __init__( prefix=f"{prefix}.text_experts_gate", ) - self.text_experts = SharedFusedMoE( + self.text_experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.moe_num_experts[0], top_k=config.moe_k, @@ -294,7 +294,7 @@ def __init__( prefix=f"{prefix}.vision_experts_gate", ) - self.vision_experts = SharedFusedMoE( + self.vision_experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.moe_num_experts[1], top_k=config.moe_k, @@ -649,7 +649,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/exaone_moe.py b/vllm/model_executor/models/exaone_moe.py index a46cadf007ee..dd91a1896289 100644 --- a/vllm/model_executor/models/exaone_moe.py +++ b/vllm/model_executor/models/exaone_moe.py @@ -31,7 +31,6 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -130,7 +129,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, gate=self.gate, num_experts=self.n_routed_experts, diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 671e868da0ad..680e74609927 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -42,7 +42,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -178,7 +178,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -466,7 +466,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/glm4_moe_lite.py b/vllm/model_executor/models/glm4_moe_lite.py index 6d96f748e3ea..5dc33ec18bff 100644 --- a/vllm/model_executor/models/glm4_moe_lite.py +++ b/vllm/model_executor/models/glm4_moe_lite.py @@ -41,7 +41,7 @@ get_pp_group, ) from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -308,7 +308,7 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -334,7 +334,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -616,7 +616,7 @@ def compute_logits( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/glm4_moe_lite_mtp.py b/vllm/model_executor/models/glm4_moe_lite_mtp.py index efa96c40d042..e00476abac6a 100644 --- a/vllm/model_executor/models/glm4_moe_lite_mtp.py +++ b/vllm/model_executor/models/glm4_moe_lite_mtp.py @@ -32,7 +32,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig -from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -260,7 +260,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 35d30006a66a..9d3ebe4ed9c7 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -42,7 +42,7 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -438,7 +438,7 @@ def __init__( else: self.shared_mlp = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_mlp, num_experts=self.n_routed_experts, top_k=top_k, @@ -712,7 +712,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: if _is_moe(self.config): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index e586a3ac3469..21940fb2e1f5 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -14,7 +14,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.kda import KimiDeltaAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -144,7 +144,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=num_experts, top_k=config.num_experts_per_token, @@ -476,7 +476,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if self.config.is_moe: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index a1c0ac896052..c9495a743b7f 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -36,7 +36,7 @@ Attention, ChunkedLocalAttention, ) -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -127,7 +127,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.n_physical_experts = self.n_local_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_expert, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, @@ -414,7 +414,7 @@ def load_moe_expert_weights( params_dict: The dictionary of module parameters. loaded_params: The set of already loaded parameters. expert_params_mapping: The mapping of expert parameters. Must be - generated by SharedFusedMoE.make_expert_params_mapping(). + generated by FusedMoE.make_expert_params_mapping(). fused: Whether the expert weights are fused into a single weight tensor or are separate weight tensors for each expert. When fused is True, loaded_weight should have shape of: @@ -554,7 +554,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: fused_experts_params = False # Expert parameter mapping for the case where the expert weights are # not fused into a single weight tensor. - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -564,7 +564,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. - expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index fa068639648e..9b8ed68560cf 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -34,8 +34,8 @@ from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import ( + FusedMoE, GateLinear, - SharedFusedMoE, activation_without_mul, ) from vllm.model_executor.layers.layernorm import RMSNorm @@ -210,7 +210,7 @@ def __init__( self.fc1_latent_proj = None self.fc2_latent_proj = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -652,7 +652,7 @@ def _get_max_n_routed_experts(self) -> int: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: if self.has_moe: # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's # what the activation is applied to # - FusedMoe.w3 (aka up_proj) should be ignored since we're diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 7de84da51935..96b837e42a8d 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -44,7 +44,7 @@ Attention, StaticSinkAttention, ) -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -200,7 +200,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -1149,7 +1149,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] has_experts = hasattr(self.config, "n_routed_experts") if has_experts: - expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping( + expert_merge_mapping = FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/param2moe.py b/vllm/model_executor/models/param2moe.py index fddd1a8f1733..4d1b3ff1b991 100644 --- a/vllm/model_executor/models/param2moe.py +++ b/vllm/model_executor/models/param2moe.py @@ -32,7 +32,7 @@ ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -353,7 +353,7 @@ def __init__( else: self.shared_experts = None # type: ignore[assignment] - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=self.num_experts, top_k=self.top_k, @@ -370,7 +370,7 @@ def __init__( routed_scaling_factor=self.routed_scaling_factor, ) - def maybe_get_fused_moe(self) -> SharedFusedMoE: + def maybe_get_fused_moe(self) -> FusedMoE: return self.experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -690,7 +690,7 @@ def load_weights( return loaded_params def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b5d13e926d7c..7fc3c6a7dde4 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -40,7 +40,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -164,7 +164,7 @@ def __init__( else: self.shared_expert = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_expert, num_experts=config.num_experts, top_k=config.num_experts_per_tok, @@ -418,7 +418,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index f0f69d435379..6f080d07795e 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,7 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -205,7 +205,7 @@ def __init__( self.shared_expert_gate = None self.shared_expert = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_expert, gate=self.gate, num_experts=self.n_routed_experts, @@ -508,7 +508,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 50d44dbbf635..2a4021be6e40 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -23,7 +23,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) @@ -146,7 +146,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): else: self.shared_expert = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_expert, gate=self.gate, num_experts=self.n_routed_experts, @@ -533,7 +533,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/sarvam.py b/vllm/model_executor/models/sarvam.py index 3656fc921b25..c770e2032000 100644 --- a/vllm/model_executor/models/sarvam.py +++ b/vllm/model_executor/models/sarvam.py @@ -35,7 +35,7 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -335,7 +335,7 @@ def __init__( else: self.shared_experts = None - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.shared_experts, num_experts=self.num_experts, top_k=self.top_k, @@ -352,7 +352,7 @@ def __init__( routed_scaling_factor=self.routed_scaling_factor, ) - def maybe_get_fused_moe(self) -> SharedFusedMoE: + def maybe_get_fused_moe(self) -> FusedMoE: return self.experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -529,7 +529,7 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return SharedFusedMoE.make_expert_params_mapping( + return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 018f78956029..7ed38a3a6f90 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -372,7 +371,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.share_expert", ) - self.experts = SharedFusedMoE( + self.experts = FusedMoE( shared_experts=self.share_expert, gate=self.gate, num_experts=config.moe_num_experts, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9b481d63990b..bf455c261f4f 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -34,3 +34,16 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_embeds={prompt_embeds_len}" ) return prompt_token_len + + +def is_moe_layer(module: torch.nn.Module) -> bool: + # TODO(bnell): Should use isinstance but can't due to circular dependencies. + def _check_bases(cls): + if cls.__name__ == "FusedMoE": + return True + + for b in cls.__bases__: + if _check_bases(b): + return True + + return _check_bases(module.__class__)