Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 82 additions & 28 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,74 @@
namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
const scalar_t& y,
const float limit) {
if constexpr (act_first) {
scalar_t gate = ACT_FN(x);
scalar_t up = y;
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fminf((float)gate, limit);
up = (scalar_t)fmaxf(fminf((float)up, limit), -limit);
}
return gate * up;
} else {
scalar_t gate = x;
scalar_t up = ACT_FN(y);
if constexpr (HAS_CLAMP) {
gate = (scalar_t)fmaxf(fminf((float)gate, limit), -limit);
up = (scalar_t)fminf((float)up, limit);
}
return gate * up;
}
}

template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
bool act_first>
bool act_first, bool HAS_CLAMP>
__device__ __forceinline__ packed_t packed_compute(const packed_t& x,
const packed_t& y) {
return act_first ? packed_mul(PACKED_ACT_FN(x), y)
: packed_mul(x, PACKED_ACT_FN(y));
const packed_t& y,
const float limit) {
if constexpr (act_first) {
packed_t gate = PACKED_ACT_FN(x);
packed_t up = y;
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fminf(g.x, limit);
g.y = fminf(g.y, limit);
u.x = fmaxf(fminf(u.x, limit), -limit);
u.y = fmaxf(fminf(u.y, limit), -limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(gate, up);
} else {
packed_t gate = x;
packed_t up = PACKED_ACT_FN(y);
if constexpr (HAS_CLAMP) {
float2 g = cast_to_float2(gate);
float2 u = cast_to_float2(up);
g.x = fmaxf(fminf(g.x, limit), -limit);
g.y = fmaxf(fminf(g.y, limit), -limit);
u.x = fminf(u.x, limit);
u.y = fminf(u.y, limit);
gate = cast_to_packed<packed_t>(g);
up = cast_to_packed<packed_t>(u);
}
return packed_mul(gate, up);
}
}

// Activation and gating kernel template.
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&),
packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
bool use_vec, bool use_256b = false>
bool use_vec, bool HAS_CLAMP, bool use_256b = false>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int d, const float limit) {
const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + blockIdx.x * d;
Expand All @@ -58,8 +103,9 @@ __global__ void act_and_mul_kernel(
}
#pragma unroll
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
x.elts[j], y.elts[j]);
x.elts[j] =
packed_compute<packed_t, PACKED_ACT_FN, act_first, HAS_CLAMP>(
x.elts[j], y.elts[j], limit);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
Expand All @@ -72,7 +118,8 @@ __global__ void act_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
out_ptr[idx] =
compute<scalar_t, ACT_FN, act_first, HAS_CLAMP>(x, y, limit);
}
}
}
Expand Down Expand Up @@ -151,8 +198,10 @@ packed_gelu_tanh_kernel(const packed_t& val) {

// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
// first. HAS_CLAMP (bool) enables clamping of gate (post-activation, max only)
// and up (both sides) by LIMIT.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST, \
HAS_CLAMP, LIMIT) \
auto dtype = input.scalar_type(); \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
Expand All @@ -177,17 +226,17 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, true, HAS_CLAMP, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
} \
} else { \
Expand All @@ -197,16 +246,21 @@ packed_gelu_tanh_kernel(const packed_t& val) {
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
ACT_FIRST, false, HAS_CLAMP><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, LIMIT); \
}); \
}

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true);
torch::Tensor& input, // [..., 2 * d]
double limit) {
if (limit > 0.0) {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, true, (float)limit);
} else {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true, false, 0.0f);
}
}

void mul_and_silu(torch::Tensor& out, // [..., d]
Expand All @@ -215,21 +269,21 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
false);
false, false, 0.0f);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true);
true, false, 0.0f);
}

void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
vllm::packed_gelu_tanh_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(
vllm::gelu_tanh_kernel, vllm::packed_gelu_tanh_kernel, true, false, 0.0f);
}

namespace vllm {
Expand Down
2 changes: 1 addition & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rope_dim_offset, bool inverse);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input, double limit = 0.0);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input, float limit=0.0) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

