Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions csrc/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ inline T gelu_quick_kernel(const T& x) {
return (T)(((float)x) / (1.0f + (T)sycl::exp(-1.702f * (float)x)));
}

template <typename T>
inline T relu2_no_mul_kernel(const T& x) {
// square(relu(x))
const float f = (float)x;
const float r = f > 0.0f ? f : 0.0f;
return (T)(r * r);
}

template <typename T>
inline T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
Expand Down Expand Up @@ -440,6 +448,15 @@ void gelu_quick(
});
}

void relu2_no_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "relu2_no_mul", [&] {
LAUNCH_ACTIVATION_KERNEL(vllm::relu2_no_mul_kernel);
});
}

void swigluoai_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ void swigluoai_and_mul(
double alpha = 1.702,
double limit = 7.0);

void relu2_no_mul(torch::Tensor& out, torch::Tensor& input);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: vllm repo doesn't have this cuda kernel yet. I prefer to make this into torch.ops._xpu_C as this is a xpu specifc kernel. (though it will not be used in vllm side yet).
keep it here is fine.


void swiglustep_and_mul(
torch::Tensor& out, torch::Tensor& input, double limit = 7.0);

Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> ()");
ops.impl("swigluoai_and_mul", torch::kXPU, &swigluoai_and_mul);

// relu2_no_mul
ops.def("relu2_no_mul(Tensor! out, Tensor! input) -> ()");
ops.impl("relu2_no_mul", torch::kXPU, &relu2_no_mul);

// swiglustep_and_mul
ops.def(
"swiglustep_and_mul(Tensor! out, Tensor input, float limit=7.0) "
Expand Down
24 changes: 24 additions & 0 deletions tests/ops/activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,27 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out


class Relu2NoMul(CustomOp):
"""Squared ReLU activation function (without mul).

The function computes x -> relu(x)^2.

Shapes:
x: (num_tokens, d) or (batch_size, seq_len, d)
return: same shape as x
"""

def __init__(self):
super().__init__()
self.op = torch.ops._C.relu2_no_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
5 changes: 5 additions & 0 deletions tests/register_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ def swigluoai_and_mul(
torch.ops._C.swigluoai_and_mul(out, input, alpha, limit)


def relu2_no_mul(out: torch.Tensor, input: torch.Tensor) -> None:
"""Relu2 (squared ReLU) activation function without mul."""
torch.ops._C.relu2_no_mul(out, input)


def swiglustep_and_mul(
out: torch.Tensor,
input: torch.Tensor,
Expand Down
10 changes: 6 additions & 4 deletions tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.allclose_default import get_default_atol, get_default_rtol
from tests.ops.activation_op import (FastGELU, GeluAndMul, MulAndSilu, NewGELU,
QuickGELU, SiluAndMul)
QuickGELU, Relu2NoMul, SiluAndMul)
from tests.utils import opcheck, seed_everything

DTYPES = [torch.half, torch.bfloat16, torch.float]
Expand Down Expand Up @@ -66,9 +66,11 @@ def test_act_and_mul(
opcheck(fn, (out, x))


@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick)])
@pytest.mark.parametrize("activation",
[(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick),
(Relu2NoMul, torch.ops._C.relu2_no_mul)])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
Expand Down
7 changes: 5 additions & 2 deletions vllm_xpu_kernels/fused_moe_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def xpu_fused_moe(hidden_states,
is_B_int4=is_int4,
is_B_mxfp4=is_mxfp4)

inter_size_scale = 2 if activation == "relu2_no_mul" else 1
# act
act_output = torch.empty((num_moe_inputs, inter_size),
act_output = torch.empty((num_moe_inputs, inter_size * inter_size_scale),
dtype=gemm1_output.dtype,
device=gemm1_output.device)
if activation == "silu":
Expand All @@ -256,6 +257,8 @@ def xpu_fused_moe(hidden_states,
torch.ops._C.gelu_and_mul(act_output, gemm1_output)
elif activation == "swigluoai" or ("SWIGLUOAI" in str(activation)):
torch.ops._C.swigluoai_and_mul(act_output, gemm1_output, 1.702, 7.0)
elif activation == "relu2_no_mul":
torch.ops._C.relu2_no_mul(act_output, gemm1_output)
elif activation == "swiglustep":
torch.ops._C.swiglustep_and_mul(act_output, gemm1_output, 7.0)
else:
Expand All @@ -276,7 +279,7 @@ def xpu_fused_moe(hidden_states,
ptr_D=gemm2_output,
expert_first_token_offset=expert_first_token_offset,
N=hidden_size,
K=inter_size,
K=inter_size * inter_size_scale,
num_experts=num_experts,
Comment on lines +282 to 283
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K is derived from inter_size * inter_size_scale, but cutlass_grouped_gemm_interface does not validate tensor shapes and will trust the provided K. For relu2_no_mul, this makes correctness/safety depend on w2 actually having the matching K-dimension (2*inter_size for non-int4 layouts). Please add an explicit shape assertion for w2 (and possibly w13) for this activation to prevent out-of-bounds reads if a caller passes incompatible weights.

Copilot uses AI. Check for mistakes.
is_B_int4=is_int4,
is_B_mxfp4=is_mxfp4)
Expand Down
Loading