diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py new file mode 100644 index 000000000000..89621129e49d --- /dev/null +++ b/python/sglang/jit_kernel/activation.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + +_IS_ROCM: bool = torch.version.hip is not None + + +@cache_once +def _jit_activation_module(dtype: torch.dtype, act_type: str) -> Module: + # act_type: "silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick" + args = make_cpp_args(dtype) + kernel_name = f"{act_type}<{args}>" + + return load_jit( + f"activation_{act_type}", + *args, + cuda_files=["elementwise/activation.cuh"], + cuda_wrappers=[(act_type, f"{kernel_name}::run")], + ) + + +def silu_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.silu_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "silu_and_mul") + module.silu_and_mul(out, input) + return out + + +def gelu_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_and_mul") + module.gelu_and_mul(out, input) + return out + + +def gelu_tanh_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_tanh_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_tanh_and_mul") + module.gelu_tanh_and_mul(out, input) + return out + + +def gelu_quick(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(input) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_quick(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_quick") + module.gelu_quick(out, input) + return out diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh new file mode 100644 index 000000000000..95613d96b95c --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -0,0 +1,181 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + +template +SGL_DEVICE T silu(const T& x) { + float f32_val = device::cast(x); + return device::cast(f32_val / (1.0f + expf(-f32_val))); +} + +template +SGL_DEVICE T gelu(const T& x) { + constexpr float kAlpha = M_SQRT1_2; + float f32_val = device::cast(x); + return device::cast(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +SGL_DEVICE T gelu_quick_act(const T& x) { + float f32_val = device::cast(x); + return device::cast(f32_val / (1.0f + expf(-f32_val * 1.702f))); +} + +template +SGL_DEVICE T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = device::cast(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return device::cast(f32_val * cdf); +} + +template +__global__ void +act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + + if (token_idx >= num_tokens) return; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + device::AlignedVector x_vec, y_vec, out_vec; + x_vec.load(input_ptr + offset + idx * vec_size); + y_vec.load(input_ptr + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.store(out_ptr + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input_ptr[offset + remaining_offset + idx], y = input_ptr[offset + remaining_offset + d + idx]; + out_ptr[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } +} + +template +__global__ void +act_only_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + + if (token_idx >= num_tokens) return; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + device::AlignedVector x_vec, out_vec; + x_vec.load(input_ptr + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.store(out_ptr + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input_ptr[offset + remaining_offset + idx]; + out_ptr[token_idx * d + remaining_offset + idx] = Activation(x); + } +} + +template +struct ActivationAndMul { + static constexpr auto kernel = act_and_mul_kernel; + + static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D_half = SymbolicSize{"d"}; + auto D_full = SymbolicSize{"2*d"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D_full}) // + .with_dtype() + .with_device(device) + .verify(input); + + TensorMatcher({N, D_half}) // + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t num_tokens = N.unwrap(); + const int64_t d = D_half.unwrap(); + + RuntimeCheck(D_full.unwrap() == 2 * d, "Input dimension must be 2 * output dimension"); + + const uint32_t block_size = std::min(d, 1024); + + LaunchKernel(num_tokens, block_size, device.unwrap())( + kernel, static_cast(output.data_ptr()), static_cast(input.data_ptr()), d, num_tokens); + } +}; + +template +struct ActivationOnly { + static constexpr auto kernel = act_only_kernel; + + static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"d"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}) // + .with_dtype() + .with_device(device) + .verify(input); + + TensorMatcher({N, D}) // + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t num_tokens = N.unwrap(); + const int64_t d = D.unwrap(); + + const uint32_t block_size = std::min(d, 1024); + + LaunchKernel(num_tokens, block_size, device.unwrap())( + kernel, static_cast(output.data_ptr()), static_cast(input.data_ptr()), d, num_tokens); + } +}; + +template +using silu_and_mul = ActivationAndMul; + +template +using gelu_and_mul = ActivationAndMul; + +template +using gelu_tanh_and_mul = ActivationAndMul; + +template +using gelu_quick = ActivationOnly; diff --git a/python/sglang/jit_kernel/tests/test_activation.py b/python/sglang/jit_kernel/tests/test_activation.py new file mode 100644 index 000000000000..4857651ac46c --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_activation.py @@ -0,0 +1,62 @@ +import pytest +import torch +import torch.nn.functional as F + +from sglang.jit_kernel.activation import ( + gelu_and_mul, + gelu_quick, + gelu_tanh_and_mul, + silu_and_mul, +) + + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_silu_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.silu(x[..., :dim]) * x[..., dim:] + y = silu_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.gelu(x[..., :dim], approximate="none") * x[..., dim:] + y = gelu_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_tanh_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.gelu(x[..., :dim], approximate="tanh") * x[..., dim:] + y = gelu_tanh_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_quick(dim, batch_size, dtype): + x = torch.randn(batch_size, dim, device="cuda", dtype=dtype) + y_ref = x * torch.sigmoid(1.702 * x) + y = gelu_quick(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 05b9a056c40e..58380f316b48 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -265,7 +265,6 @@ set(SOURCES "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/common_extension.cc" - "csrc/elementwise/activation.cu" "csrc/elementwise/cast.cu" "csrc/elementwise/concat_mla.cu" "csrc/elementwise/copy.cu" diff --git a/sgl-kernel/include/hip/hip_act_and_mul.cuh b/sgl-kernel/include/hip/hip_act_and_mul.cuh deleted file mode 100644 index ddb1b702d92d..000000000000 --- a/sgl-kernel/include/hip/hip_act_and_mul.cuh +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include "utils.h" - -#define kBitsToLoad 128 -#define kBytesToLoad (kBitsToLoad / 8) - -// Adapted from -// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) - -namespace sgl_hip { -namespace activation { - -template -__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { - constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - const int64_t offset = token_idx * 2 * d; - -#pragma unroll 1 - for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { - sgl_hip::vec_t x_vec, y_vec, out_vec; - x_vec.cast_load(input + offset + idx * vec_size); - y_vec.cast_load(input + offset + d + idx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - out_vec[i] = Activation(x_vec[i]) * y_vec[i]; - } - out_vec.cast_store(out + token_idx * d + idx * vec_size); - } - - const int64_t remaining_offset = d - d % (stride * vec_size); - // process the remaining elements -#pragma unroll 1 - for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { - T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; - out[token_idx * d + remaining_offset + idx] = Activation(x) * y; - } -} - -template -__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { - constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - const int64_t offset = token_idx * d; - -#pragma unroll 1 - for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { - sgl_hip::vec_t x_vec, y_vec, out_vec; - x_vec.cast_load(input + offset + idx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - out_vec[i] = Activation(x_vec[i]); - } - out_vec.cast_store(out + token_idx * d + idx * vec_size); - } - - const int64_t remaining_offset = d - d % (stride * vec_size); - // process the remaining elements -#pragma unroll 1 - for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { - T x = input[offset + remaining_offset + idx]; - out[token_idx * d + remaining_offset + idx] = Activation(x); - } -} - -} // namespace activation -} // namespace sgl_hip