Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions python/sglang/jit_kernel/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import torch

from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args

if TYPE_CHECKING:
from tvm_ffi.module import Module

_IS_ROCM: bool = torch.version.hip is not None


@cache_once
def _jit_activation_module(dtype: torch.dtype, act_type: str) -> Module:
# act_type: "silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"
args = make_cpp_args(dtype)
kernel_name = f"{act_type}<{args}>"

return load_jit(
f"activation_{act_type}",
*args,
cuda_files=["elementwise/activation.cuh"],
cuda_wrappers=[(act_type, f"{kernel_name}::run")],
)


def silu_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
if out is None:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
if _IS_ROCM:
import sgl_kernel

sgl_kernel.silu_and_mul(out, input)
else:
module = _jit_activation_module(input.dtype, "silu_and_mul")
module.silu_and_mul(out, input)
return out


def gelu_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
if out is None:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
if _IS_ROCM:
import sgl_kernel

sgl_kernel.gelu_and_mul(out, input)
else:
module = _jit_activation_module(input.dtype, "gelu_and_mul")
module.gelu_and_mul(out, input)
return out


def gelu_tanh_and_mul(
input: torch.Tensor, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
if out is None:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
if _IS_ROCM:
import sgl_kernel

sgl_kernel.gelu_tanh_and_mul(out, input)
else:
module = _jit_activation_module(input.dtype, "gelu_tanh_and_mul")
module.gelu_tanh_and_mul(out, input)
return out


def gelu_quick(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
if _IS_ROCM:
import sgl_kernel

sgl_kernel.gelu_quick(out, input)
else:
module = _jit_activation_module(input.dtype, "gelu_quick")
module.gelu_quick(out, input)
return out
181 changes: 181 additions & 0 deletions python/sglang/jit_kernel/csrc/elementwise/activation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#pragma once

#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.h>

#include <sgl_kernel/type.cuh>
#include <sgl_kernel/vec.cuh>

#include <tvm/ffi/container/tensor.h>

#define kBitsToLoad 128
#define kBytesToLoad (kBitsToLoad / 8)

template <typename T>
SGL_DEVICE T silu(const T& x) {
float f32_val = device::cast<float>(x);
return device::cast<T>(f32_val / (1.0f + expf(-f32_val)));
}

template <typename T>
SGL_DEVICE T gelu(const T& x) {
constexpr float kAlpha = M_SQRT1_2;
float f32_val = device::cast<float>(x);
return device::cast<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
}

// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
template <typename T>
SGL_DEVICE T gelu_quick_act(const T& x) {
float f32_val = device::cast<float>(x);
return device::cast<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
}

template <typename T>
SGL_DEVICE T gelu_tanh(const T& x) {
constexpr float kAlpha = 0.044715f;
constexpr float kBeta = 0.7978845608028654f;
float f32_val = device::cast<float>(x);
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
return device::cast<T>(f32_val * cdf);
}

template <typename T, T (*Activation)(const T&)>
__global__ void
act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * 2 * d;

if (token_idx >= num_tokens) return;

#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
device::AlignedVector<T, vec_size> x_vec, y_vec, out_vec;
x_vec.load(input_ptr + offset + idx * vec_size);
y_vec.load(input_ptr + offset + d + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
}
out_vec.store(out_ptr + token_idx * d + idx * vec_size);
}

const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input_ptr[offset + remaining_offset + idx], y = input_ptr[offset + remaining_offset + d + idx];
out_ptr[token_idx * d + remaining_offset + idx] = Activation(x) * y;
}
}

template <typename T, T (*Activation)(const T&)>
__global__ void
act_only_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;

if (token_idx >= num_tokens) return;

#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
device::AlignedVector<T, vec_size> x_vec, out_vec;
x_vec.load(input_ptr + offset + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]);
}
out_vec.store(out_ptr + token_idx * d + idx * vec_size);
}

const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input_ptr[offset + remaining_offset + idx];
out_ptr[token_idx * d + remaining_offset + idx] = Activation(x);
}
}

template <typename T, T (*Activation)(const T&)>
struct ActivationAndMul {
static constexpr auto kernel = act_and_mul_kernel<T, Activation>;

static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) {
using namespace host;
auto N = SymbolicSize{"num_tokens"};
auto D_half = SymbolicSize{"d"};
auto D_full = SymbolicSize{"2*d"};
auto device = SymbolicDevice{};
device.set_options<kDLCUDA>();

TensorMatcher({N, D_full}) //
.with_dtype<T>()
.with_device(device)
.verify(input);

TensorMatcher({N, D_half}) //
.with_dtype<T>()
.with_device(device)
.verify(output);

const int64_t num_tokens = N.unwrap();
const int64_t d = D_half.unwrap();

RuntimeCheck(D_full.unwrap() == 2 * d, "Input dimension must be 2 * output dimension");

const uint32_t block_size = std::min<uint32_t>(d, 1024);

LaunchKernel(num_tokens, block_size, device.unwrap())(
kernel, static_cast<T*>(output.data_ptr()), static_cast<T*>(input.data_ptr()), d, num_tokens);
}
};

template <typename T, T (*Activation)(const T&)>
struct ActivationOnly {
static constexpr auto kernel = act_only_kernel<T, Activation>;

static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) {
using namespace host;
auto N = SymbolicSize{"num_tokens"};
auto D = SymbolicSize{"d"};
auto device = SymbolicDevice{};
device.set_options<kDLCUDA>();

TensorMatcher({N, D}) //
.with_dtype<T>()
.with_device(device)
.verify(input);

TensorMatcher({N, D}) //
.with_dtype<T>()
.with_device(device)
.verify(output);

const int64_t num_tokens = N.unwrap();
const int64_t d = D.unwrap();

const uint32_t block_size = std::min<uint32_t>(d, 1024);

LaunchKernel(num_tokens, block_size, device.unwrap())(
kernel, static_cast<T*>(output.data_ptr()), static_cast<T*>(input.data_ptr()), d, num_tokens);
}
};

template <typename T>
using silu_and_mul = ActivationAndMul<T, silu>;

template <typename T>
using gelu_and_mul = ActivationAndMul<T, gelu>;

template <typename T>
using gelu_tanh_and_mul = ActivationAndMul<T, gelu_tanh>;

template <typename T>
using gelu_quick = ActivationOnly<T, gelu_quick_act>;
62 changes: 62 additions & 0 deletions python/sglang/jit_kernel/tests/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch
import torch.nn.functional as F

from sglang.jit_kernel.activation import (
gelu_and_mul,
gelu_quick,
gelu_tanh_and_mul,
silu_and_mul,
)


@pytest.mark.parametrize("dim", [128, 512, 1024])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_silu_and_mul(dim, batch_size, dtype):
x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype)
y_ref = F.silu(x[..., :dim]) * x[..., dim:]
y = silu_and_mul(x)
atol = 1e-3 if dtype == torch.float16 else 2e-2
rtol = 1e-3 if dtype == torch.float16 else 2e-2
torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("dim", [128, 512, 1024])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_gelu_and_mul(dim, batch_size, dtype):
x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype)
y_ref = F.gelu(x[..., :dim], approximate="none") * x[..., dim:]
y = gelu_and_mul(x)
atol = 1e-3 if dtype == torch.float16 else 2e-2
rtol = 1e-3 if dtype == torch.float16 else 2e-2
torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("dim", [128, 512, 1024])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_gelu_tanh_and_mul(dim, batch_size, dtype):
x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype)
y_ref = F.gelu(x[..., :dim], approximate="tanh") * x[..., dim:]
y = gelu_tanh_and_mul(x)
atol = 1e-3 if dtype == torch.float16 else 2e-2
rtol = 1e-3 if dtype == torch.float16 else 2e-2
torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("dim", [128, 512, 1024])
@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_gelu_quick(dim, batch_size, dtype):
x = torch.randn(batch_size, dim, device="cuda", dtype=dtype)
y_ref = x * torch.sigmoid(1.702 * x)
y = gelu_quick(x)
atol = 1e-3 if dtype == torch.float16 else 2e-2
rtol = 1e-3 if dtype == torch.float16 else 2e-2
torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol)


if __name__ == "__main__":
pytest.main([__file__])
1 change: 0 additions & 1 deletion sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ set(SOURCES
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/common_extension.cc"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/concat_mla.cu"
"csrc/elementwise/copy.cu"
Expand Down
Loading
Loading