diff --git a/tests/kernels/moe/test_triton_moe_no_act_mul.py b/tests/kernels/moe/test_triton_moe_no_act_mul.py new file mode 100644 index 000000000000..8e8cd7a237a1 --- /dev/null +++ b/tests/kernels/moe/test_triton_moe_no_act_mul.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test for is_act_and_mul=False MoE using Triton. + +This tests the code path used by models like Nemotron-H that use +non-fused activations (e.g., relu2_no_mul) instead of SwiGLU-style +fused activations. + +This feature is supported on both CUDA and ROCm (with AITER disabled). +""" + +import pytest +import torch + +from vllm.platforms import current_platform + +pytestmark = pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="Tests for is_act_and_mul=False MoE require CUDA or ROCm", +) + + +@pytest.fixture +def disable_aiter_on_rocm(monkeypatch): + """Fixture to disable AITER on ROCm to use Triton path.""" + if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "0") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "0") + rocm_aiter_ops.refresh_env_variables() + + yield + + rocm_aiter_ops.refresh_env_variables() + else: + # On CUDA, no special setup needed + yield + + +@pytest.fixture +def init_workspace(): + """Initialize workspace manager for MoE tests.""" + from vllm.v1.worker.workspace import ( + init_workspace_manager, + reset_workspace_manager, + ) + + torch.manual_seed(42) + init_workspace_manager(torch.cuda.current_device()) + + yield + + reset_workspace_manager() + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 256, 1024]) +@pytest.mark.parametrize("k", [128, 512]) +@pytest.mark.parametrize("e", [4, 8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("activation", ["relu2_no_mul", "silu_no_mul", "gelu_no_mul"]) +@torch.inference_mode() +def test_moe_no_act_mul( + disable_aiter_on_rocm, + init_workspace, + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + activation: str, +): + """Test MoE with is_act_and_mul=False using Triton.""" + from vllm.model_executor.layers.fused_moe import TritonExperts, fused_topk + from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig + from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel, + ) + from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, + ) + + a = torch.randn((m, k), device="cuda", dtype=dtype) + w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=True) + + fused_experts = FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts(quant_config), + ) + + output = fused_experts( + hidden_states=a, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + ) + + assert output.shape == (m, k), f"Expected shape {(m, k)}, got {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + assert output.abs().sum() > 0, "Output is all zeros" + + +@torch.inference_mode() +def test_moe_workspace_shapes_no_act_mul(disable_aiter_on_rocm): + """Test workspace_shapes returns correct sizes for is_act_and_mul=False.""" + from vllm.model_executor.layers.fused_moe import TritonExperts + from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig + + M, N, K, topk = 64, 256, 128, 2 + + quant_config = FusedMoEQuantConfig.make(is_act_and_mul=False) + experts = TritonExperts(quant_config) + ws1, ws2, out = experts.workspace_shapes(M, N, K, topk, 8, 8, None) + + assert ws1[2] == max(N, K) + assert out == (M, K) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 17d5ec4bcda7..4dd76b8f15d1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -201,6 +201,11 @@ class FusedMoEQuantConfig: _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc + # Whether activation is fused with gate multiplication (SwiGLU-style). + # When True: intermediate_size = N // 2 (gate and up are combined) + # When False: intermediate_size = N (no gate multiplication) + is_act_and_mul: bool = True + def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" @@ -435,6 +440,7 @@ def make( w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, + is_act_and_mul: bool = True, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -494,6 +500,7 @@ def make( _w2=FusedMoEQuantDesc( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), + is_act_and_mul=is_act_and_mul, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -806,6 +813,7 @@ def awq_marlin_moe_quant_config( def biased_moe_quant_config( w1_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, + is_act_and_mul: bool = True, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations with biases. @@ -815,6 +823,7 @@ def biased_moe_quant_config( _a2=FusedMoEQuantDesc(), _w1=FusedMoEQuantDesc(bias=w1_bias), _w2=FusedMoEQuantDesc(bias=w2_bias), + is_act_and_mul=is_act_and_mul, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7fd8511e297d..b4d3a6c16a93 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -871,8 +871,11 @@ def workspace_shapes( num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens + # For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2 + # For non-fused activations: N = intermediate, after act = N + intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + workspace2 = (num_experts, max_num_tokens * num_dp, intermediate_size) output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output) @@ -947,7 +950,11 @@ def apply( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) + # For fused activations (SwiGLU): output is N/2, for non-fused: output is N + intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N + intermediate_cache2 = _resize_cache( + workspace2, (E, max_num_tokens, intermediate_size) + ) # TODO(bnell): should this be done for any quantized type? if self.quant_config.use_fp8_w8a8: @@ -978,7 +985,7 @@ def apply( # TODO (bnell): use triton utility from batched deep gemm. self.activation( activation, - intermediate_cache2.view(-1, N // 2), + intermediate_cache2.view(-1, intermediate_size), intermediate_cache1.view(-1, N), ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1e2f96c86a55..cdff1b700c59 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2292,7 +2292,10 @@ def workspace_shapes( local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - workspace1 = (M, topk, max(N // 2, K)) + # For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2 + # For non-fused activations: N = intermediate, after act = N + intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N + workspace1 = (M, topk, max(intermediate_size, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) return (workspace1, workspace2, output) @@ -2367,8 +2370,10 @@ def apply( # Note that the output tensor might be in workspace1 intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + # For fused activations (SwiGLU): output is N/2, for non-fused: output is N + intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N intermediate_cache2 = _resize_cache( - workspace13, (num_tokens * top_k_num, N // 2) + workspace13, (num_tokens * top_k_num, intermediate_size) ) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fb441963a97d..cf1b3e1b8626 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -603,9 +603,15 @@ def _get_quant_method() -> FusedMoEMethodBase: "is_act_and_mul=False is supported only for unquantized " ", ModelOpt FP8, and ModelOpt NvFp4 checkpoints" ) - if not current_platform.is_cuda(): + # ROCm without AITER MoE uses Triton which supports + # is_act_and_mul=False via standard PyTorch ops (F.silu, F.gelu) + rocm_without_aiter_moe = ( + current_platform.is_rocm() and not rocm_aiter_ops.is_fused_moe_enabled() + ) + if not current_platform.is_cuda() and not rocm_without_aiter_moe: raise NotImplementedError( - "is_act_and_mul=False is supported only for CUDA for now" + "is_act_and_mul=False is supported only for CUDA, or ROCm " + "(when AITER MoE is disabled) for now" ) if self.enable_eplb and not self.quant_method.supports_eplb: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 79168948f04a..117359a8f916 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum -from math import prod +from math import prod, sqrt from typing import final import torch @@ -575,14 +575,35 @@ def workspace_shapes( def activation( self, activation: str, output: torch.Tensor, input: torch.Tensor ) -> None: - assert output.size(-1) * 2 == input.size(-1) + # Fused activations (SwiGLU-style): output is half the size of input if activation == "silu": + assert output.size(-1) * 2 == input.size(-1) torch.ops._C.silu_and_mul(output, input) elif activation == "gelu": + assert output.size(-1) * 2 == input.size(-1) torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 + assert output.size(-1) * 2 == input.size(-1) torch.ops._C.swigluoai_and_mul(output, input) + # Non-fused activations (is_act_and_mul=False): output same size as input + elif activation == "silu_no_mul": + assert output.size(-1) == input.size(-1) + # Use out= argument to avoid intermediate tensor + torch.sigmoid(input, out=output) + output.mul_(input) + elif activation == "gelu_no_mul": + assert output.size(-1) == input.size(-1) + # GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + # Use out= and in-place ops to avoid intermediate tensors + output.copy_(input).div_(sqrt(2)) + torch.erf(output, out=output) + output.add_(1).mul_(input).mul_(0.5) + elif activation == "relu2_no_mul": + assert output.size(-1) == input.size(-1) + # ReLU²: clamp has out=, then in-place square + torch.clamp(input, min=0, out=output) + output.square_() else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index ca6b90d6b31c..92ef850205fc 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -299,7 +299,12 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon return biased_moe_quant_config( layer.w13_bias, layer.w2_bias, + is_act_and_mul=self.moe.is_act_and_mul, ) + elif not self.moe.is_act_and_mul: + # Create a config with is_act_and_mul=False since + # FUSED_MOE_UNQUANTIZED_CONFIG has is_act_and_mul=True + return FusedMoEQuantConfig.make(is_act_and_mul=False) else: return FUSED_MOE_UNQUANTIZED_CONFIG