Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions tests/kernels/moe/test_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
get_eplb_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE, fused_experts
from vllm.model_executor.layers.fused_moe import FusedMoE, fused_experts
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.router.router_factory import (
Expand Down Expand Up @@ -858,11 +858,7 @@ def make_fused_moe_layer(
quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts)

kwargs = dict()
if shared_experts is None:
builder = FusedMoE
else:
builder = SharedFusedMoE
kwargs["shared_experts"] = shared_experts
kwargs["shared_experts"] = shared_experts

# Add gate and routed_input_transform if provided
if gate is not None:
Expand All @@ -872,7 +868,7 @@ def make_fused_moe_layer(
kwargs["routed_input_transform"] = routed_input_transform
kwargs["routed_output_transform"] = routed_output_transform

layer = builder(
layer = FusedMoE(
num_experts=global_num_experts,
top_k=top_k,
hidden_size=hidden_size,
Expand Down
20 changes: 10 additions & 10 deletions tests/kernels/moe/test_shared_fused_moe_routed_transform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for SharedFusedMoE with routed_input_transform.
Tests for FusedMoE with routed_input_transform.

Verifies that applying routed_input_transform inside SharedFusedMoE
Verifies that applying routed_input_transform inside FusedMoE
produces the same results as applying the transform manually outside.
"""

Expand All @@ -13,7 +13,7 @@

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer, set_random_seed

Expand Down Expand Up @@ -133,9 +133,9 @@ def test_routed_input_transform_inside_vs_outside(
workspace_init,
monkeypatch,
):
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
Method A (inside): SharedFusedMoE with routed_input_transform
Method B (outside): Manually transform, then SharedFusedMoE without transform
"""Compare FusedMoE with transform inside vs manually applying outside.
Method A (inside): FusedMoE with routed_input_transform
Method B (outside): Manually transform, then FusedMoE without transform
"""
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
Expand All @@ -157,8 +157,8 @@ def test_routed_input_transform_inside_vs_outside(
routed_transform = SimpleLinear(hidden_size, latent_size, dtype)

with set_current_vllm_config(vllm_config):
# Method A: SharedFusedMoE WITH routed_input_transform
moe_with_transform = SharedFusedMoE(
# Method A: FusedMoE WITH routed_input_transform
moe_with_transform = FusedMoE(
shared_experts=shared_experts,
routed_input_transform=routed_transform,
num_experts=num_experts,
Expand All @@ -173,9 +173,9 @@ def test_routed_input_transform_inside_vs_outside(
prefix="moe_with_transform",
)

# Method B: SharedFusedMoE WITHOUT routed_input_transform
# Method B: FusedMoE WITHOUT routed_input_transform
# Note: shared_experts=None because when transform is done outside,
moe_without_transform = SharedFusedMoE(
moe_without_transform = FusedMoE(
shared_experts=None,
routed_input_transform=None,
num_experts=num_experts,
Expand Down
13 changes: 3 additions & 10 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.utils import is_moe_layer


class Cache:
def __init__(self):
Expand Down Expand Up @@ -317,16 +319,7 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None
if not self.is_ep_communicator:
return

moe_modules = [
module
for module in model.modules()
# TODO(bnell): Should use isinstance but can't. Maybe search for
# presence of quant_method.maybe_init_modular_kernel?
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
moe_modules = [module for module in model.modules() if is_moe_layer(module)]
for module in moe_modules:
module.maybe_init_modular_kernel()

Expand Down
6 changes: 2 additions & 4 deletions vllm/distributed/elastic_ep/elastic_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.utils import is_moe_layer
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
Expand Down Expand Up @@ -319,10 +320,7 @@ def switch_and_prepare(self) -> None:
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
if is_moe_layer(module)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def can_replace_layer(
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""

# source_layer is FusedMoE or SharedFusedMoE
# source_layer is FusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2


Expand Down Expand Up @@ -772,5 +772,5 @@ def can_replace_layer(
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
# source_layer is FusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
Expand Down Expand Up @@ -64,7 +63,6 @@ def get_config() -> dict[str, Any] | None:
"FusedMoEPrepareAndFinalizeModular",
"GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"activation_without_mul",
"apply_moe_activation",
"override_config",
Expand Down
25 changes: 0 additions & 25 deletions vllm/model_executor/layers/fused_moe/shared_fused_moe.py

This file was deleted.

8 changes: 4 additions & 4 deletions vllm/model_executor/models/AXK1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
Expand Down Expand Up @@ -916,7 +916,7 @@ def compute_logits(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down Expand Up @@ -950,7 +950,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -124,8 +124,8 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

# Routed experts using SharedFusedMoE
self.experts = SharedFusedMoE(
# Routed experts using FusedMoE
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
Expand Down Expand Up @@ -479,7 +479,7 @@ def make_empty_intermediate_tensors(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down Expand Up @@ -637,7 +637,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.num_moe_layers = config.num_hidden_layers - config.num_dense_layers
self.num_expert_groups = config.n_group

self.moe_layers: list[SharedFusedMoE] = []
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(
return out


class AriaFusedMoE(SharedFusedMoE):
class AriaFusedMoE(FusedMoE):
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -285,7 +285,7 @@ def __init__(
else:
self.shared_experts = None

self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
Expand Down Expand Up @@ -461,7 +461,7 @@ def forward(
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/bailing_moe_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RMSNormGated,
layernorm_fn,
)
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -351,8 +351,8 @@ def __init__(
else:
self.shared_experts = None

# Routed experts using SharedFusedMoE
self.experts = SharedFusedMoE(
# Routed experts using FusedMoE
self.experts = FusedMoE(
shared_experts=self.shared_experts,
num_experts=self.num_experts,
top_k=self.top_k,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -252,7 +252,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
]
stacked_params_mapping.extend(indexer_fused_mapping)

expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
RoutingMethodType,
SharedFusedMoE,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
Expand Down Expand Up @@ -311,7 +311,7 @@ def __init__(
prefix=f"{prefix}.shared_experts",
)

self.experts = SharedFusedMoE(
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
Expand Down Expand Up @@ -1432,7 +1432,7 @@ def compute_logits(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return SharedFusedMoE.make_expert_params_mapping(
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down Expand Up @@ -1474,7 +1474,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down
Loading
Loading