ops.def(
Expand Down
78 changes: 78 additions & 0 deletions tests/kernels/core/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,88 @@ def _get_rtol(output) -> float:
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
elif activation == "silu_and_mul":
opcheck(fn, (out, x, 0.0))
elif activation != "swiglustep_and_mul":
opcheck(fn, (out, x))


SWIGLU_LIMITS = [3.0, 7.0, 15.0]


@pytest.mark.parametrize("swiglu_limit", SWIGLU_LIMITS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_clamp(
default_vllm_config,
swiglu_limit: float,
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
"""SiluAndMul with swiglu_limit: cuda kernel must match native reference."""
set_random_seed(seed)
torch.set_default_device(device)
# Use large values to ensure clamping is exercised.
x = torch.randn(num_tokens, 2 * d, dtype=dtype) * swiglu_limit * 2

layer = SiluAndMul(compile_native=False, swiglu_limit=swiglu_limit)
out = layer(x)
ref_out = layer.forward_native(x)

rtol = {
torch.float16: 2e-3,
torch.bfloat16: 2e-2,
torch.float: 1.3e-6,
}
torch.testing.assert_close(
out, ref_out, atol=get_default_atol(out), rtol=rtol[out.dtype]
)

# Verify clamping is actually being applied: the clamped output should
# differ from the unclamped output when inputs are large.
unclamped_out = SiluAndMul(compile_native=False).forward_native(x)
assert not torch.equal(ref_out.float(), unclamped_out.float()), (
"Input was not large enough to exercise the clamp; increase scale"
)

# Verify gate clamping semantics with a controlled scalar case.
# silu(large_val) >> swiglu_limit, so after clamp(max=limit) * 1.0 == limit.
x_gate = torch.tensor(
[[swiglu_limit * 20.0, 1.0]], dtype=torch.float32, device=device
)
out_gate = SiluAndMul(compile_native=False, swiglu_limit=swiglu_limit)(x_gate)
torch.testing.assert_close(
out_gate,
torch.tensor([[swiglu_limit]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)

# Verify up clamping semantics: up >> limit gets clamped to limit.
x_up = torch.tensor(
[[1.0, swiglu_limit * 20.0]], dtype=torch.float32, device=device
)
out_up = SiluAndMul(compile_native=False, swiglu_limit=swiglu_limit)(x_up)
silu_1 = torch.nn.functional.silu(torch.tensor(1.0)).item()
torch.testing.assert_close(
out_up,
torch.tensor([[silu_1 * swiglu_limit]], dtype=torch.float32, device=device),
atol=1e-3,
rtol=1e-3,
)

# opcheck
out_buf = torch.empty(x.shape[:-1] + (d,), dtype=dtype, device=device)
opcheck(torch.ops._C.silu_and_mul, (out_buf, x, swiglu_limit))


@pytest.mark.parametrize(
"activation",
[
Expand Down
18 changes: 13 additions & 5 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,32 @@ class SiluAndMul(CustomOp):

# --8<-- [end:silu_and_mul]

def __init__(self, *, compile_native: bool = True):
def __init__(
self, *, compile_native: bool = True, swiglu_limit: float | None = None
):
super().__init__(compile_native=compile_native)
self.swiglu_limit = swiglu_limit
if current_platform.is_cuda_alike() or current_platform.is_xpu():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native

@staticmethod
def forward_native(x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
gate = F.silu(x[..., :d])
up = x[..., d:]
if self.swiglu_limit is not None:
gate = torch.clamp(gate, max=self.swiglu_limit)
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
Comment on lines +145 to +147

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential inconsistency between the native and CUDA implementations when swiglu_limit is set to 0.0. In forward_native, any value of swiglu_limit that is not None (including 0.0) will trigger clamping, effectively zeroing out the output. However, in forward_cuda (and the underlying CUDA kernel), clamping is only enabled if limit > 0.0. While swiglu_limit is typically a positive value, it's safer to align the logic to avoid discrepancies.

Suggested change
if self.swiglu_limit is not None:
gate = torch.clamp(gate, max=self.swiglu_limit)
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)
if self.swiglu_limit is not None and self.swiglu_limit > 0:
gate = torch.clamp(gate, max=self.swiglu_limit)
up = torch.clamp(up, min=-self.swiglu_limit, max=self.swiglu_limit)

return gate * up

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
limit = self.swiglu_limit if self.swiglu_limit is not None else 0.0
self.op(out, x, limit)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading
Loading