From 0dfc3c719a11b3b79097eb8bb94a4a21f283257e Mon Sep 17 00:00:00 2001 From: weiminc Date: Sat, 7 Feb 2026 07:35:32 +0000 Subject: [PATCH 1/7] init --- 3rdparty/amd/sgl-kernel/rocm_hipify.py | 2 +- sgl-kernel/CMakeLists.txt | 2 +- sgl-kernel/setup_rocm.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/3rdparty/amd/sgl-kernel/rocm_hipify.py b/3rdparty/amd/sgl-kernel/rocm_hipify.py index 8373f741d1d6..aa44050141dd 100644 --- a/3rdparty/amd/sgl-kernel/rocm_hipify.py +++ b/3rdparty/amd/sgl-kernel/rocm_hipify.py @@ -16,7 +16,7 @@ "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", - "csrc/elementwise/activation.cu", + #"csrc/elementwise/activation.cu", "csrc/elementwise/pos_enc.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 05b9a056c40e..fc6c19b11a55 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -265,7 +265,7 @@ set(SOURCES "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/common_extension.cc" - "csrc/elementwise/activation.cu" + # "csrc/elementwise/activation.cu" "csrc/elementwise/cast.cu" "csrc/elementwise/concat_mla.cu" "csrc/elementwise/copy.cu" diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 66713bf0ae7d..4f7017942ea9 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -45,7 +45,7 @@ def _get_version(): "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", - "csrc/elementwise/activation.cu", + # "csrc/elementwise/activation.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", "csrc/moe/moe_align_kernel.cu", From 6ac0d0950378790a483093e0445966bfc2e399a8 Mon Sep 17 00:00:00 2001 From: weiminc Date: Sat, 7 Feb 2026 16:55:41 +0000 Subject: [PATCH 2/7] basic impl --- python/sglang/jit_kernel/activation.py | 74 ++++++++ .../csrc/elementwise/activation.cuh | 174 ++++++++++++++++++ .../jit_kernel/tests/test_activation.py | 56 ++++++ 3 files changed, 304 insertions(+) create mode 100644 python/sglang/jit_kernel/activation.py create mode 100644 python/sglang/jit_kernel/csrc/elementwise/activation.cuh create mode 100644 python/sglang/jit_kernel/tests/test_activation.py diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py new file mode 100644 index 000000000000..3bdc414ce4cf --- /dev/null +++ b/python/sglang/jit_kernel/activation.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + load_jit, + make_cpp_args, +) + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_activation_module(dtype: torch.dtype, act_type: str) -> Module: + # act_type: "silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick" + args = make_cpp_args(dtype) + kernel_name = f"{act_type}<{args}>" + + return load_jit( + f"activation_{act_type}", + *args, + cuda_files=["elementwise/activation.cuh"], + cuda_wrappers=[(act_type, f"{kernel_name}::run")], + ) + + +def silu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + module = _jit_activation_module(input.dtype, "silu_and_mul") + module.silu_and_mul(out, input) + return out + + +def gelu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + module = _jit_activation_module(input.dtype, "gelu_and_mul") + module.gelu_and_mul(out, input) + return out + + +def gelu_tanh_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: + if out is None: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + module = _jit_activation_module(input.dtype, "gelu_tanh_and_mul") + module.gelu_tanh_and_mul(out, input) + return out + + +def gelu_quick(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: + if out is None: + out = torch.empty_like(input) + module = _jit_activation_module(input.dtype, "gelu_quick") + module.gelu_quick(out, input) + return out diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh new file mode 100644 index 000000000000..636adb0f3366 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include +#include + +namespace { + +namespace detail { + +template +SGL_DEVICE float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif +} + +template +SGL_DEVICE T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} + +} // namespace detail + +template +SGL_DEVICE T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); +} + +template +SGL_DEVICE T gelu(const T& x) { + constexpr float kAlpha = M_SQRT1_2; + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +SGL_DEVICE T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); +} + +template +SGL_DEVICE T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); +} + +template +__global__ void act_and_mul_kernel(void* __restrict__ out_ptr, const void* __restrict__ input_ptr, int64_t d, + int64_t num_tokens) { + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + + if (token_idx >= num_tokens) return; + + const T* input = static_cast(input_ptr) + offset; + T* out = static_cast(out_ptr) + token_idx * d; + + for (int64_t i = thread_idx; i < d; i += stride) { + T x = input[i]; + T y = input[i + d]; + out[i] = Activation(x) * y; + } +} + +template +__global__ void act_only_kernel(void* __restrict__ out_ptr, const void* __restrict__ input_ptr, int64_t d, + int64_t num_tokens) { + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + + if (token_idx >= num_tokens) return; + + const T* input = static_cast(input_ptr) + offset; + T* out = static_cast(out_ptr) + token_idx * d; + + for (int64_t i = thread_idx; i < d; i += stride) { + out[i] = Activation(input[i]); + } +} + +template +struct ActivationAndMul { + static constexpr auto kernel = act_and_mul_kernel; + + static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D_half = SymbolicSize{"d"}; + auto D_full = SymbolicSize{"2*d"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D_full}) // + .with_dtype() + .with_device(device) + .verify(input); + + TensorMatcher({N, D_half}) // + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t num_tokens = N.unwrap(); + const int64_t d = D_half.unwrap(); + + RuntimeCheck(D_full.unwrap() == 2 * d, "Input dimension must be 2 * output dimension"); + + const uint32_t block_size = std::min(d, 1024); + + LaunchKernel(num_tokens, block_size, device.unwrap())(kernel, output.data_ptr(), input.data_ptr(), d, num_tokens); + } +}; + +template +struct ActivationOnly { + static constexpr auto kernel = act_only_kernel; + + static void run(tvm::ffi::TensorView output, tvm::ffi::TensorView input) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"d"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}) // + .with_dtype() + .with_device(device) + .verify(input); + + TensorMatcher({N, D}) // + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t num_tokens = N.unwrap(); + const int64_t d = D.unwrap(); + + const uint32_t block_size = std::min(d, 1024); + + LaunchKernel(num_tokens, block_size, device.unwrap())(kernel, output.data_ptr(), input.data_ptr(), d, num_tokens); + } +}; + +template +using silu_and_mul = ActivationAndMul; + +template +using gelu_and_mul = ActivationAndMul; + +template +using gelu_tanh_and_mul = ActivationAndMul; + +template +using gelu_quick = ActivationOnly; + +} // namespace diff --git a/python/sglang/jit_kernel/tests/test_activation.py b/python/sglang/jit_kernel/tests/test_activation.py new file mode 100644 index 000000000000..59b559580314 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_activation.py @@ -0,0 +1,56 @@ +import pytest +import torch +import torch.nn.functional as F +from sglang.jit_kernel.activation import ( + silu_and_mul, + gelu_and_mul, + gelu_tanh_and_mul, + gelu_quick, +) + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_silu_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.silu(x[..., :dim]) * x[..., dim:] + y = silu_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.gelu(x[..., :dim], approximate="none") * x[..., dim:] + y = gelu_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_tanh_and_mul(dim, batch_size, dtype): + x = torch.randn(batch_size, 2 * dim, device="cuda", dtype=dtype) + y_ref = F.gelu(x[..., :dim], approximate="tanh") * x[..., dim:] + y = gelu_tanh_and_mul(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + +@pytest.mark.parametrize("dim", [128, 512, 1024]) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gelu_quick(dim, batch_size, dtype): + x = torch.randn(batch_size, dim, device="cuda", dtype=dtype) + y_ref = x * torch.sigmoid(1.702 * x) + y = gelu_quick(x) + atol = 1e-3 if dtype == torch.float16 else 2e-2 + rtol = 1e-3 if dtype == torch.float16 else 2e-2 + torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + +if __name__ == "__main__": + pytest.main([__file__]) From 0eb823cd2d168296d46d79d08daee07ec9070ef6 Mon Sep 17 00:00:00 2001 From: weiminc Date: Wed, 11 Feb 2026 06:09:49 +0000 Subject: [PATCH 3/7] update --- 3rdparty/amd/sgl-kernel/rocm_hipify.py | 1 - python/sglang/jit_kernel/activation.py | 6 +- .../csrc/elementwise/activation.cuh | 59 ++++-- sgl-kernel/CMakeLists.txt | 1 - sgl-kernel/csrc/elementwise/activation.cu | 170 ------------------ sgl-kernel/include/hip/hip_act_and_mul.cuh | 87 --------- sgl-kernel/setup_rocm.py | 1 - 7 files changed, 46 insertions(+), 279 deletions(-) delete mode 100644 sgl-kernel/csrc/elementwise/activation.cu delete mode 100644 sgl-kernel/include/hip/hip_act_and_mul.cuh diff --git a/3rdparty/amd/sgl-kernel/rocm_hipify.py b/3rdparty/amd/sgl-kernel/rocm_hipify.py index aa44050141dd..0187ffddc058 100644 --- a/3rdparty/amd/sgl-kernel/rocm_hipify.py +++ b/3rdparty/amd/sgl-kernel/rocm_hipify.py @@ -16,7 +16,6 @@ "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", - #"csrc/elementwise/activation.cu", "csrc/elementwise/pos_enc.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py index 3bdc414ce4cf..407066916acf 100644 --- a/python/sglang/jit_kernel/activation.py +++ b/python/sglang/jit_kernel/activation.py @@ -4,11 +4,7 @@ import torch -from sglang.jit_kernel.utils import ( - cache_once, - load_jit, - make_cpp_args, -) +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args if TYPE_CHECKING: from tvm_ffi.module import Module diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh index 636adb0f3366..4ffb8868bbdb 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -1,10 +1,14 @@ #pragma once #include +#include #include #include #include +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + namespace { namespace detail { @@ -59,8 +63,9 @@ SGL_DEVICE T gelu_tanh(const T& x) { } template -__global__ void act_and_mul_kernel(void* __restrict__ out_ptr, const void* __restrict__ input_ptr, int64_t d, +__global__ void act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); const int64_t token_idx = blockIdx.x; const int64_t thread_idx = threadIdx.x; const int64_t stride = blockDim.x; @@ -68,19 +73,31 @@ __global__ void act_and_mul_kernel(void* __restrict__ out_ptr, const void* __res if (token_idx >= num_tokens) return; - const T* input = static_cast(input_ptr) + offset; - T* out = static_cast(out_ptr) + token_idx * d; +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + device::AlignedVector x_vec, y_vec, out_vec; + x_vec.load(input_ptr + offset + idx * vec_size); + y_vec.load(input_ptr + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.store(out_ptr + token_idx * d + idx * vec_size); + } - for (int64_t i = thread_idx; i < d; i += stride) { - T x = input[i]; - T y = input[i + d]; - out[i] = Activation(x) * y; + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input_ptr[offset + remaining_offset + idx], y = input_ptr[offset + remaining_offset + d + idx]; + out_ptr[token_idx * d + remaining_offset + idx] = Activation(x) * y; } } template -__global__ void act_only_kernel(void* __restrict__ out_ptr, const void* __restrict__ input_ptr, int64_t d, +__global__ void act_only_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); const int64_t token_idx = blockIdx.x; const int64_t thread_idx = threadIdx.x; const int64_t stride = blockDim.x; @@ -88,11 +105,23 @@ __global__ void act_only_kernel(void* __restrict__ out_ptr, const void* __restri if (token_idx >= num_tokens) return; - const T* input = static_cast(input_ptr) + offset; - T* out = static_cast(out_ptr) + token_idx * d; +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + device::AlignedVector x_vec, out_vec; + x_vec.load(input_ptr + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.store(out_ptr + token_idx * d + idx * vec_size); + } - for (int64_t i = thread_idx; i < d; i += stride) { - out[i] = Activation(input[i]); + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input_ptr[offset + remaining_offset + idx]; + out_ptr[token_idx * d + remaining_offset + idx] = Activation(x); } } @@ -125,7 +154,8 @@ struct ActivationAndMul { const uint32_t block_size = std::min(d, 1024); - LaunchKernel(num_tokens, block_size, device.unwrap())(kernel, output.data_ptr(), input.data_ptr(), d, num_tokens); + LaunchKernel(num_tokens, block_size, device.unwrap())( + kernel, static_cast(output.data_ptr()), static_cast(input.data_ptr()), d, num_tokens); } }; @@ -155,7 +185,8 @@ struct ActivationOnly { const uint32_t block_size = std::min(d, 1024); - LaunchKernel(num_tokens, block_size, device.unwrap())(kernel, output.data_ptr(), input.data_ptr(), d, num_tokens); + LaunchKernel(num_tokens, block_size, device.unwrap())( + kernel, static_cast(output.data_ptr()), static_cast(input.data_ptr()), d, num_tokens); } }; diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index fc6c19b11a55..58380f316b48 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -265,7 +265,6 @@ set(SOURCES "csrc/attention/merge_attn_states.cu" "csrc/attention/vertical_slash_index.cu" "csrc/common_extension.cc" - # "csrc/elementwise/activation.cu" "csrc/elementwise/cast.cu" "csrc/elementwise/concat_mla.cu" "csrc/elementwise/copy.cu" diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu deleted file mode 100644 index 43617f87f318..000000000000 --- a/sgl-kernel/csrc/elementwise/activation.cu +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright (c) 2024 by FlashInfer team. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#ifndef USE_ROCM - -#include - -#include "utils.h" - -#else -#include "hip/hip_act_and_mul.cuh" -#endif - -// Adapted from flashinfer activation -// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 - -namespace detail { - -template -__device__ __forceinline__ float to_f32(const T& x) { -#if USE_ROCM - return castToFloat(x); -#else - return static_cast(x); -#endif -} - -template -__device__ __forceinline__ T from_f32(float f32) { -#if USE_ROCM - return castFromFloat(f32); -#else - return static_cast(f32); -#endif -} - -} // namespace detail - -template -__device__ __forceinline__ T silu(const T& x) { - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); -} - -template -__device__ __forceinline__ T gelu(const T& x) { - constexpr float kAlpha = M_SQRT1_2; - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); -} - -// gelu_quick(x) = x * torch.sigmoid(1.702 * x) -template -__device__ __forceinline__ T gelu_quick_act(const T& x) { - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); -} - -template -__device__ __forceinline__ T gelu_tanh(const T& x) { - constexpr float kAlpha = 0.044715f; - constexpr float kBeta = 0.7978845608028654f; - float f32_val = detail::to_f32(x); - const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); - return detail::from_f32(f32_val * cdf); -} - -void silu_and_mul(at::Tensor& out, at::Tensor& input) { - int d = input.size(-1) / 2; - int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); -#if USE_ROCM - sgl_hip::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#endif - return true; - }); -} - -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { - int d = input.size(-1) / 2; - int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); -#if USE_ROCM - sgl_hip::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#endif - return true; - }); -} - -void gelu_and_mul(at::Tensor& out, at::Tensor& input) { - int d = input.size(-1) / 2; - int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); -#if USE_ROCM - sgl_hip::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); -#endif - - return true; - }); -} - -#if USE_ROCM -void gelu_quick(at::Tensor& out, const at::Tensor& input) { - int d = input.size(-1); - int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); - sgl_hip::activation::act_only_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); - - return true; - }); -} -#endif diff --git a/sgl-kernel/include/hip/hip_act_and_mul.cuh b/sgl-kernel/include/hip/hip_act_and_mul.cuh deleted file mode 100644 index ddb1b702d92d..000000000000 --- a/sgl-kernel/include/hip/hip_act_and_mul.cuh +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include "utils.h" - -#define kBitsToLoad 128 -#define kBytesToLoad (kBitsToLoad / 8) - -// Adapted from -// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) - -namespace sgl_hip { -namespace activation { - -template -__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { - constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - const int64_t offset = token_idx * 2 * d; - -#pragma unroll 1 - for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { - sgl_hip::vec_t x_vec, y_vec, out_vec; - x_vec.cast_load(input + offset + idx * vec_size); - y_vec.cast_load(input + offset + d + idx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - out_vec[i] = Activation(x_vec[i]) * y_vec[i]; - } - out_vec.cast_store(out + token_idx * d + idx * vec_size); - } - - const int64_t remaining_offset = d - d % (stride * vec_size); - // process the remaining elements -#pragma unroll 1 - for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { - T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; - out[token_idx * d + remaining_offset + idx] = Activation(x) * y; - } -} - -template -__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { - constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - const int64_t offset = token_idx * d; - -#pragma unroll 1 - for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { - sgl_hip::vec_t x_vec, y_vec, out_vec; - x_vec.cast_load(input + offset + idx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - out_vec[i] = Activation(x_vec[i]); - } - out_vec.cast_store(out + token_idx * d + idx * vec_size); - } - - const int64_t remaining_offset = d - d % (stride * vec_size); - // process the remaining elements -#pragma unroll 1 - for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { - T x = input[offset + remaining_offset + idx]; - out[token_idx * d + remaining_offset + idx] = Activation(x); - } -} - -} // namespace activation -} // namespace sgl_hip diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 4f7017942ea9..3bd5a136cbf2 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -45,7 +45,6 @@ def _get_version(): "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", - # "csrc/elementwise/activation.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", "csrc/moe/moe_align_kernel.cu", From f4922260df8e4161be752e4f16885176494e0701 Mon Sep 17 00:00:00 2001 From: weiminc Date: Wed, 11 Feb 2026 06:23:05 +0000 Subject: [PATCH 4/7] pre-commit --- python/sglang/jit_kernel/activation.py | 8 ++++++-- .../jit_kernel/csrc/elementwise/activation.cuh | 14 ++++++++------ python/sglang/jit_kernel/tests/test_activation.py | 10 ++++++++-- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py index 407066916acf..266a1e0d900f 100644 --- a/python/sglang/jit_kernel/activation.py +++ b/python/sglang/jit_kernel/activation.py @@ -24,7 +24,9 @@ def _jit_activation_module(dtype: torch.dtype, act_type: str) -> Module: ) -def silu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def silu_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: if out is None: out = torch.empty( input.shape[:-1] + (input.shape[-1] // 2,), @@ -36,7 +38,9 @@ def silu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> tor return out -def gelu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def gelu_and_mul( + input: torch.Tensor, out: Optional[torch.Tensor] = None +) -> torch.Tensor: if out is None: out = torch.empty( input.shape[:-1] + (input.shape[-1] // 2,), diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh index 4ffb8868bbdb..fa129fbcba8f 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -1,9 +1,11 @@ #pragma once #include -#include -#include #include + +#include +#include + #include #define kBitsToLoad 128 @@ -63,8 +65,8 @@ SGL_DEVICE T gelu_tanh(const T& x) { } template -__global__ void act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, - int64_t num_tokens) { +__global__ void +act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); const int64_t token_idx = blockIdx.x; const int64_t thread_idx = threadIdx.x; @@ -95,8 +97,8 @@ __global__ void act_and_mul_kernel(T* __restrict__ out_ptr, const T* __restrict_ } template -__global__ void act_only_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, - int64_t num_tokens) { +__global__ void +act_only_kernel(T* __restrict__ out_ptr, const T* __restrict__ input_ptr, int64_t d, int64_t num_tokens) { constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); const int64_t token_idx = blockIdx.x; const int64_t thread_idx = threadIdx.x; diff --git a/python/sglang/jit_kernel/tests/test_activation.py b/python/sglang/jit_kernel/tests/test_activation.py index 59b559580314..4857651ac46c 100644 --- a/python/sglang/jit_kernel/tests/test_activation.py +++ b/python/sglang/jit_kernel/tests/test_activation.py @@ -1,13 +1,15 @@ import pytest import torch import torch.nn.functional as F + from sglang.jit_kernel.activation import ( - silu_and_mul, gelu_and_mul, - gelu_tanh_and_mul, gelu_quick, + gelu_tanh_and_mul, + silu_and_mul, ) + @pytest.mark.parametrize("dim", [128, 512, 1024]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -19,6 +21,7 @@ def test_silu_and_mul(dim, batch_size, dtype): rtol = 1e-3 if dtype == torch.float16 else 2e-2 torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + @pytest.mark.parametrize("dim", [128, 512, 1024]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -30,6 +33,7 @@ def test_gelu_and_mul(dim, batch_size, dtype): rtol = 1e-3 if dtype == torch.float16 else 2e-2 torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + @pytest.mark.parametrize("dim", [128, 512, 1024]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -41,6 +45,7 @@ def test_gelu_tanh_and_mul(dim, batch_size, dtype): rtol = 1e-3 if dtype == torch.float16 else 2e-2 torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + @pytest.mark.parametrize("dim", [128, 512, 1024]) @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -52,5 +57,6 @@ def test_gelu_quick(dim, batch_size, dtype): rtol = 1e-3 if dtype == torch.float16 else 2e-2 torch.testing.assert_close(y, y_ref, rtol=rtol, atol=atol) + if __name__ == "__main__": pytest.main([__file__]) From 7b6589b81daec943ab0dbbcf6d77b7dd75342050 Mon Sep 17 00:00:00 2001 From: weiminc Date: Thu, 12 Feb 2026 02:51:01 +0000 Subject: [PATCH 5/7] update --- .../csrc/elementwise/activation.cuh | 44 +---- sgl-kernel/csrc/elementwise/activation.cu | 170 ++++++++++++++++++ sgl-kernel/setup_rocm.py | 1 + 3 files changed, 180 insertions(+), 35 deletions(-) create mode 100644 sgl-kernel/csrc/elementwise/activation.cu diff --git a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh index fa129fbcba8f..95613d96b95c 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/activation.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/activation.cuh @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include @@ -11,57 +11,33 @@ #define kBitsToLoad 128 #define kBytesToLoad (kBitsToLoad / 8) -namespace { - -namespace detail { - -template -SGL_DEVICE float to_f32(const T& x) { -#if USE_ROCM - return castToFloat(x); -#else - return static_cast(x); -#endif -} - -template -SGL_DEVICE T from_f32(float f32) { -#if USE_ROCM - return castFromFloat(f32); -#else - return static_cast(f32); -#endif -} - -} // namespace detail - template SGL_DEVICE T silu(const T& x) { - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); + float f32_val = device::cast(x); + return device::cast(f32_val / (1.0f + expf(-f32_val))); } template SGL_DEVICE T gelu(const T& x) { constexpr float kAlpha = M_SQRT1_2; - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); + float f32_val = device::cast(x); + return device::cast(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); } // gelu_quick(x) = x * torch.sigmoid(1.702 * x) template SGL_DEVICE T gelu_quick_act(const T& x) { - float f32_val = detail::to_f32(x); - return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); + float f32_val = device::cast(x); + return device::cast(f32_val / (1.0f + expf(-f32_val * 1.702f))); } template SGL_DEVICE T gelu_tanh(const T& x) { constexpr float kAlpha = 0.044715f; constexpr float kBeta = 0.7978845608028654f; - float f32_val = detail::to_f32(x); + float f32_val = device::cast(x); const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); - return detail::from_f32(f32_val * cdf); + return device::cast(f32_val * cdf); } template @@ -203,5 +179,3 @@ using gelu_tanh_and_mul = ActivationAndMul; template using gelu_quick = ActivationOnly; - -} // namespace diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu new file mode 100644 index 000000000000..43617f87f318 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/activation.cu @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef USE_ROCM + +#include + +#include "utils.h" + +#else +#include "hip/hip_act_and_mul.cuh" +#endif + +// Adapted from flashinfer activation +// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 + +namespace detail { + +template +__device__ __forceinline__ float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif +} + +template +__device__ __forceinline__ T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} + +} // namespace detail + +template +__device__ __forceinline__ T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); +} + +template +__device__ __forceinline__ T gelu(const T& x) { + constexpr float kAlpha = M_SQRT1_2; + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +__device__ __forceinline__ T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); +} + +template +__device__ __forceinline__ T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); +} + +void silu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + + return true; + }); +} + +#if USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input) { + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); + sgl_hip::activation::act_only_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + return true; + }); +} +#endif diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 3bd5a136cbf2..66713bf0ae7d 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -45,6 +45,7 @@ def _get_version(): "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", + "csrc/elementwise/activation.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", "csrc/moe/moe_align_kernel.cu", From ddb58603d82504ec30b1a327e9eeafed9f8065ad Mon Sep 17 00:00:00 2001 From: weiminc Date: Sun, 22 Feb 2026 09:16:10 +0000 Subject: [PATCH 6/7] keep HIP code --- 3rdparty/amd/sgl-kernel/rocm_hipify.py | 1 + 1 file changed, 1 insertion(+) diff --git a/3rdparty/amd/sgl-kernel/rocm_hipify.py b/3rdparty/amd/sgl-kernel/rocm_hipify.py index 0187ffddc058..8373f741d1d6 100644 --- a/3rdparty/amd/sgl-kernel/rocm_hipify.py +++ b/3rdparty/amd/sgl-kernel/rocm_hipify.py @@ -16,6 +16,7 @@ "csrc/allreduce/deterministic_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", "csrc/common_extension_rocm.cc", + "csrc/elementwise/activation.cu", "csrc/elementwise/pos_enc.cu", "csrc/elementwise/topk.cu", "csrc/grammar/apply_token_bitmask_inplace_cuda.cu", From d39d1767cf022c57ad2e174c45b9a8258653353a Mon Sep 17 00:00:00 2001 From: weiminc Date: Sun, 22 Feb 2026 10:01:27 +0000 Subject: [PATCH 7/7] add _IS_ROCM guard --- python/sglang/jit_kernel/activation.py | 38 ++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/python/sglang/jit_kernel/activation.py b/python/sglang/jit_kernel/activation.py index 266a1e0d900f..89621129e49d 100644 --- a/python/sglang/jit_kernel/activation.py +++ b/python/sglang/jit_kernel/activation.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from tvm_ffi.module import Module +_IS_ROCM: bool = torch.version.hip is not None + @cache_once def _jit_activation_module(dtype: torch.dtype, act_type: str) -> Module: @@ -33,8 +35,13 @@ def silu_and_mul( device=input.device, dtype=input.dtype, ) - module = _jit_activation_module(input.dtype, "silu_and_mul") - module.silu_and_mul(out, input) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.silu_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "silu_and_mul") + module.silu_and_mul(out, input) return out @@ -47,8 +54,13 @@ def gelu_and_mul( device=input.device, dtype=input.dtype, ) - module = _jit_activation_module(input.dtype, "gelu_and_mul") - module.gelu_and_mul(out, input) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_and_mul") + module.gelu_and_mul(out, input) return out @@ -61,14 +73,24 @@ def gelu_tanh_and_mul( device=input.device, dtype=input.dtype, ) - module = _jit_activation_module(input.dtype, "gelu_tanh_and_mul") - module.gelu_tanh_and_mul(out, input) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_tanh_and_mul(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_tanh_and_mul") + module.gelu_tanh_and_mul(out, input) return out def gelu_quick(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: if out is None: out = torch.empty_like(input) - module = _jit_activation_module(input.dtype, "gelu_quick") - module.gelu_quick(out, input) + if _IS_ROCM: + import sgl_kernel + + sgl_kernel.gelu_quick(out, input) + else: + module = _jit_activation_module(input.dtype, "gelu_quick") + module.gelu_quick(out, input) return out