diff --git a/tests/kernels/moe/test_triton_moe_no_act_mul.py b/tests/kernels/moe/test_triton_moe_no_act_mul.py new file mode 100644 index 000000000000..12d5180f964e --- /dev/null +++ b/tests/kernels/moe/test_triton_moe_no_act_mul.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MoE with non-gated activations (*_no_mul). + +These tests verify that MoE layers work correctly with activations like +silu_no_mul, gelu_no_mul, relu2_no_mul where the activation output dimension +equals N (not N // 2 like gated activations). +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, +) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.utils import ( + GELU_NO_MUL, + RELU2_NO_MUL, + SILU_NO_MUL, +) +from vllm.platforms import current_platform + +# Test parameters +M_SIZES = [1, 16, 64] +N_SIZES = [128, 256] +K_SIZES = [64, 128] +TOPK_VALUES = [1, 2] +NUM_EXPERTS = 8 +NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL] + + +def make_test_tensors( + m: int, + n: int, + k: int, + num_experts: int, + topk: int, + dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", +): + """Create test tensors for MoE with non-gated activation. + + For non-gated activations (*_no_mul): + - w1: (E, N, K) - projects from K to N + - w2: (E, K, N) - projects from N back to K (note: N, not N//2) + """ + hidden_states = torch.randn(m, k, dtype=dtype, device=device) + + # For non-gated: w1 projects K -> N, w2 projects N -> K + w1 = torch.randn(num_experts, n, k, dtype=dtype, device=device) * 0.1 + w2 = torch.randn(num_experts, k, n, dtype=dtype, device=device) * 0.1 + + topk_weights = torch.ones(m, topk, dtype=torch.float32, device=device) / topk + topk_ids = torch.randint(0, num_experts, (m, topk), device=device) + + return hidden_states, w1, w2, topk_weights, topk_ids + + +@pytest.mark.skipif( + not current_platform.has_device_capability(80), + reason="Requires compute capability >= 8.0", +) +@pytest.mark.parametrize("m", M_SIZES) +@pytest.mark.parametrize("n", N_SIZES) +@pytest.mark.parametrize("k", K_SIZES) +@pytest.mark.parametrize("topk", TOPK_VALUES) +@pytest.mark.parametrize("activation", NO_MUL_ACTIVATIONS) +@torch.inference_mode() +def test_triton_experts_no_mul_activation( + m: int, + n: int, + k: int, + topk: int, + activation: str, +): + hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors( + m, n, k, NUM_EXPERTS, topk + ) + + experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + + ws1_shape, ws2_shape, out_shape = experts.workspace_shapes( + M=m, + N=n, + K=k, + topk=topk, + global_num_experts=NUM_EXPERTS, + local_num_experts=NUM_EXPERTS, + expert_tokens_meta=None, + activation=activation, + ) + + # Verify workspace shapes are correct for no_mul activation + # workspace1 should handle activation_out_dim = N (not N//2) + assert ws1_shape == (m, topk, max(n, k)), ( + f"workspace1 shape mismatch: expected {(m, topk, max(n, k))}, got {ws1_shape}" + ) + # workspace2 should handle max(N, K) for intermediate_cache1/cache3 + assert ws2_shape == (m, topk, max(n, k)), ( + f"workspace2 shape mismatch: expected {(m, topk, max(n, k))}, got {ws2_shape}" + ) + assert out_shape == (m, k), ( + f"output shape mismatch: expected {(m, k)}, got {out_shape}" + ) + + workspace1 = torch.empty( + ws1_shape[0] * ws1_shape[1] * ws1_shape[2], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + workspace2 = torch.empty( + ws2_shape[0] * ws2_shape[1] * ws2_shape[2], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + output = torch.zeros(m, k, dtype=hidden_states.dtype, device=hidden_states.device) + + experts.apply( + output=output, + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=NUM_EXPERTS, + expert_map=None, + a1q_scale=None, + a2_scale=None, + workspace13=workspace1, + workspace2=workspace2, + expert_tokens_meta=None, + apply_router_weight_on_input=False, + ) + + assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + assert output.abs().sum() > 0, "Output is all zeros" + + +@pytest.mark.skipif( + not current_platform.has_device_capability(80), + reason="Requires compute capability >= 8.0", +) +@torch.inference_mode() +def test_workspace_shapes_no_mul_vs_gated(): + """Test that workspace shapes differ correctly between gated and non-gated.""" + from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + M, N, K, topk = 64, 256, 128, 2 + + experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + + ws1_no_mul, _, out_no_mul = experts.workspace_shapes( + M, N, K, topk, 8, 8, None, SILU_NO_MUL + ) + + ws1_gated, _, out_gated = experts.workspace_shapes( + M, N, K, topk, 8, 8, None, "silu" + ) + + # For no_mul: activation_out_dim = N + # For gated: activation_out_dim = N // 2 + # workspace1 should use max(activation_out_dim, K) + activation_out_dim_no_mul = N + activation_out_dim_gated = N // 2 + + assert ws1_no_mul[2] == max(activation_out_dim_no_mul, K), ( + f"no_mul workspace1 last dim should be max({activation_out_dim_no_mul}, {K})" + ) + assert ws1_gated[2] == max(activation_out_dim_gated, K), ( + f"gated workspace1 last dim should be max({activation_out_dim_gated}, {K})" + ) + + # Output shapes should be the same + assert out_no_mul == out_gated == (M, K) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(80), + reason="Requires compute capability >= 8.0", +) +@torch.inference_mode() +def test_adjust_n_for_activation(): + """Test the adjust_N_for_activation method.""" + from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + experts = TritonExperts(FUSED_MOE_UNQUANTIZED_CONFIG) + + N = 256 + + # Gated activations should return N // 2 + assert experts.adjust_N_for_activation(N, "silu") == N // 2 + assert experts.adjust_N_for_activation(N, "gelu") == N // 2 + + # Non-gated activations should return N + assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N + assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N + assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 15f6e3a18ed6..e598ec3acb3d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -305,6 +305,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks @@ -312,8 +313,9 @@ def workspace_shapes( num_dispatchers = self.num_dispatchers num_experts = local_num_experts max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) + workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim) output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index fdac768da8f9..c0ffa38fdb2c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -355,9 +355,11 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, max(N // 2, K)) + workspace2 = (M * topk, max(activation_out_dim, K)) output = (M, K) return (workspace1, workspace2, output) @@ -402,11 +404,17 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) + workspace2 = ( + self.max_experts_per_worker, + M * num_dp, + max(activation_out_dim, K), + ) output = (self.max_experts_per_worker, M, K) return (workspace1, workspace2, output) @@ -635,13 +643,15 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: workspace1 = (self.max_experts_per_worker, M, max(N, K)) - workspace2 = (self.max_experts_per_worker, M, (N // 2)) + workspace2 = (self.max_experts_per_worker, M, activation_out_dim) output = (self.max_experts_per_worker, M, K) else: workspace1 = (M * topk, max(2 * N, K)) @@ -896,9 +906,11 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk, max(N // 2, K)) + workspace2 = (M * topk, max(activation_out_dim, K)) output = (M, K) return (workspace1, workspace2, output) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 5ca91768c976..a2e5a07fbfd2 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -143,6 +143,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.block_shape is not None block_m = self.block_shape[0] @@ -151,7 +152,8 @@ def workspace_shapes( ) assert M_sum % block_m == 0 - workspace1 = (M_sum, max(N // 2, K)) + activation_out_dim = self.adjust_N_for_activation(N, activation) + workspace1 = (M_sum, max(activation_out_dim, K)) workspace2 = (M_sum, max(N, K)) output = (M, K) return (workspace1, workspace2, output) @@ -163,11 +165,13 @@ def _act_mul_quant( block_k = self.block_shape[1] scale_fmt = DeepGemmQuantScaleFMT.from_oracle() + M_sum, N = input.size() + activation_out_dim = self.adjust_N_for_activation(N, activation) + # 1. DeepGemm UE8M0: use packed per-token-group quant if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: - M_sum, N = input.size() act_out = torch.empty( - (M_sum, N // 2), dtype=input.dtype, device=input.device + (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) self.activation(activation, act_out, input) a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( @@ -187,8 +191,9 @@ def _act_mul_quant( ) # 3. fallback path for non-SiLU activations in non‑UE8M0 cases. - M_sum, N = input.size() - act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device) + act_out = torch.empty( + (M_sum, activation_out_dim), dtype=input.dtype, device=input.device + ) self.activation(activation, act_out, input) return per_token_group_quant_fp8( act_out, block_k, column_major_scales=True, out_q=output @@ -254,8 +259,9 @@ def apply( (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids ) + activation_out_dim = self.adjust_N_for_activation(N, activation) quant_out = _resize_cache( - workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim) ) a2q, a2q_scale = self._act_mul_quant( input=mm1_out.view(-1, N), output=quant_out, activation=activation diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 14ef6b9aaa5e..4556392144a0 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -76,6 +76,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: raise NotImplementedError diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index ce93ae235f27..1651f3530eef 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -91,6 +91,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 09c3d9b2190f..ae60e15db841 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -103,6 +103,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7fd8511e297d..fb93464392ea 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -673,6 +673,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts @@ -867,12 +868,14 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + workspace2 = (num_experts, max_num_tokens * num_dp, activation_out_dim) output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output) @@ -947,7 +950,10 @@ def apply( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) + activation_out_dim = self.adjust_N_for_activation(N, activation) + intermediate_cache2 = _resize_cache( + workspace2, (E, max_num_tokens, activation_out_dim) + ) # TODO(bnell): should this be done for any quantized type? if self.quant_config.use_fp8_w8a8: @@ -978,7 +984,7 @@ def apply( # TODO (bnell): use triton utility from batched deep gemm. self.activation( activation, - intermediate_cache2.view(-1, N // 2), + intermediate_cache2.view(-1, activation_out_dim), intermediate_cache1.view(-1, N), ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e82a838959de..77c6b97eaea3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -640,6 +640,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Modular Kernel provisions output buffer from workspace1. However in # the fused_marlin_moe() function, the final torch.sum(), is defined @@ -768,6 +769,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dispatchers = self.num_dispatchers num_experts = local_num_experts diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d1b25aa92e8d..afc546a239b8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,7 +9,6 @@ from typing import Any import torch -import torch.nn.functional as F import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -43,7 +42,7 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, - activation_without_mul, + apply_moe_activation, disable_inplace, moe_kernel_quantize_input, ) @@ -1957,11 +1956,6 @@ def fused_experts( ) -SILU_NO_MUL: str = activation_without_mul("silu") -GELU_NO_MUL: str = activation_without_mul("gelu") -RELU2_NO_MUL: str = activation_without_mul("relu2") - - def _get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, @@ -2094,8 +2088,13 @@ def fused_experts_impl( intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 + activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( + N, activation + ) intermediate_cache2 = torch.empty( - (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype + (M * top_k_num, activation_out_dim), + device=hidden_states.device, + dtype=hidden_states.dtype, ) if hidden_states.dtype == torch.bfloat16: @@ -2235,29 +2234,9 @@ def fused_experts_impl( B_bias=w1_bias, ) - # Activation function with multiplication - if activation == "silu": - torch.ops._C.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, N) - ) - elif activation == "gelu": - torch.ops._C.gelu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, N) - ) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, N) - ) - # Activation function without multiplication - elif activation == SILU_NO_MUL: - intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) - elif activation == GELU_NO_MUL: - intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - elif activation == RELU2_NO_MUL: - intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N))) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}.") + apply_moe_activation( + activation, intermediate_cache2, intermediate_cache1.view(-1, N) + ) qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, @@ -2336,8 +2315,10 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - workspace1 = (M, topk, max(N // 2, K)) + activation_out_dim = self.adjust_N_for_activation(N, activation) + workspace1 = (M, topk, max(activation_out_dim, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) return (workspace1, workspace2, output) @@ -2412,8 +2393,9 @@ def apply( # Note that the output tensor might be in workspace1 intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + cache2_dim = self.adjust_N_for_activation(N, activation) intermediate_cache2 = _resize_cache( - workspace13, (num_tokens * top_k_num, N // 2) + workspace13, (num_tokens * top_k_num, cache2_dim) ) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) @@ -2565,8 +2547,9 @@ def apply( # Note that the output tensor might be in workspace1 intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + activation_out_dim = self.adjust_N_for_activation(N, activation) intermediate_cache2 = _resize_cache( - workspace13, (num_tokens * top_k_num, N // 2) + workspace13, (num_tokens * top_k_num, activation_out_dim) ) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index dff8a9f3a8f0..c4bc1824aa1f 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -323,10 +323,12 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel + activation_out_dim = self.adjust_N_for_activation(N, activation) workspace1 = (0, 0) - workspace2 = (M * topk, N // 2) + workspace2 = (M * topk, activation_out_dim) output = (M, K) return (workspace1, workspace2, output) @@ -415,9 +417,11 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel - workspace1 = (M * topk, N // 2) + activation_out_dim = self.adjust_N_for_activation(N, activation) + workspace1 = (M * topk, activation_out_dim) workspace2 = (M * topk, max(N, K)) output = (M, K) return (workspace1, workspace2, output) @@ -443,8 +447,10 @@ def apply( expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): - if self.quant_config is None: - self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + # Use local variable to help mypy narrow the type after None check + quant_config = self.quant_config + if quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG if expert_map is not None: topk_ids = expert_map[topk_ids] @@ -462,12 +468,10 @@ def apply( # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 assert ( - self.quant_config.w1_bias is None - or self.quant_config.w1_bias.dtype == torch.float32 + quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 ) assert ( - self.quant_config.w2_bias is None - or self.quant_config.w2_bias.dtype == torch.float32 + quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 ) # Shape check, only check non-mxfp4 @@ -485,17 +489,18 @@ def apply( # Note that the output tensor might be in workspace13 intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) - intermediate_cache2 = _resize_cache(workspace13, (M * topk, N // 2)) + activation_out_dim = self.adjust_N_for_activation(N, activation) + intermediate_cache2 = _resize_cache(workspace13, (M * topk, activation_out_dim)) gammas = routing_data.gate_scal if routing_data else None matmul_ogs( hidden_states, w1, - self.quant_config.w1_bias, + quant_config.w1_bias, routing_data, gather_indx=gather_indx, - precision_config=self.quant_config.w1_precision, + precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, fused_activation=None, y=intermediate_cache1, @@ -515,10 +520,10 @@ def apply( matmul_ogs( intermediate_cache2[gather_indx.src_indx], w2, - self.quant_config.w2_bias, + quant_config.w2_bias, routing_data, scatter_indx=scatter_indx, - precision_config=self.quant_config.w2_precision, + precision_config=quant_config.w2_precision, gammas=None if apply_router_weight_on_input else gammas, y=intermediate_cache3, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 79168948f04a..a6df2b20af9c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -18,6 +18,7 @@ ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, + apply_moe_activation, count_expert_num_tokens, disable_inplace, ) @@ -542,6 +543,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Compute the shapes for the temporary and final outputs of the two gemms @@ -572,19 +574,31 @@ def workspace_shapes( """ raise NotImplementedError + @staticmethod + def adjust_N_for_activation(N: int, activation: str) -> int: + """ + Calculate the output dimension for the activation function. + + For *_no_mul activations (e.g. relu2_no_mul), + there's no gate/up split, so output size equals input size (N). + + For regular gated activations (e.g., silu, gelu, swigluoai), + output size is N // 2 due to gate × activation(up) multiplication. + + Args: + N: The intermediate size (width of w1/w3 weights). + activation: The activation function name. + + Returns: + The output dimension after activation. + """ + is_no_mul = activation.endswith("_no_mul") + return N if is_no_mul else N // 2 + def activation( self, activation: str, output: torch.Tensor, input: torch.Tensor ) -> None: - assert output.size(-1) * 2 == input.size(-1) - if activation == "silu": - torch.ops._C.silu_and_mul(output, input) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(output, input) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(output, input) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + apply_moe_activation(activation, output, input) def enable_chunking(self): return ( @@ -761,6 +775,7 @@ def _allocate_buffers( global_num_experts: int, local_num_experts: int, expert_tokens_meta: ExpertTokensMetadata | None, + activation: str, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Allocate temporary and output buffers for the fused experts op. @@ -796,6 +811,7 @@ def _allocate_buffers( # amount of workspace. Mark it None, so we allocate for # the worst-case scenario. expert_tokens_meta=None, + activation=activation, ) ) @@ -814,6 +830,7 @@ def _allocate_buffers( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) # Get final output shape based on the full M size. @@ -825,6 +842,7 @@ def _allocate_buffers( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) # We can reuse the memory between cache1 and cache3 because by the @@ -1043,6 +1061,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]: global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) for chunk_idx in range(num_chunks): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 06707e5e4892..b78794c6bd83 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -299,6 +299,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Workspaces are managed internally by AITER. workspace1 = (0,) diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index e874ba609be0..09d5e45c1ec2 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -39,6 +39,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Small batch fallback for sm100. if self.is_sm100 and M <= 8: @@ -50,6 +51,7 @@ def workspace_shapes( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) else: return self.experts.workspace_shapes( @@ -60,6 +62,7 @@ def workspace_shapes( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) def _select_experts_impl( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 4fcc1a7c1fc0..55b1e1211b0a 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -35,6 +35,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm @@ -48,6 +49,7 @@ def workspace_shapes( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) else: return self.fallback_experts.workspace_shapes( @@ -58,6 +60,7 @@ def workspace_shapes( global_num_experts, local_num_experts, expert_tokens_meta, + activation, ) def _select_experts_impl( diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 132d35e65aba..c46f59564930 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -57,6 +57,7 @@ def workspace_shapes( global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # The workspaces for this implementation are managed by flashinfer. workspace1 = (0,) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1f946d67a8f5..e74b4fd212b3 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -4,6 +4,7 @@ from math import prod import torch +import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -324,6 +325,55 @@ def activation_without_mul(activation: str) -> str: return activation + "_no_mul" +RELU2_NO_MUL: str = activation_without_mul("relu2") +SILU_NO_MUL: str = activation_without_mul("silu") +GELU_NO_MUL: str = activation_without_mul("gelu") + + +def apply_moe_activation( + activation: str, + output: torch.Tensor, + input: torch.Tensor, +) -> torch.Tensor: + """ + Apply MoE activation function. + + For *_and_mul activations (silu, gelu, swigluoai): + - Expects output.size(-1) * 2 == input.size(-1) + + For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): + - Expects output.size(-1) == input.size(-1) + """ + is_no_mul = activation.endswith("_no_mul") + if is_no_mul: + assert output.size(-1) == input.size(-1), ( + f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}" + ) + else: + assert output.size(-1) * 2 == input.size(-1), ( + f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}" + ) + + # Activations with gated multiplication (gate × activation(up)) + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + elif activation == "swigluoai": + torch.ops._C.swigluoai_and_mul(output, input) + # Activations without gated multiplication + elif activation == SILU_NO_MUL: + output.copy_(F.silu(input)) + elif activation == GELU_NO_MUL: + output.copy_(F.gelu(input)) + elif activation == RELU2_NO_MUL: + torch.square(F.relu(input), out=output) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + return output + + # Torch custom ops can't deal with outputs aliasing inputs so we need to # disable inplace for torch >= 2.9. # See https://github.com/vllm-project/vllm/issues/26378