Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 82 additions & 25 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,74 @@
namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
const scalar_t& y,
const float limit) {
if constexpr (act_first) {
scalar_t gate = x;
scalar_t up = y;
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fminf((float)gate, limit);
up = (scalar_t)fmaxf(fminf((float)up, limit), -limit);
}
return ACT_FN(gate) * up;
} else {
scalar_t gate = x;
scalar_t up = y;
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fmaxf(fminf((float)gate, limit), -limit);
up = (scalar_t)fminf((float)up, limit);
}
return gate * ACT_FN(up);
}
Comment on lines +18 to +34
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of clamping for the activated branch is currently pre-activation (e.g., ACT_FN(fminf(gate, limit))), which contradicts the comment on line 201 stating it should be "post-activation". For SiLU, these operations are not mathematically equivalent. DeepSeek architectures typically apply clamping to the output of the activation function for numerical stability. Please align the implementation with the intended post-activation clamping or update the comment if pre-activation was indeed intended.

}

template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ packed_t packed_compute(const packed_t& x,
const packed_t& y) {
return act_first ? packed_mul(PACKED_ACT_FN(x), y)
: packed_mul(x, PACKED_ACT_FN(y));
const packed_t& y,
const float limit) {
if constexpr (act_first) {
packed_t gate = x;
packed_t up = y;
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fminf(g.x, limit);
g.y = fminf(g.y, limit);
u.x = fmaxf(fminf(u.x, limit), -limit);
u.y = fmaxf(fminf(u.y, limit), -limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(PACKED_ACT_FN(gate), up);
} else {
packed_t gate = x;
packed_t up = y;
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fmaxf(fminf(g.x, limit), -limit);
g.y = fmaxf(fminf(g.y, limit), -limit);
u.x = fminf(u.x, limit);
u.y = fminf(u.y, limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(gate, PACKED_ACT_FN(up));
}
}

// Activation and gating kernel template.
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&),
packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
bool use_vec, bool use_256b = false>
bool use_vec, bool HAS_CLAMP, bool use_256b = false>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int d, const float limit) {
const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + blockIdx.x * d;
Expand All @@ -58,8 +103,9 @@ __global__ void act_and_mul_kernel(
}
#pragma unroll
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
x.elts[j], y.elts[j]);
x.elts[j] =
packed_compute<packed_t, PACKED_ACT_FN, act_first, HAS_CLAMP>(
x.elts[j], y.elts[j], limit);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
Expand All @@ -72,7 +118,8 @@ __global__ void act_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
out_ptr[idx] =
compute<scalar_t, ACT_FN, act_first, HAS_CLAMP>(x, y, limit);
}
}
}
Expand Down Expand Up @@ -151,8 +198,11 @@ packed_gelu_tanh_kernel(const packed_t& val) {

// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
// first. HAS_CLAMP (bool) enables pre-activation clamping: gate input is
// clamped (max only) and up input is clamped (both sides) before the
// activation function is applied.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST, \
HAS_CLAMP, LIMIT) \
auto dtype = input.scalar_type(); \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
Expand All @@ -177,17 +227,17 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} \
} else { \
Expand All @@ -197,16 +247,23 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, false, HAS_CLAMP><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
}

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true);
true, false, 0.0f);
}

void silu_and_mul_clamp(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
double limit) {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, true, (float)limit);
}

void mul_and_silu(torch::Tensor& out, // [..., d]
Expand All @@ -215,21 +272,21 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
false);
false, false, 0.0f);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true);
true, false, 0.0f);
}

void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
vllm::packed_gelu_tanh_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(
vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f);
}

namespace vllm {
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void silu_and_mul_clamp(torch::Tensor& out, torch::Tensor& input, double limit);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

Expand Down
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

// SwiGLU activation with input clamping.
ops.def(
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
"-> ()");
ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp);

ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
Expand Down
80 changes: 80 additions & 0 deletions tests/kernels/core/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NewGELU,
QuickGELU,
SiluAndMul,
SiluAndMulWithClamp,
SwigluOAIAndMul,
SwigluStepAndMul,
swiglustep_and_mul_triton,
Expand Down Expand Up @@ -116,6 +117,85 @@ def _get_rtol(output) -> float:
opcheck(fn, (out, x))


