diff --git a/csrc/activation.cpp b/csrc/activation.cpp index 33c4f28ca..9bef17807 100644 --- a/csrc/activation.cpp +++ b/csrc/activation.cpp @@ -36,6 +36,14 @@ inline T gelu_quick_kernel(const T& x) { return (T)(((float)x) / (1.0f + (T)sycl::exp(-1.702f * (float)x))); } +template +inline T relu2_no_mul_kernel(const T& x) { + // square(relu(x)) + const float f = (float)x; + const float r = f > 0.0f ? f : 0.0f; + return (T)(r * r); +} + template inline T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. @@ -440,6 +448,15 @@ void gelu_quick( }); } +void relu2_no_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu2_no_mul", [&] { + LAUNCH_ACTIVATION_KERNEL(vllm::relu2_no_mul_kernel); + }); +} + void swigluoai_and_mul( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] diff --git a/csrc/ops.h b/csrc/ops.h index 1e0416671..25d38e67d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -138,6 +138,8 @@ void swigluoai_and_mul( double alpha = 1.702, double limit = 7.0); +void relu2_no_mul(torch::Tensor& out, torch::Tensor& input); + void swiglustep_and_mul( torch::Tensor& out, torch::Tensor& input, double limit = 7.0); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2a5727829..04e3480fe 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -107,6 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> ()"); ops.impl("swigluoai_and_mul", torch::kXPU, &swigluoai_and_mul); + // relu2_no_mul + ops.def("relu2_no_mul(Tensor! out, Tensor! input) -> ()"); + ops.impl("relu2_no_mul", torch::kXPU, &relu2_no_mul); + // swiglustep_and_mul ops.def( "swiglustep_and_mul(Tensor! out, Tensor input, float limit=7.0) " diff --git a/tests/ops/activation_op.py b/tests/ops/activation_op.py index e691de64c..02b8cf8fd 100644 --- a/tests/ops/activation_op.py +++ b/tests/ops/activation_op.py @@ -144,3 +144,27 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) self.op(out, x) return out + + +class Relu2NoMul(CustomOp): + """Squared ReLU activation function (without mul). + + The function computes x -> relu(x)^2. + + Shapes: + x: (num_tokens, d) or (batch_size, seq_len, d) + return: same shape as x + """ + + def __init__(self): + super().__init__() + self.op = torch.ops._C.relu2_no_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return torch.square(F.relu(x)) + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out diff --git a/tests/register_ops.py b/tests/register_ops.py index 5293f7731..0c965d441 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -283,6 +283,11 @@ def swigluoai_and_mul( torch.ops._C.swigluoai_and_mul(out, input, alpha, limit) +def relu2_no_mul(out: torch.Tensor, input: torch.Tensor) -> None: + """Relu2 (squared ReLU) activation function without mul.""" + torch.ops._C.relu2_no_mul(out, input) + + def swiglustep_and_mul( out: torch.Tensor, input: torch.Tensor, diff --git a/tests/test_activation.py b/tests/test_activation.py index 5bbb81611..a2fffc397 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -4,7 +4,7 @@ from tests.allclose_default import get_default_atol, get_default_rtol from tests.ops.activation_op import (FastGELU, GeluAndMul, MulAndSilu, NewGELU, - QuickGELU, SiluAndMul) + QuickGELU, Relu2NoMul, SiluAndMul) from tests.utils import opcheck, seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -66,9 +66,11 @@ def test_act_and_mul( opcheck(fn, (out, x)) -@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), - (NewGELU, torch.ops._C.gelu_new), - (QuickGELU, torch.ops._C.gelu_quick)]) +@pytest.mark.parametrize("activation", + [(FastGELU, torch.ops._C.gelu_fast), + (NewGELU, torch.ops._C.gelu_new), + (QuickGELU, torch.ops._C.gelu_quick), + (Relu2NoMul, torch.ops._C.relu2_no_mul)]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm_xpu_kernels/fused_moe_interface.py b/vllm_xpu_kernels/fused_moe_interface.py index 7ba0a47d4..5c807afa4 100755 --- a/vllm_xpu_kernels/fused_moe_interface.py +++ b/vllm_xpu_kernels/fused_moe_interface.py @@ -246,8 +246,9 @@ def xpu_fused_moe(hidden_states, is_B_int4=is_int4, is_B_mxfp4=is_mxfp4) + inter_size_scale = 2 if activation == "relu2_no_mul" else 1 # act - act_output = torch.empty((num_moe_inputs, inter_size), + act_output = torch.empty((num_moe_inputs, inter_size * inter_size_scale), dtype=gemm1_output.dtype, device=gemm1_output.device) if activation == "silu": @@ -256,6 +257,8 @@ def xpu_fused_moe(hidden_states, torch.ops._C.gelu_and_mul(act_output, gemm1_output) elif activation == "swigluoai" or ("SWIGLUOAI" in str(activation)): torch.ops._C.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0) + elif activation == "relu2_no_mul": + torch.ops._C.relu2_no_mul(act_output, gemm1_output) elif activation == "swiglustep": torch.ops._C.swiglustep_and_mul(act_output, gemm1_output, 7.0) else: @@ -276,7 +279,7 @@ def xpu_fused_moe(hidden_states, ptr_D=gemm2_output, expert_first_token_offset=expert_first_token_offset, N=hidden_size, - K=inter_size, + K=inter_size * inter_size_scale, num_experts=num_experts, is_B_int4=is_int4, is_B_mxfp4=is_mxfp4)