Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_routing_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s

# Test the select_experts method
topk_weights, topk_ids, _ = fused_moe.select_experts(
topk_weights, topk_ids, _ = fused_moe.router.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@

from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_params import (
FusedMoEParams,
)
from vllm.model_executor.layers.fused_moe.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
Expand Down Expand Up @@ -45,7 +52,10 @@ def get_config() -> dict[str, Any] | None:

__all__ = [
"FusedMoE",
"FusedMoEParams",
"FusedMoERouter",
"FusedMoEConfig",
"FusedMoEQuantConfig",
"FusedMoEMethodBase",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
Expand Down
13 changes: 0 additions & 13 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
TopKWeightAndReduceNoOP,
Expand Down Expand Up @@ -711,16 +708,6 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)


def modular_marlin_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
MarlinExperts(quant_config),
shared_experts,
)


class BatchedMarlinExperts(MarlinExpertsBase):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_params import (
FusedMoEParams,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
Expand Down Expand Up @@ -94,9 +97,9 @@
return self.__class__.__name__

@abstractmethod
def apply(

Check failure on line 100 in vllm/model_executor/layers/fused_moe/fused_moe_method_base.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "apply" incompatible with supertype "QuantizeMethodBase" [override]
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
params: FusedMoEParams,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Expand Down
23 changes: 13 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.fused_moe_params import (
FusedMoEParams,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEPrepareAndFinalize,
Expand Down Expand Up @@ -91,31 +94,31 @@
) -> FusedMoEQuantConfig | None:
return self.moe_quant_config

def apply(

Check failure on line 97 in vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "apply" incompatible with supertype "QuantizeMethodBase" [override]
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
params: FusedMoEParams,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids, zero_expert_result = layer.select_experts(
topk_weights, topk_ids, zero_expert_result = params.router.select_experts(
hidden_states=x,
router_logits=router_logits,
)

result = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1=params.w13_weight,
w2=params.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else layer.expert_map,
activation=params.activation,
global_num_experts=params.global_num_experts,
apply_router_weight_on_input=params.apply_router_weight_on_input,
expert_map=None if self.disable_expert_map else params.expert_map,
)

if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
if params.zero_expert_num != 0 and params.zero_expert_type is not None:
assert not isinstance(result, tuple), (
"Shared + zero experts are mutually exclusive not yet supported"
)
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter


class FusedMoEParams(torch.nn.Module):
def __init__(self, router: FusedMoERouter):
super().__init__()
self.router = router
28 changes: 28 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod

import torch

from vllm.model_executor.layers.fused_moe.config import RoutingMethodType


# TODO: add eplb stuff here
class FusedMoERouter(ABC):
@property
@abstractmethod
def enable_eplb(self) -> bool:
raise NotImplementedError

@property
@abstractmethod
def routing_method_type(self) -> RoutingMethodType:
raise NotImplementedError

@abstractmethod
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
raise NotImplementedError
36 changes: 30 additions & 6 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
from vllm.model_executor.layers.fused_moe.fused_moe_params import FusedMoEParams
from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
Expand Down Expand Up @@ -297,8 +299,29 @@ def maybe_roundup_hidden_size(
return hidden_size


class FusedMoERouterImpl(FusedMoERouter):
def __init__(self, layer: "FusedMoE"):
super().__init__()
self.layer = layer

@property
def enable_eplb(self) -> bool:
return self.layer.enable_eplb

@property
def routing_method_type(self) -> RoutingMethodType:
return self.layer.routing_method_type

def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return self.layer.select_experts(hidden_states, router_logits)


@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
class FusedMoE(CustomOp, FusedMoEParams):
"""FusedMoE layer for MoE models.

This layer contains both MergedColumnParallel weights (gate_up_proj /
Expand Down Expand Up @@ -353,9 +376,10 @@ def __init__(
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
routing_method_type: RoutingMethodType | None = None,
):
super().__init__()
CustomOp.__init__(self)
FusedMoEParams.__init__(self, router=FusedMoERouterImpl(self))

# Allow disabling of the separate shared experts stream for
# debug purposes.
Expand Down Expand Up @@ -542,7 +566,7 @@ def __init__(

# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
self.routing_method_type: RoutingMethodType = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
Expand Down Expand Up @@ -1809,7 +1833,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
params=self,
x=staged_hidden_states,
router_logits=staged_router_logits,
)
Expand Down Expand Up @@ -1952,7 +1976,7 @@ def forward_impl(

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
params=self,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
Expand Down
Loading