SWIGLU_LIMITS = [3.0, 7.0, 15.0]


@pytest.mark.parametrize("swiglu_limit", SWIGLU_LIMITS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_with_clamp(
default_vllm_config,
swiglu_limit: float,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
"""SiluAndMulWithClamp: cuda kernel must match native reference."""
set_random_seed(seed)
torch.set_default_device(device)
# Use large values to ensure clamping is exercised.
x = torch.randn(num_tokens, 2 * d, dtype=dtype) * swiglu_limit * 2

layer = SiluAndMulWithClamp(swiglu_limit, compile_native=False)
out = layer(x)
ref_out = layer.forward_native(x)

rtol = {
torch.float16: 2e-3,
torch.bfloat16: 2e-2,
torch.float: 1.3e-6,
}
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=rtol[out.dtype]
)

# Verify clamping is actually being applied: the clamped output should
# differ from the unclamped SiluAndMul output when inputs are large.
unclamped_out = SiluAndMul.forward_native(x)
assert not torch.equal(ref_out.float(), unclamped_out.float()), (
"Input was not large enough to exercise the clamp; increase scale"
)

# Verify gate clamping semantics with a controlled scalar case.
# gate=large_val is clamped to limit first, then silu(limit) * 1.0.
x_gate = torch.tensor(
[[swiglu_limit * 20.0, 1.0]], dtype=torch.float32, device=device
)
out_gate = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_gate)
expected_gate = torch.nn.functional.silu(
torch.tensor(swiglu_limit, dtype=torch.float32)
).item()
torch.testing.assert_close(
out_gate,
torch.tensor([[expected_gate]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)

# Verify up clamping semantics: up >> limit gets clamped to limit.
x_up = torch.tensor(
[[1.0, swiglu_limit * 20.0]], dtype=torch.float32, device=device
)
out_up = SiluAndMulWithClamp(swiglu_limit, compile_native=False)(x_up)
silu_1 = torch.nn.functional.silu(torch.tensor(1.0)).item()
torch.testing.assert_close(
out_up,
torch.tensor([[silu_1 * swiglu_limit]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)

# opcheck
out_buf = torch.empty(x.shape[:-1] + (d,), dtype=dtype, device=device)
opcheck(torch.ops._C.silu_and_mul_with_clamp, (out_buf, x, swiglu_limit))


@pytest.mark.parametrize(
"activation",
[
Expand Down
40 changes: 40 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,46 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)


@CustomOp.register("silu_and_mul_with_clamp")
class SiluAndMulWithClamp(CustomOp):
"""SwiGLU activation with input clamping (used by some MoE shared experts).

Computes:
gate = clamp(x[..., :d], max=swiglu_limit)
up = clamp(x[..., d:], min=-swiglu_limit, max=swiglu_limit)
out = silu(gate) * up
where d = x.shape[-1] // 2.

Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self, swiglu_limit: float, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
self.swiglu_limit = float(swiglu_limit)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
self.op = torch.ops._C.silu_and_mul_with_clamp
elif current_platform.is_cpu():
self._forward_method = self.forward_native

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate = torch.clamp(x[..., :d], max=self.swiglu_limit)
up = torch.clamp(x[..., d:], min=-self.swiglu_limit, max=self.swiglu_limit)
return F.silu(gate) * up

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x, self.swiglu_limit)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)


# --8<-- [start:mul_and_silu]
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _gelu_and_mul(
# Uses static methods or standalone functions to avoid instantiating CustomOp
# classes, which would call get_current_vllm_config() before config is set.
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
MoEActivation.SILU: SiluAndMul.forward_native,
MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x),
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
}
Expand Down
Loading
Loading