diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py index f1234d821347..b33282523db5 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -11,6 +11,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk @@ -161,7 +162,7 @@ def bench_run( w2_fp8q_cutlass, topk_weights, topk_ids, - activation="silu", + activation=MoEActivation.SILU, global_num_experts=num_experts, ) torch.cuda.synchronize() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c5e3dabe5796..5ee1cf1995c3 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -16,6 +16,7 @@ from ray.experimental.tqdm_ray import tqdm from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -211,7 +212,8 @@ def run(): hidden_dim=hidden_size, intermediate_size_per_partition=shard_intermediate_size, num_local_experts=num_experts, - activation="silu", + num_logical_experts=num_experts, + activation=MoEActivation.SILU, moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), in_dtype=init_dtype, routing_method=RoutingMethodType.TopK, diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 6dfcd5ebe51e..87cf0453bea1 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -22,6 +22,7 @@ ) from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.all2all_utils import ( maybe_make_prepare_finalize, ) @@ -599,7 +600,7 @@ def next_power_of_2(x): moe_parallel_config=moe_parallel_config, in_dtype=config.dtype, max_num_tokens=next_power_of_2(config.M), - activation="silu", + activation=MoEActivation.SILU, device=vllm_config.device_config.device, routing_method=RoutingMethodType.DeepSeekV3, ) diff --git a/tests/kernels/moe/test_cpu_fused_moe.py b/tests/kernels/moe/test_cpu_fused_moe.py index 681f42091742..839eceeeb2fc 100644 --- a/tests/kernels/moe/test_cpu_fused_moe.py +++ b/tests/kernels/moe/test_cpu_fused_moe.py @@ -6,6 +6,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -19,7 +20,7 @@ HIDDEN_DIM = [128, 2880] INTERMEDIATE_DIM = [128, 2880] BATCH_SIZE = [1, 64, 256] -ACT = ["silu", "swigluoai"] +ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI] USE_BIAS = [True, False] ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] DTYPE = [torch.bfloat16] @@ -33,7 +34,7 @@ def ref_fused_moe( w2_bias: torch.Tensor | None, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, ) -> torch.Tensor: len_experts = w13.size(0) @@ -103,7 +104,7 @@ def test_cpu_fused_moe( intermediate_size: int, use_bias: bool, dtype: torch.dtype, - act: str, + act: MoEActivation, isa: str, ): set_random_seed(0) @@ -153,7 +154,7 @@ def test_cpu_fused_moe( w2_bias, topk_weight, topk_ids, - act, + act.value, isa, ) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index d232d00fcbb9..ec23008dfa1f 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -12,6 +12,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -531,7 +532,7 @@ def test_run_cutlass_moe_fp8( c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) - activation = "silu" + activation = MoEActivation.SILU a1q, a1q_scale = moe_kernel_quantize_input( mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token ) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 11f5357157d2..2b8240482829 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -16,6 +16,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, @@ -324,7 +325,7 @@ def build_expert_map(): w2=w2, topk_weights=test_tensors.topk_weights, topk_ids=test_tensors.topk, - activation="silu", + activation=MoEActivation.SILU, global_num_experts=num_experts, expert_map=build_expert_map(), apply_router_weight_on_input=False, diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 8d3ca165076c..01f340730af3 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -15,6 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) @@ -260,7 +261,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): w2=w2, topk_weights=topk_weights_chunk, topk_ids=topk_chunk, - activation="silu", + activation=MoEActivation.SILU, global_num_experts=num_experts, expert_map=build_expert_map(), apply_router_weight_on_input=False, diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index c5d34ef0b603..9c31d9325962 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -93,9 +94,14 @@ class TestData: @staticmethod def make_moe_tensors_8bit( - m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu" + m: int, + k: int, + n: int, + e: int, + is_trtllm: bool, + activation: MoEActivation = MoEActivation.SILU, ) -> "TestData": - is_gated = activation != "relu2_no_mul" + is_gated = activation.is_gated hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 w13 = torch.randn( @@ -194,7 +200,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk_weights=topk_weights, topk_ids=topk_ids, inplace=False, - activation="silu", + activation=MoEActivation.SILU, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, @@ -219,21 +225,19 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"]) +@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]) def test_flashinfer_cutlass_moe_fp8_no_graph( m: int, n: int, k: int, e: int, topk: int, - activation: str, + activation: MoEActivation, monkeypatch, workspace_init, ): set_random_seed(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") - assert activation in ["silu", "relu2_no_mul"] - is_act_and_mul = activation == "silu_and_mul" with set_current_vllm_config(vllm_config): td = TestData.make_moe_tensors_8bit( m, k, n, e, is_trtllm=False, activation=activation @@ -292,7 +296,7 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: device="cuda", moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), in_dtype=torch.bfloat16, - is_act_and_mul=is_act_and_mul, + is_act_and_mul=activation.is_gated, routing_method=RoutingMethodType.TopK, ) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index c61bca31360e..1f1349cff841 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -13,6 +13,7 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -54,7 +55,7 @@ @pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"]) +@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]) @torch.inference_mode() def test_flashinfer_fp4_moe_no_graph( m: int, @@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph( e: int, topk: int, dtype: torch.dtype, - activation: str, + activation: MoEActivation, workspace_init, ): set_random_seed(7) @@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 - is_gated_act = activation == "silu_and_mul" + is_gated_act = activation.is_gated w1_q, w2_q, quant_config = make_test_quant_config( e, @@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph( inplace=False, ) - fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] - flashinfer_output = flashinfer_experts( hidden_states=a, w1=w1_q, w2=w2_q, topk_weights=topk_weights, topk_ids=topk_ids, - activation=fi_activation, + activation=activation, ) # Reference check: diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index bebf18ef0aaf..cf9ff18634d0 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -7,6 +7,7 @@ import pytest import torch +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.utils.import_utils import has_triton_kernels if not has_triton_kernels(): @@ -192,7 +193,7 @@ def oai_triton_moe_impl( w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - activation="swigluoai", + activation=MoEActivation.SWIGLUOAI, global_num_experts=num_experts, expert_map=None, apply_router_weight_on_input=False, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 6a622ac8e4d5..eddc395ccbb7 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.fused_moe import ( + MoEActivation, fused_topk, ) from vllm.model_executor.layers.fused_moe.config import ( @@ -1155,7 +1156,10 @@ def test_fused_marlin_moe_with_bias(m): @pytest.mark.parametrize("m", [1, 64, 256]) @pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)]) @pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)]) -def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): +@pytest.mark.parametrize("activation", [MoEActivation.RELU2_NO_MUL]) +def test_fused_marlin_moe_non_gated( + m: int, n: int, k: int, e: int, topk: int, activation: MoEActivation +): """Test Marlin MoE with non-gated activation (relu2_no_mul). Non-gated activations like relu2 don't have the gate-up projection pattern, @@ -1198,7 +1202,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): w2_data.w_ref, score, topk, - activation="relu2", + activation=activation, ) marlin_output = fused_marlin_moe( @@ -1223,7 +1227,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int): w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, - activation="relu2_no_mul", + activation=activation, ) torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0) @@ -1330,9 +1334,18 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): @pytest.mark.parametrize("topk", [2]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("with_bias", [False, True]) -@pytest.mark.parametrize("activation", ["silu"]) +@pytest.mark.parametrize("activation", [MoEActivation.SILU]) @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test") -def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation): +def test_cpu_fused_moe_basic( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + with_bias: bool, + activation: MoEActivation, +): from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE device = "cpu" @@ -1608,6 +1621,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( hidden_dim=k, intermediate_size_per_partition=n, num_local_experts=e, + num_logical_experts=e, activation="silu", device="cuda", moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 894e57fe2d68..d8a6600743e2 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,6 +9,7 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -149,7 +150,7 @@ def make_moe_config() -> FusedMoEConfig: num_local_experts=num_local_experts, num_logical_experts=num_experts, moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), - activation="silu", + activation=MoEActivation.SILU, in_dtype=torch.bfloat16, device="cuda", routing_method=RoutingMethodType.Llama4, diff --git a/tests/kernels/moe/test_triton_moe_no_act_mul.py b/tests/kernels/moe/test_triton_moe_no_act_mul.py index ab15f898b625..1dfac3cf0fdc 100644 --- a/tests/kernels/moe/test_triton_moe_no_act_mul.py +++ b/tests/kernels/moe/test_triton_moe_no_act_mul.py @@ -11,15 +11,11 @@ import torch from tests.kernels.moe.utils import make_dummy_moe_config +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, ) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.model_executor.layers.fused_moe.utils import ( - GELU_NO_MUL, - RELU2_NO_MUL, - SILU_NO_MUL, -) from vllm.platforms import current_platform # Test parameters @@ -28,7 +24,11 @@ K_SIZES = [64, 128] TOPK_VALUES = [1, 2] NUM_EXPERTS = 8 -NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL] +NO_MUL_ACTIVATIONS = [ + MoEActivation.SILU_NO_MUL, + MoEActivation.GELU_NO_MUL, + MoEActivation.RELU2_NO_MUL, +] def make_test_tensors( @@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation( n: int, k: int, topk: int, - activation: str, + activation: MoEActivation, ): hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors( m, n, k, NUM_EXPERTS, topk @@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated(): ) ws1_no_mul, _, out_no_mul = experts.workspace_shapes( - M, N, K, topk, 8, 8, None, SILU_NO_MUL + M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL ) ws1_gated, _, out_gated = experts.workspace_shapes( - M, N, K, topk, 8, 8, None, "silu" + M, N, K, topk, 8, 8, None, MoEActivation.SILU ) # For no_mul: activation_out_dim = N @@ -202,10 +202,10 @@ def test_adjust_n_for_activation(): N = 256 # Gated activations should return N // 2 - assert experts.adjust_N_for_activation(N, "silu") == N // 2 - assert experts.adjust_N_for_activation(N, "gelu") == N // 2 + assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2 + assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2 # Non-gated activations should return N - assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N - assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N - assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N + assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N + assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N + assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 984fabc477c4..6cf01ac47b8b 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -12,6 +12,7 @@ fused_experts, fused_topk, ) +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -54,7 +55,7 @@ def make_dummy_moe_config( num_local_experts=num_experts, num_logical_experts=num_experts, moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), - activation="silu", + activation=MoEActivation.SILU, in_dtype=in_dtype, device="cuda", routing_method=RoutingMethodType.TopK, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 9c6cc4dabb26..c1a111e1f14d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,6 +15,7 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.model_executor.custom_op import op_registry from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.attention.backend import AttentionType @@ -840,7 +841,7 @@ def torch_experts( per_act_token_quant=False, block_shape: list[int] | None = None, apply_router_weights_on_input: bool = False, - activation: str = "silu_and_mul", + activation: MoEActivation = MoEActivation.SILU, ) -> torch.Tensor: assert ( global_num_experts == -1 @@ -883,7 +884,7 @@ def torch_experts( f32 = torch.float32 - act = op_registry[activation] + act = op_registry[activation.custom_op_name] for i in range(num_experts): mask = topk_ids == i @@ -973,7 +974,7 @@ def torch_moe( b_bias2: torch.Tensor | None = None, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, - activation: str = "silu_and_mul", + activation: MoEActivation = MoEActivation.SILU, ) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index dc17af87e164..c6cb31b629a0 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -4,6 +4,11 @@ from contextlib import contextmanager from typing import Any +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + activation_without_mul, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, RoutingMethodType, @@ -27,7 +32,6 @@ from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) -from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import ( ZeroExpertFusedMoE, ) @@ -54,6 +58,7 @@ def get_config() -> dict[str, Any] | None: "FusedMoERouter", "FusedMoEConfig", "FusedMoEMethodBase", + "MoEActivation", "UnquantizedFusedMoEMethod", "FusedMoeWeightScaleSupported", "FusedMoEPermuteExpertsUnpermute", @@ -63,6 +68,7 @@ def get_config() -> dict[str, Any] | None: "SharedFusedMoE", "ZeroExpertFusedMoE", "activation_without_mul", + "apply_moe_activation", "override_config", "get_config", ] diff --git a/vllm/model_executor/layers/fused_moe/activation.py b/vllm/model_executor/layers/fused_moe/activation.py new file mode 100644 index 000000000000..3112b3054fcd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/activation.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""MoE activation function enum and utilities.""" + +from enum import Enum + +import torch +import torch.nn.functional as F + + +class MoEActivation(Enum): + """Activation functions for MoE layers.""" + + # Gated activations (gate * activation(up)) expect input of shape [..., 2*d] + # and produce output of shape [..., d] + SILU = "silu" + GELU = "gelu" + RELU2 = "relu2" + SWIGLUOAI = "swigluoai" + SWIGLUSTEP = "swiglustep" + + # Non-gated activations (no mul with gate) expect input of shape [..., d] + # and produce output of shape [..., d]. + # NOTE: Non-gated activations require the "_no_mul" suffix to be present. + SILU_NO_MUL = "silu_no_mul" + GELU_NO_MUL = "gelu_no_mul" + RELU2_NO_MUL = "relu2_no_mul" + + @property + def is_gated(self) -> bool: + """Returns True if activation expects gate*activation(up) pattern. + + Gated activations expect input tensor with 2x the output size, + where the first half is the gate and second half is the up projection. + """ + return not self.value.endswith("_no_mul") + + @property + def custom_op_name(self) -> str: + """Maps to the CustomOp name of activations + in vllm/model_executor/layers/activation.py.""" + return _CUSTOM_OP_NAMES[self] + + def without_mul(self) -> "MoEActivation": + """Get the non-gated variant of this activation. + + For activations that have a _no_mul variant, returns that variant. + For activations without a _no_mul variant (or already _no_mul), + returns self. + """ + return _WITHOUT_MUL.get(self, self) + + @classmethod + def from_str(cls, s: str) -> "MoEActivation": + """Parse from string for backward compatibility.""" + for member in cls: + if member.value == s: + return member + valid = [m.value for m in cls] + raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") + + +# Module-level lookup tables used by MoEActivation functions. +_CUSTOM_OP_NAMES: dict[MoEActivation, str] = { + MoEActivation.SILU: "silu_and_mul", + MoEActivation.GELU: "gelu_and_mul", + MoEActivation.SWIGLUOAI: "swigluoai_and_mul", + MoEActivation.SWIGLUSTEP: "swiglustep_and_mul", + MoEActivation.RELU2: "relu2", + MoEActivation.SILU_NO_MUL: "silu_and_mul", + MoEActivation.GELU_NO_MUL: "gelu_and_mul", + MoEActivation.RELU2_NO_MUL: "relu2", +} + +_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = { + MoEActivation.SILU: MoEActivation.SILU_NO_MUL, + MoEActivation.GELU: MoEActivation.GELU_NO_MUL, + MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, +} + + +def activation_without_mul(activation: str) -> str: + """Get the non-gated variant of an activation function. + + Args: + activation: The activation function name (e.g., "silu", "gelu") + + Returns: + The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul") + """ + return MoEActivation.from_str(activation).without_mul().value + + +def apply_moe_activation( + activation: MoEActivation, + output: torch.Tensor, + input: torch.Tensor, +) -> torch.Tensor: + """Apply MoE activation function.""" + assert input.dim() == 2, "Input must be 2D" + assert output.dim() == 2, "Output must be 2D" + if activation.is_gated: + assert output.size(-1) * 2 == input.size(-1), ( + f"{activation.value} expects 2x ratio: " + f"{output.size(-1) * 2} vs {input.size(-1)}" + ) + else: + assert output.size(-1) == input.size(-1), ( + f"{activation.value} expects equal sizes: " + f"{output.size(-1)} vs {input.size(-1)}" + ) + + # Activations with gated multiplication (gate × activation(up)) + if activation == MoEActivation.SILU: + torch.ops._C.silu_and_mul(output, input) + elif activation == MoEActivation.GELU: + torch.ops._C.gelu_and_mul(output, input) + elif activation == MoEActivation.SWIGLUOAI: + torch.ops._C.swigluoai_and_mul(output, input) + elif activation == MoEActivation.SWIGLUSTEP: + from vllm.model_executor.layers.activation import swiglustep_and_mul_triton + + swiglustep_and_mul_triton(output, input) + + # Activations without gated multiplication + elif activation == MoEActivation.SILU_NO_MUL: + output.copy_(F.silu(input)) + elif activation == MoEActivation.GELU_NO_MUL: + output.copy_(F.gelu(input)) + elif activation == MoEActivation.RELU2_NO_MUL: + F.relu(input, inplace=True) + torch.square(input, out=output) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + return output diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index ac37cff9329a..405965c5395b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -303,8 +304,8 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation == MoEActivation.SILU @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -338,7 +339,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks @@ -389,7 +390,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6dce6875df83..c999673e854b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -14,6 +14,7 @@ get_tensor_model_parallel_rank, ) from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_DTYPES, OCP_MX_Scheme, @@ -1132,7 +1133,7 @@ class FusedMoEConfig: intermediate_size_per_partition: int num_local_experts: int num_logical_experts: int - activation: str + activation: MoEActivation device: torch.device | str routing_method: RoutingMethodType moe_parallel_config: FusedMoEParallelConfig diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 1275388227d0..7a78faafb97c 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -9,6 +9,7 @@ from vllm import _custom_ops as ops from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter from vllm.utils.torch_utils import direct_register_custom_op @@ -36,9 +37,9 @@ def _swigluoai_forward_native( # Map activation names to their native forward functions. # Uses static methods or standalone functions to avoid instantiating CustomOp # classes, which would call get_current_vllm_config() before config is set. -_CPU_MOE_ACT_FN: dict[str, Callable[[torch.Tensor], torch.Tensor]] = { - "silu": SiluAndMul.forward_native, - "swigluoai": _swigluoai_forward_native, +_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = { + MoEActivation.SILU: SiluAndMul.forward_native, + MoEActivation.SWIGLUOAI: _swigluoai_forward_native, } @@ -168,9 +169,9 @@ def __call__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, ) -> torch.Tensor: - assert activation == "silu", f"{activation} is not supported." + assert activation == MoEActivation.SILU, f"{activation} is not supported." assert not apply_router_weight_on_input topk_weights, topk_ids = select_experts( hidden_states=x, @@ -235,7 +236,7 @@ def __call__( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, ) -> torch.Tensor: assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported." @@ -353,7 +354,7 @@ def forward_grouped_gemm( input: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int = -1, skip_weighted: bool = False, ) -> torch.Tensor: @@ -371,7 +372,7 @@ def forward_grouped_gemm( getattr(layer, "w2_bias", None), topk_weights, topk_ids, - activation, + activation.value, self.isa, skip_weighted, ) @@ -383,7 +384,7 @@ def forward_torch( input: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int = -1, skip_weighted: bool = False, ) -> torch.Tensor: @@ -419,6 +420,7 @@ def cpu_fused_moe_torch( global_num_experts: int = -1, skip_weighted: bool = False, ) -> None: + act = MoEActivation.from_str(activation) layer = _CPU_MOE_LAYER_CACHE[layer_id]() # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 @@ -442,7 +444,7 @@ def cpu_fused_moe_torch( tokens_for_this_expert = sorted_tokens[start_idx:end_idx] gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore - gate_up = _CPU_MOE_ACT_FN[activation](gate_up) + gate_up = _CPU_MOE_ACT_FN[act](gate_up) expert_out = layer.down_linear[i](gate_up) # type: ignore outputs.append(expert_out) start_idx = end_idx diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 77d439d320ad..4f89487784e3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,6 +7,10 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -25,7 +29,6 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, - apply_moe_activation, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -51,7 +54,7 @@ def run_cutlass_moe_fp8( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, w1_scale: torch.Tensor | None, @@ -73,7 +76,7 @@ def run_cutlass_moe_fp8( ): a1q = hidden_states - assert not activation.endswith("_no_mul"), "Only gated activation is supported" + assert activation.is_gated, "Only gated activation is supported" assert w1_scale is not None assert w2_scale is not None assert w1.dtype == torch.float8_e4m3fn @@ -310,8 +313,12 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [ + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + ] def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -325,7 +332,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -415,7 +422,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M * topk, max(N, K)) @@ -456,7 +463,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None @@ -489,7 +496,7 @@ def run_cutlass_moe_fp4( w2_alphas: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, workspace13: torch.Tensor, workspace2: torch.Tensor, m: int, @@ -612,7 +619,7 @@ def run_cutlass_moe_fp4( blockscale_offsets[:-1], ) del rep_a_fp4, rep_a_blockscale - if activation == "silu": + if activation == MoEActivation.SILU: # Fused SiLU+Mul+NVFP4 quantization # Note: c2 workspace is no longer needed since SiLU is fused with quantization. # c3 reuses workspace13 after c1 is consumed. @@ -682,8 +689,12 @@ def _supports_quant_scheme( return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic) @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [ + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + ] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -716,7 +727,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) @@ -731,7 +742,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, # unused @@ -776,7 +787,7 @@ def run_cutlass_moe_w4a8_fp8( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, w1_scale: torch.Tensor | None, @@ -970,7 +981,7 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: raise NotImplementedError( "CutlassExpertsW4A8Fp8 is not yet used by an Oracle. " "This method should not be called." @@ -1005,7 +1016,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M * topk, max(N, K)) @@ -1021,7 +1032,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -1094,7 +1105,7 @@ def cutlass_moe_w4a8_fp8( s_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -1137,7 +1148,7 @@ def cutlass_moe_w4a8_fp8( dtype: torch.int64 - per_act_token (Optional[bool]): Whether the scale is per-token or per-tensor. - - activation (str): The activation function to use. + - activation (MoEActivation): The activation function to use. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 59dde3ca9e36..69ca7c91cfda 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -5,6 +5,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -145,8 +146,8 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swiglustep"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -171,7 +172,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.block_shape is not None block_m = self.block_shape[0] @@ -187,7 +188,7 @@ def workspace_shapes( return (workspace1, workspace2, output) def _act_mul_quant( - self, input: torch.Tensor, output: torch.Tensor, activation: str + self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation ) -> tuple[torch.Tensor, torch.Tensor]: assert self.block_shape is not None block_k = self.block_shape[1] @@ -210,7 +211,7 @@ def _act_mul_quant( return a2q, a2q_scale # 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel - if activation == "silu": + if activation == MoEActivation.SILU: use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 return silu_mul_per_token_group_quant_fp8_colmajor( input=input, @@ -235,7 +236,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 07e5b80059f0..4b6458e7fd33 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -6,6 +6,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey @@ -76,7 +77,7 @@ def _supports_quant_scheme( ) and fallback_cls._supports_quant_scheme(weight_key, activation_key) @classmethod - def _supports_activation(cls, activation: str) -> bool: + def _supports_activation(cls, activation: MoEActivation) -> bool: experts_cls, fallback_cls = cls.get_clses() return experts_cls._supports_activation( activation @@ -138,7 +139,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: raise NotImplementedError @@ -159,7 +160,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 2ad949577664..d0cf7533d70f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -6,6 +6,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -72,8 +73,8 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation == MoEActivation.SILU @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -101,7 +102,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. @@ -135,7 +136,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 85df6cb66a01..4ec76ee9820c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -5,6 +5,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, FusedMoEQuantConfig, @@ -130,8 +131,8 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "relu2_no_mul"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -164,7 +165,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. @@ -201,7 +202,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -214,8 +215,8 @@ def apply( from flashinfer.fused_moe.core import ActivationType activation_str_to_value_map = { - "silu": ActivationType.Swiglu, # This is the default - "relu2_no_mul": ActivationType.Relu2, + MoEActivation.SILU: ActivationType.Swiglu, # This is the default + MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, } assert activation in activation_str_to_value_map, ( f"{activation=} missing from {activation_str_to_value_map.keys()=}" diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 9af18485e057..a50ad6722078 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -50,9 +51,9 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A -def _supports_activation(activation: str) -> bool: +def _supports_activation(activation: MoEActivation) -> bool: """Supports silu activation only.""" - return activation in ["silu"] + return activation == MoEActivation.SILU def _supports_routing_method( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 8822b8a8a18e..fbd47f8c4236 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -698,7 +699,7 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: raise NotImplementedError( "NaiveBatchedExperts is not yet used by an Oracle. " "This method should not be called." @@ -730,7 +731,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.num_dispatchers is not None assert self.max_num_tokens is not None @@ -757,7 +758,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -942,14 +943,14 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: return activation in [ - "silu", - "gelu", - "swigluoai", - "silu_no_mul", - "gelu_no_mul", - "relu2_no_mul", + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + MoEActivation.SILU_NO_MUL, + MoEActivation.GELU_NO_MUL, + MoEActivation.RELU2_NO_MUL, ] @staticmethod @@ -975,7 +976,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.num_dispatchers is not None assert self.max_num_tokens is not None @@ -996,7 +997,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3d3a21f81c72..57fb3561d1d2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,6 +8,10 @@ import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -23,7 +27,6 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, - apply_moe_activation, disable_inplace, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -59,9 +62,9 @@ def _fused_marlin_moe( sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, activation_func: Callable[ - [str, torch.Tensor, torch.Tensor], None + [MoEActivation, torch.Tensor, torch.Tensor], None ] = apply_moe_activation, input_global_scale1: torch.Tensor | None = None, input_global_scale2: torch.Tensor | None = None, @@ -83,7 +86,7 @@ def _fused_marlin_moe( assert hidden_states.ndim == 2 M, K = hidden_states.size() N = marlin_moe_intermediate_size(w1, w2) - w13_num_shards = 1 if "no_mul" in activation else 2 + w13_num_shards = 2 if activation.is_gated else 1 if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) @@ -215,9 +218,9 @@ def fused_marlin_moe( quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, activation_func: Callable[ - [str, torch.Tensor, torch.Tensor], None + [MoEActivation, torch.Tensor, torch.Tensor], None ] = apply_moe_activation, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, expert_map: torch.Tensor | None = None, @@ -377,7 +380,7 @@ def batched_fused_marlin_moe( quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - activation: str | None = "silu", + activation: MoEActivation = MoEActivation.SILU, expert_map: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None, @@ -579,14 +582,14 @@ def _supports_quant_scheme( return weight_key in SUPPORTED_W @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: return activation in [ - "silu", - "gelu", - "swigluoai", - "silu_no_mul", - "gelu_no_mul", - "relu2_no_mul", + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + MoEActivation.SILU_NO_MUL, + MoEActivation.GELU_NO_MUL, + MoEActivation.RELU2_NO_MUL, ] @staticmethod @@ -661,7 +664,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Modular Kernel provisions output buffer from workspace1. However in # the fused_marlin_moe() function, the final torch.sum(), is defined @@ -692,7 +695,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -788,7 +791,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.num_dispatchers is not None assert self.max_num_tokens is not None @@ -808,7 +811,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 352288e173c8..f988e91c2478 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -17,6 +17,10 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, @@ -32,7 +36,6 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, - apply_moe_activation, disable_inplace, moe_kernel_quantize_input, ) @@ -1468,6 +1471,7 @@ def outplace_fused_experts_fake( topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", + apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1521,7 +1525,7 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, @@ -1539,7 +1543,7 @@ def fused_experts( w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - activation=activation, + activation=activation.value, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=quant_config.use_fp8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8, @@ -1618,6 +1622,9 @@ def fused_experts_impl( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: + # Convert string activation to enum for internal use + activation_enum = MoEActivation.from_str(activation) + # Check constraints. if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" @@ -1692,7 +1699,7 @@ def fused_experts_impl( # This needs separate memory since it's used concurrently with cache1 activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( - N, activation + N, activation_enum ) intermediate_cache2 = torch.empty( (M * top_k_num, activation_out_dim), @@ -1832,7 +1839,7 @@ def fused_experts_impl( ) apply_moe_activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N) + activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) ) qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( @@ -1932,8 +1939,13 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai", "swiglustep"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [ + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, + ] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -1957,7 +1969,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M, topk, max(activation_out_dim, K)) @@ -1973,7 +1985,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -2138,7 +2150,7 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: raise NotImplementedError( "TritonWNA16Experts is not yet used by an Oracle. " "This method should not be called." @@ -2159,7 +2171,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 5aaf2a8c39d4..70d11f44f43b 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEParallelConfig, @@ -172,7 +173,7 @@ def triton_kernel_moe_forward( gating_output: torch.Tensor, topk: int, renormalize: bool, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SWIGLUOAI, quant_config: FusedMoEQuantConfig | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, @@ -211,7 +212,7 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx topk: int, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SWIGLUOAI, quant_config: FusedMoEQuantConfig | None = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, @@ -222,6 +223,9 @@ def triton_kernel_fused_experts( a1q_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Triton implementation of fused expert computation using OAI kernels.""" + assert activation == MoEActivation.SWIGLUOAI, ( + "Only SWIGLUOAI activation is supported" + ) if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -379,7 +383,7 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: raise NotImplementedError( "OAITritonExperts is not yet used by an Oracle. " "This method should not be called." @@ -463,7 +467,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel activation_out_dim = self.adjust_N_for_activation(N, activation) @@ -480,7 +484,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -547,7 +551,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel activation_out_dim = self.adjust_N_for_activation(N, activation) @@ -567,7 +571,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5a8f51de6462..a181b18c9d1c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -20,6 +20,7 @@ from vllm.distributed.eplb.eplb_state import EplbLayerState, EplbState from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -500,7 +501,7 @@ def __init__( # TODO(bnell): end attributes self.apply_router_weight_on_input = apply_router_weight_on_input - self.activation = activation + self.activation = MoEActivation.from_str(activation) self.router = create_fused_moe_router( top_k=top_k, @@ -554,7 +555,7 @@ def __init__( has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, - activation=activation, + activation=self.activation, device=vllm_config.device_config.device, routing_method=self.routing_method_type, # TODO: in_dtype == out_dtype? diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index e2f77d6c8509..7e6855778fd4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -12,6 +12,10 @@ import vllm.envs as envs from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -19,7 +23,6 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, - apply_moe_activation, count_expert_num_tokens, disable_inplace, ) @@ -536,7 +539,7 @@ def _supports_quant_scheme( @staticmethod @abstractmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: """ Whether the kernel supports a particular act function. """ @@ -658,7 +661,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Compute the shapes for the temporary and final outputs of the two gemms @@ -690,7 +693,7 @@ def workspace_shapes( raise NotImplementedError @staticmethod - def adjust_N_for_activation(N: int, activation: str) -> int: + def adjust_N_for_activation(N: int, activation: MoEActivation) -> int: """ Calculate the output dimension for the activation function. @@ -702,16 +705,15 @@ def adjust_N_for_activation(N: int, activation: str) -> int: Args: N: The intermediate size (width of w1/w3 weights). - activation: The activation function name. + activation: The activation function enum. Returns: The output dimension after activation. """ - is_no_mul = activation.endswith("_no_mul") - return N if is_no_mul else N // 2 + return N if not activation.is_gated else N // 2 def activation( - self, activation: str, output: torch.Tensor, input: torch.Tensor + self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor ) -> None: apply_moe_activation(activation, output, input) @@ -732,7 +734,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -892,7 +894,7 @@ def _allocate_buffers( global_num_experts: int, local_num_experts: int, expert_tokens_meta: ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Allocate temporary and output buffers for the fused experts op. @@ -1135,7 +1137,7 @@ def _fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, local_num_experts: int, expert_map: torch.Tensor | None, @@ -1309,7 +1311,7 @@ def forward( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, @@ -1326,7 +1328,7 @@ def forward( - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - - activation (str): The activation function to apply after the first + - activation (MoEActivation): The activation function to apply after the first MoE layer. - global_num_experts (int): The total number of experts in the global expert space. diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 535abc420e37..def1ec9dcb44 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -7,6 +7,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEParallelConfig, @@ -184,7 +185,7 @@ def rocm_aiter_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str = "silu", + activation: MoEActivation = MoEActivation.SILU, apply_router_weight_on_input: bool = False, expert_map: torch.Tensor | None = None, quant_config: FusedMoEQuantConfig | None = None, @@ -196,9 +197,13 @@ def rocm_aiter_fused_experts( if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - activation_method = ( - ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU - ) + if activation == MoEActivation.SILU: + activation_method = ActivationMethod.SILU + elif activation == MoEActivation.GELU: + activation_method = ActivationMethod.GELU + else: + raise ValueError(f"Unsupported activation: {activation}") + # All AITER Fused MoE kernels are expecting the following datatypes topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) @@ -322,8 +327,8 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [MoEActivation.SILU, MoEActivation.GELU] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -347,7 +352,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Workspaces are managed internally by AITER. workspace1 = (0,) @@ -363,7 +368,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index f537f2f99ade..21a3d05f4cd2 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -45,7 +46,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Small batch fallback for sm100. if self.is_sm100 and M <= 8: diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 7e41269dc538..a3f2f59c5b3c 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -45,7 +46,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 074b8154a95a..61e06fa603d6 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -64,7 +65,7 @@ def _supports_quant_scheme( ) @staticmethod - def _supports_activation(activation: str) -> bool: + def _supports_activation(activation: MoEActivation) -> bool: raise NotImplementedError( "TrtLlmGenExperts is not yet used by an Oracle. " "This method should not be called." @@ -95,7 +96,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # The workspaces for this implementation are managed by flashinfer. workspace1 = (0,) @@ -111,7 +112,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 7d5ca876bdcc..a1d4f46aa220 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,7 +4,6 @@ from math import prod import torch -import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -341,65 +340,6 @@ def _validate_scale_shape( assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" -def activation_without_mul(activation: str) -> str: - return activation + "_no_mul" - - -RELU2_NO_MUL: str = activation_without_mul("relu2") -SILU_NO_MUL: str = activation_without_mul("silu") -GELU_NO_MUL: str = activation_without_mul("gelu") - - -def apply_moe_activation( - activation: str, - output: torch.Tensor, - input: torch.Tensor, -) -> torch.Tensor: - """ - Apply MoE activation function. - - For *_and_mul activations (silu, gelu, swigluoai): - - Expects output.size(-1) * 2 == input.size(-1) - - For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): - - Expects output.size(-1) == input.size(-1) - """ - is_no_mul = activation.endswith("_no_mul") - if is_no_mul: - assert output.size(-1) == input.size(-1), ( - f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}" - ) - else: - assert output.size(-1) * 2 == input.size(-1), ( - f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}" - ) - - # Activations with gated multiplication (gate × activation(up)) - if activation == "silu": - torch.ops._C.silu_and_mul(output, input) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(output, input) - elif activation == "swigluoai": - torch.ops._C.swigluoai_and_mul(output, input) - elif activation == "swiglustep": - from vllm.model_executor.layers.activation import swiglustep_and_mul_triton - - swiglustep_and_mul_triton(output, input) - - # Activations without gated multiplication - elif activation == SILU_NO_MUL: - output.copy_(F.silu(input)) - elif activation == GELU_NO_MUL: - output.copy_(F.gelu(input)) - elif activation == RELU2_NO_MUL: - F.relu(input, inplace=True) - torch.square(input, out=output) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - return output - - # Torch custom ops can't deal with outputs aliasing inputs so we need to # disable inplace for torch >= 2.9. # See https://github.com/vllm-project/vllm/issues/26378 diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index a20679ea6c4d..e6f8b8efa804 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -3,6 +3,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -55,8 +56,12 @@ def _supports_no_act_and_mul() -> bool: return False @staticmethod - def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] + def _supports_activation(activation: MoEActivation) -> bool: + return activation in [ + MoEActivation.SILU, + MoEActivation.GELU, + MoEActivation.SWIGLUOAI, + ] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -92,7 +97,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, - activation: str, + activation: MoEActivation, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (0,) workspace2 = (0,) @@ -107,7 +112,7 @@ def apply( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, + activation: MoEActivation, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, @@ -129,7 +134,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, n_experts_per_token=topk, - activation=activation, + activation=activation.value, num_experts=self.moe_config.num_local_experts, ep_rank=self.moe_config.ep_rank, ep_size=self.moe_config.ep_size, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 690ff0454407..0fecc7bbcc85 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -24,6 +24,7 @@ FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -622,7 +623,9 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert layer.activation == "silu", "Only SiLU activation is supported." + assert layer.activation == MoEActivation.SILU, ( + f"Only SiLU activation is supported, not {layer.activation}." + ) assert ( self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM and not layer.enable_eplb @@ -649,7 +652,9 @@ def apply( shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - assert layer.activation == "silu", "Only SiLU activation is supported." + assert layer.activation == MoEActivation.SILU, ( + f"Only SiLU activation is supported, not {layer.activation}." + ) # EPLB path if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: @@ -1025,7 +1030,9 @@ def apply_monolithic( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - assert layer.activation == "silu" + assert layer.activation == MoEActivation.SILU, ( + f"Only SiLU activation is supported, not {layer.activation}." + ) if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 @@ -2271,19 +2278,21 @@ def apply_monolithic( router_logits: torch.Tensor, ) -> torch.Tensor: assert not layer.enable_eplb, "EPLB not supported for W4A8-int MoE yet." - assert layer.activation in ("silu", "swigluoai", "swiglu"), ( - "Only SiLU/SwiGLUGU/SwiGLUUG are supported." - ) + assert layer.activation in ( + MoEActivation.SILU, + MoEActivation.SWIGLUOAI, + MoEActivation.SWIGLUSTEP, + ), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." assert layer.expert_map is None, """expert_map/EP not implemented for CPU dyn-4bit MoE.""" - def _act_kind(s: str) -> int: + def _act_kind(s: MoEActivation) -> int: # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU - if s == "swiglu": + if s == MoEActivation.SWIGLUSTEP: return 0 - if s == "swigluoai": + if s == MoEActivation.SWIGLUOAI: return 1 - if s == "silu": + if s == MoEActivation.SILU: return 2 raise ValueError(f"Unknown activation '{s}'") diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 279f97dd6e2b..cd589b315b4e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,7 @@ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, + MoEActivation, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -965,7 +966,7 @@ def apply_monolithic( # TODO(rob): convert this to MK. if layer.enable_eplb: raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") - assert layer.activation == "silu", ( + assert layer.activation == MoEActivation.SILU, ( f"Expected 'silu' activation but got {layer.activation}" ) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index f7d995598564..88023349e779 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -12,6 +12,10 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import ( + MoEActivation, + apply_moe_activation, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -246,16 +250,13 @@ def _fused_moe_gguf( qweight_type2: int, activation: str, ) -> torch.Tensor: + activation_enum = MoEActivation.from_str(activation) + def act(x: torch.Tensor): d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - if activation == "silu": - torch.ops._C.silu_and_mul(out, x) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(out, x) - else: - raise ValueError(f"Unsupported activation: {activation}") + apply_moe_activation(activation_enum, out, x) return out # lazy import to avoid triggering triton import in CPU backend @@ -637,7 +638,6 @@ def apply( topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert layer.activation == "silu", "Only SiLU activation is supported." if layer.apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" @@ -652,7 +652,7 @@ def apply( topk_ids, layer.w13_qweight_type.weight_type, layer.w2_qweight_type.weight_type, - layer.activation, + layer.activation.value, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 570317ad3975..e0322a46f01a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -10,6 +10,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -936,7 +937,7 @@ def apply_monolithic( ) # TODO(rob): this validation should happen at kernel selection # time in the oracle rather than here. - assert layer.activation == "silu", ( + assert layer.activation == MoEActivation.SILU, ( f"Expected 'silu' activation but got {layer.activation}" ) assert not layer.renormalize @@ -965,7 +966,10 @@ def apply( # TODO(rob): this validation should happen at kernel selection # time in the oracle rather than here. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - assert layer.activation in ("silu", "relu2_no_mul"), ( + assert layer.activation in ( + MoEActivation.SILU, + MoEActivation.RELU2_NO_MUL, + ), ( "Expected activation to be in ('silu', 'relu2_no_mul')," f"but got {layer.activation}" ) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 4365d16935dc..f5c679840432 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,6 +6,7 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, @@ -371,7 +372,9 @@ def apply( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - assert layer.activation == "silu", "Only SiLU activation is supported." + assert layer.activation == MoEActivation.SILU, ( + f"Only SiLU activation is supported, not {layer.activation}." + ) return fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5cd6d5d79df3..5c6837e7afc0 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -13,6 +13,7 @@ FusedMoE, FusedMoEConfig, FusedMoEMethodBase, + MoEActivation, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( @@ -1141,8 +1142,9 @@ def apply_monolithic( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: - assert layer.activation == "swigluoai", ( - "Only swiglu_oai activation is supported for XPU MXFP4 MoE" + assert layer.activation == MoEActivation.SWIGLUOAI, ( + "Only swiglu_oai activation is supported for " + f"XPU MXFP4 MoE, not {layer.activation}." ) from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 7faa4fcc9ab8..555b94c1cea7 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -15,6 +15,7 @@ FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + MoEActivation, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, @@ -438,7 +439,7 @@ def apply( expert_map=layer.expert_map, ) elif self.use_marlin: - assert layer.activation == "silu", ( + assert layer.activation == MoEActivation.SILU, ( f"{layer.activation} not supported for Marlin MoE." ) return fused_marlin_moe( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index bbe2068006da..9d9fd31ad09d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -9,6 +9,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -64,9 +65,9 @@ def _supports_quant_scheme( return (weight_key, activation_key) in SUPPORTED_W_A -def _supports_activation(activation: str) -> bool: +def _supports_activation(activation: MoEActivation) -> bool: """Supports silu activation only.""" - return activation in ["silu"] + return activation in [MoEActivation.SILU] def _supports_routing_method( @@ -267,7 +268,7 @@ def flashinfer_trtllm_fp4_moe( x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, - activation: str, + activation: MoEActivation, global_num_experts: int, num_expert_group: int | None, topk_group: int | None, @@ -297,7 +298,7 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404 - assert activation == "silu", ( + assert activation == MoEActivation.SILU, ( "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. " f"{activation} found instead." ) @@ -365,7 +366,7 @@ def flashinfer_trtllm_fp4_routed_moe( topk_ids: torch.Tensor, topk_weights: torch.Tensor, top_k: int, - activation: str, + activation: MoEActivation, global_num_experts: int, ) -> torch.Tensor: """ @@ -387,7 +388,7 @@ def flashinfer_trtllm_fp4_routed_moe( import flashinfer # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535 - assert activation == "silu", ( + assert activation == MoEActivation.SILU, ( "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. " f"{activation} found instead." ) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index e9ecf0547033..9dbfc6ecad7b 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -6,6 +6,7 @@ import torch from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils.import_utils import has_triton_kernels @@ -88,7 +89,7 @@ def _can_support_mxfp4( e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, scoring_func: str = "softmax", - activation: str = "swigluoai", + activation: MoEActivation = MoEActivation.SWIGLUOAI, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, @@ -101,7 +102,7 @@ def _can_support_mxfp4( or e_score_correction_bias or apply_router_weight_on_input or scoring_func != "softmax" - or activation != "swigluoai" + or activation != MoEActivation.SWIGLUOAI or expert_load_view or logical_to_physical_map or logical_replica_count diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index a935071fc6fe..06141013c468 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -33,8 +33,11 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE -from vllm.model_executor.layers.fused_moe.utils import activation_without_mul +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + SharedFusedMoE, + activation_without_mul, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear,