diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index 421fbedab84e..5b239a5a4319 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -24,6 +24,7 @@ include_directories( ${Python_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR} ) # Platform-specific library directory @@ -56,7 +57,7 @@ if(DEFINED ENV{CONDA_PREFIX}) endif() endif() -file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") +file(GLOB_RECURSE SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") if(NOT DEFINED ENV{SGLANG_CPU_FP8_CVT_FTZ}) set(ENV{SGLANG_CPU_FP8_CVT_FTZ} "1") diff --git a/sgl-kernel/csrc/cpu/mamba/fused_gdn_gating.cpp b/sgl-kernel/csrc/cpu/mamba/fused_gdn_gating.cpp new file mode 100644 index 000000000000..dde7d8542f48 --- /dev/null +++ b/sgl-kernel/csrc/cpu/mamba/fused_gdn_gating.cpp @@ -0,0 +1,88 @@ +#include "common.h" +#include "vec.h" + +namespace { + +inline float softplus(float x, double threshold = 20.0) { + if (x > threshold) + return x; + else if (x < -threshold) + return std::exp(x); + else + return std::log1p(std::exp(x)); +} + +inline at::vec::Vectorized softplus(const at::vec::Vectorized& x, double threshold = 20.0) { + at::vec::Vectorized mask_hi = x > at::vec::Vectorized(threshold); + at::vec::Vectorized mask_lo = x < at::vec::Vectorized(-threshold); + + at::vec::Vectorized expx = x.exp(); + at::vec::Vectorized log1pex = (expx + at::vec::Vectorized(1.0f)).log(); + + return at::vec::Vectorized::blendv(at::vec::Vectorized::blendv(log1pex, expx, mask_lo), x, mask_hi); +} +template +void fused_gdn_gating_kernel_impl( + float* __restrict__ A_log, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ dt_bias, + float* __restrict__ out, + int64_t batch, + int64_t num_heads) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int vec_size = bVec::size(); + constexpr int fvec_size = fVec::size(); + fVec neg_one(-1.0f); + at::parallel_for(0, batch, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t j = 0; + for (; j < num_heads - (num_heads % vec_size); j += vec_size) { + fVec A_log_vec0 = fVec::loadu(A_log + j); + fVec A_log_vec1 = fVec::loadu(A_log + j + fvec_size); + bVec dt_bias_vec = bVec::loadu(dt_bias + j); + bVec a_bvec = bVec::loadu(a + i * num_heads + j); + fVec a0, a1, dt_bias_vec0, dt_bias_vec1; + std::tie(a0, a1) = at::vec::convert_to_float(a_bvec); + std::tie(dt_bias_vec0, dt_bias_vec1) = at::vec::convert_to_float(dt_bias_vec); + + fVec g0 = neg_one * A_log_vec0.exp() * softplus(a0 + dt_bias_vec0); + fVec g1 = neg_one * A_log_vec1.exp() * softplus(a1 + dt_bias_vec1); + + g0.store(out + i * num_heads + j); + g1.store(out + i * num_heads + j + fvec_size); + } + for (; j < num_heads; ++j) { + out[i * num_heads + j] = -std::exp(A_log[j]) * softplus(float(a[i * num_heads + j]) + float(dt_bias[j])); + } + } + }); +} +} // anonymous namespace + +// A_log: [num_v_heads] +// a: [batch, num_v_heads] +// dt_bias: [num_v_heads] +// -A_log.float().exp() * F.softplus(a.float() + dt_bias) +at::Tensor fused_gdn_gating_cpu(const at::Tensor& A_log, const at::Tensor& a, const at::Tensor& dt_bias) { + RECORD_FUNCTION("sgl-kernel::fused_gdn_gating_cpu", std::vector({A_log, a, dt_bias})); + CHECK_DIM(1, A_log); + CHECK_DIM(2, a); + CHECK_DIM(1, dt_bias); + CHECK_CONTIGUOUS(a); + CHECK_EQ(A_log.size(0), a.size(1)); + CHECK_EQ(A_log.size(0), dt_bias.size(0)); + int batch = a.size(0); + int num_heads = a.size(1); + at::Tensor out = at::empty_like(a, a.options().dtype(at::kFloat)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(a.scalar_type(), "fused_gdn_gating_kernel", [&] { + fused_gdn_gating_kernel_impl( + A_log.data_ptr(), + a.data_ptr(), + dt_bias.data_ptr(), + out.data_ptr(), + batch, + num_heads); + }); + return out; +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 2c8d9e3ececc..1d93aa55c798 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -233,6 +233,9 @@ std::tuple rotary_embedding_cpu( // CPU and memory binding std::string init_cpu_threads_env(const std::string& cpu_ids); +// fused_gdn_gating +at::Tensor fused_gdn_gating_cpu(const at::Tensor& A_log, const at::Tensor& a, const at::Tensor& dt_bias); + TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); @@ -363,6 +366,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // CPU and memory binding m.def("init_cpu_threads_env(str cpu_ids) -> str"); + // fused_gdn_gating + m.def("fused_gdn_gating_cpu(Tensor A_log, Tensor a, Tensor dt_bias) -> Tensor"); + m.impl("fused_gdn_gating_cpu", torch::kCPU, &fused_gdn_gating_cpu); } TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { diff --git a/test/srt/cpu/test_mamba.py b/test/srt/cpu/test_mamba.py new file mode 100644 index 000000000000..8abde641c24a --- /dev/null +++ b/test/srt/cpu/test_mamba.py @@ -0,0 +1,33 @@ +import unittest + +import sgl_kernel +import torch +import torch.nn.functional as F +from torch.nn.functional import softplus +from utils import precision + +from sglang.test.test_utils import CustomTestCase + +torch.manual_seed(1234) + + +def torch_gdn_gating(A_log, a, dt_bias): + return -A_log.float().exp() * softplus(a.float() + dt_bias) + + +class TestMambaAttention(CustomTestCase): + def test_fused_gdn_gating(self): + dims = [6, 32] + for dim in dims: + A_log = torch.rand(dim) + a = torch.rand(1024, dim, dtype=torch.bfloat16) + dt_bias = torch.rand(dim, dtype=torch.bfloat16) + + g = torch_gdn_gating(A_log, a, dt_bias) + g_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu(A_log, a, dt_bias) + atol = rtol = precision[g.dtype] + self.assertTrue(torch.allclose(g, g_sgl, atol=atol, rtol=rtol)) + + +if __name__ == "__main__": + unittest.main()