diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index d6e5766c83a3..1373c93fea65 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -128,6 +128,16 @@ namespace { #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) +template +static inline void CHECK_INPUT_SHAPE_DTYPE(const at::Tensor& tensor, const at::IntArrayRef sizes, at::ScalarType st) { + TORCH_CHECK(tensor.sizes() == sizes, "Input tensor shape mismatch: expected ", sizes, ", got ", tensor.sizes()); + TORCH_CHECK(tensor.scalar_type() == st, "Input tensor dtype mismatch"); + if constexpr (is_only_lastdim_contiguous) { + CHECK_LAST_DIM_CONTIGUOUS_INPUT(tensor); + } else { + CHECK_INPUT(tensor); + } +} #define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) // [NB] Parallel Routines diff --git a/sgl-kernel/csrc/cpu/mamba/fla.cpp b/sgl-kernel/csrc/cpu/mamba/fla.cpp index 9c4e5e7aecc0..dc0cdec23b0b 100644 --- a/sgl-kernel/csrc/cpu/mamba/fla.cpp +++ b/sgl-kernel/csrc/cpu/mamba/fla.cpp @@ -793,6 +793,234 @@ void chunk_gated_delta_rule_kernel_impl( } }); } + +inline float softplus(float x, double threshold = 20.0) { + if (x > threshold) + return x; + else if (x < -threshold) + return std::exp(x); + else + return std::log1p(std::exp(x)); +} + +inline at::vec::Vectorized softplus(const at::vec::Vectorized& x, double threshold = 20.0) { + using Vec = at::vec::Vectorized; + Vec mask_hi = x > Vec(threshold); + Vec mask_lo = x < Vec(-threshold); + + Vec expx = x.exp_u20(); + Vec log1pex = (expx + Vec(1.0f)).log(); + + return Vec::blendv(Vec::blendv(log1pex, expx, mask_lo), x, mask_hi); +} + +template +void fused_sigmoid_gating_delta_rule_update_kernel_impl( + const scalar_t* __restrict__ q_ptr, + const scalar_t* __restrict__ k_ptr, + const scalar_t* __restrict__ v_ptr, + const float* __restrict__ A_log_ptr, + const scalar_t* __restrict__ a_ptr, + const scalar_t* __restrict__ dt_bias_ptr, + const scalar_t* __restrict__ b_ptr, + const int32_t* __restrict__ indices_ptr, + float* __restrict__ state_ptr, + scalar_t* __restrict__ o_ptr, + float* __restrict__ qk_scale_buf, + int64_t seq_len, + int64_t batch_size, + int64_t num_heads, + int64_t head_dim, + int64_t v_num_heads, + int64_t v_head_dim, + int64_t q_strideB, + int64_t q_strideS, + int64_t q_strideH, + int64_t k_strideB, + int64_t k_strideS, + int64_t k_strideH, + int64_t v_strideB, + int64_t v_strideS, + int64_t v_strideH, + bool use_qk_l2norm_in_kernel, + double softplus_threshold) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int64_t VecSize = bVec::size(); + constexpr int64_t fVecSize = fVec::size(); + int64_t group_size = v_num_heads / num_heads; + double scale = 1 / std::sqrt(head_dim); + fVec scale_vec = fVec(scale); + if (use_qk_l2norm_in_kernel) { + float eps = 1e-5; + at::parallel_for(0, batch_size * seq_len * num_heads, 0, [&](int64_t begin, int64_t end) { + int64_t bi{0}, si{0}, ni{0}; + data_index_init(begin, bi, batch_size, si, seq_len, ni, num_heads); + for (int64_t i = begin; i < end; ++i) { + float sum_q = float(0); + float sum_k = float(0); + fVec sum_q_fvec = fVec(float(0)); + fVec sum_k_fvec = fVec(float(0)); + int64_t q_offset = bi * q_strideB + si * q_strideS + ni * q_strideH; + int64_t k_offset = bi * k_strideB + si * k_strideS + ni * k_strideH; + int64_t q_scale_offset = bi * seq_len * num_heads + si * num_heads + ni; + int64_t k_scale_offset = q_scale_offset + batch_size * seq_len * num_heads; + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= head_dim - VecSize; d += VecSize) { + bVec q_bvec = bVec::loadu(q_ptr + q_offset + d); + fVec q_fvec0, q_fvec1; + std::tie(q_fvec0, q_fvec1) = at::vec::convert_to_float(q_bvec); + sum_q_fvec += q_fvec0 * q_fvec0; + sum_q_fvec += q_fvec1 * q_fvec1; + bVec k_bvec = bVec::loadu(k_ptr + k_offset + d); + fVec k_fvec0, k_fvec1; + std::tie(k_fvec0, k_fvec1) = at::vec::convert_to_float(k_bvec); + sum_k_fvec += k_fvec0 * k_fvec0; + sum_k_fvec += k_fvec1 * k_fvec1; + } +#pragma GCC unroll 4 + for (; d < head_dim; ++d) { + float q_val = static_cast(q_ptr[q_offset + d]); + sum_q += q_val * q_val; + float k_val = static_cast(k_ptr[k_offset + d]); + sum_k += k_val * k_val; + } + + sum_q += vec_reduce_sum(sum_q_fvec); + sum_k += vec_reduce_sum(sum_k_fvec); + qk_scale_buf[q_scale_offset] = float(1) / std::sqrt(sum_q + eps); + qk_scale_buf[k_scale_offset] = float(1) / std::sqrt(sum_k + eps); + + data_index_step(bi, batch_size, si, seq_len, ni, num_heads); + } + }); + } + at::parallel_for(0, batch_size * seq_len * v_num_heads, 0, [&](int64_t begin, int64_t end) { + int64_t bi{0}, si{0}, ni{0}; + data_index_init(begin, bi, batch_size, si, seq_len, ni, v_num_heads); + for (int64_t i = begin; i < end; ++i) { + int64_t cache_index = indices_ptr[bi]; + int64_t state_offset = (cache_index * v_num_heads + ni) * head_dim * v_head_dim; + float g_val = -std::exp(A_log_ptr[ni]) * + softplus(float(a_ptr[bi * v_num_heads + ni]) + float(dt_bias_ptr[ni]), softplus_threshold); + float g_val_exp = std::exp(g_val); + fVec g_val_exp_vec = fVec(g_val_exp); + int64_t q_offset = si * q_strideS + bi * q_strideB + (ni / group_size) * q_strideH; + int64_t k_offset = si * k_strideS + bi * k_strideB + (ni / group_size) * k_strideH; + int64_t q_scale_offset = bi * seq_len * num_heads + si * num_heads + (ni / group_size); + int64_t k_scale_offset = q_scale_offset + batch_size * seq_len * num_heads; + float q_scale = use_qk_l2norm_in_kernel ? qk_scale_buf[q_scale_offset] : 1.0f; + float k_scale = use_qk_l2norm_in_kernel ? qk_scale_buf[k_scale_offset] : 1.0f; + int64_t v_offset = si * v_strideS + bi * v_strideB + ni * v_strideH; + int64_t o_offset = ((bi * seq_len + si) * v_num_heads + ni) * v_head_dim; + float beta_val = 1 / (1 + std::exp(-b_ptr[ni])); + fVec beta_vec = fVec(beta_val); + int64_t dvi = 0; + for (; dvi <= v_head_dim - VecSize; dvi += VecSize) { + fVec kv_mem_vec0 = fVec(float(0)); + fVec kv_mem_vec1 = fVec(float(0)); + for (int di = 0; di < head_dim; ++di) { + fVec k_val_vec = fVec(k_ptr[k_offset + di] * k_scale); + fVec state_vec0 = fVec::loadu(state_ptr + state_offset + di * v_head_dim + dvi); + fVec state_vec1 = fVec::loadu(state_ptr + state_offset + di * v_head_dim + dvi + fVecSize); + kv_mem_vec0 = kv_mem_vec0 + state_vec0 * g_val_exp_vec * k_val_vec; + kv_mem_vec1 = kv_mem_vec1 + state_vec1 * g_val_exp_vec * k_val_vec; + } + bVec v_bvec = bVec::loadu(v_ptr + v_offset + dvi); + fVec v_vec0, v_vec1; + std::tie(v_vec0, v_vec1) = at::vec::convert_to_float(v_bvec); + fVec dt_vec0 = (v_vec0 - kv_mem_vec0) * beta_vec; + fVec dt_vec1 = (v_vec1 - kv_mem_vec1) * beta_vec; + fVec o_vec0 = fVec(float(0)); + fVec o_vec1 = fVec(float(0)); + for (int di = 0; di < head_dim; ++di) { + fVec q_vec = fVec(q_ptr[q_offset + di] * q_scale); + fVec k_vec = fVec(k_ptr[k_offset + di] * k_scale); + fVec state_vec0 = fVec::loadu(state_ptr + state_offset + di * v_head_dim + dvi); + fVec state_vec1 = fVec::loadu(state_ptr + state_offset + di * v_head_dim + dvi + fVecSize); + state_vec0 = state_vec0 * g_val_exp_vec + k_vec * dt_vec0; + state_vec1 = state_vec1 * g_val_exp_vec + k_vec * dt_vec1; + o_vec0 = o_vec0 + state_vec0 * q_vec * scale_vec; + o_vec1 = o_vec1 + state_vec1 * q_vec * scale_vec; + state_vec0.store(state_ptr + state_offset + di * v_head_dim + dvi); + state_vec1.store(state_ptr + state_offset + di * v_head_dim + dvi + fVecSize); + } + bVec o_vec = at::vec::convert_from_float(o_vec0, o_vec1); + o_vec.store(o_ptr + o_offset + dvi); + } + for (; dvi < v_head_dim; ++dvi) { + float kv_mem_val = 0; + for (int di = 0; di < head_dim; ++di) { + float k_val = k_ptr[k_offset + di] * k_scale; + state_ptr[state_offset + di * v_head_dim + dvi] *= g_val_exp; + kv_mem_val += state_ptr[state_offset + di * v_head_dim + dvi] * k_val; + } + float v_val = v_ptr[v_offset + dvi]; + float dt_val = (v_val - kv_mem_val) * beta_val; + float o_val = 0; + for (int di = 0; di < head_dim; ++di) { + float q_val = q_ptr[q_offset + di] * q_scale; + float k_val = k_ptr[k_offset + di] * k_scale; + state_ptr[state_offset + di * v_head_dim + dvi] += k_val * dt_val; + o_val += state_ptr[state_offset + di * v_head_dim + dvi] * q_val * scale; + } + o_ptr[o_offset + dvi] = o_val; + } + data_index_step(bi, batch_size, si, seq_len, ni, v_num_heads); + } + }); +} + +template +void fused_gdn_gating_kernel_impl( + float* __restrict__ A_log, + const scalar_t* __restrict__ a, + const scalar_t* __restrict__ b, + const scalar_t* __restrict__ dt_bias, + float* __restrict__ out, + scalar_t* __restrict__ beta, + int64_t batch, + int64_t num_heads) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int vec_size = bVec::size(); + constexpr int fvec_size = fVec::size(); + const fVec neg_one(-1.0f); + const fVec one(1.0f); + at::parallel_for(0, batch, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t j = 0; + for (; j < num_heads - (num_heads % vec_size); j += vec_size) { + fVec A_log_vec0 = fVec::loadu(A_log + j); + fVec A_log_vec1 = fVec::loadu(A_log + j + fvec_size); + bVec dt_bias_vec = bVec::loadu(dt_bias + j); + bVec a_bvec = bVec::loadu(a + i * num_heads + j); + bVec b_bvec = bVec::loadu(b + i * num_heads + j); + fVec a0, a1, dt_bias_vec0, dt_bias_vec1, b0, b1; + std::tie(a0, a1) = at::vec::convert_to_float(a_bvec); + std::tie(b0, b1) = at::vec::convert_to_float(b_bvec); + std::tie(dt_bias_vec0, dt_bias_vec1) = at::vec::convert_to_float(dt_bias_vec); + + fVec g0 = neg_one * A_log_vec0.exp_u20() * softplus(a0 + dt_bias_vec0); + fVec g1 = neg_one * A_log_vec1.exp_u20() * softplus(a1 + dt_bias_vec1); + fVec beta0 = one / (one + (neg_one * b0).exp_u20()); + fVec beta1 = one / (one + (neg_one * b1).exp_u20()); + + g0.store(out + i * num_heads + j); + g1.store(out + i * num_heads + j + fvec_size); + bVec beta_vec = at::vec::convert_from_float(beta0, beta1); + beta_vec.store(beta + i * num_heads + j); + } + for (; j < num_heads; ++j) { + out[i * num_heads + j] = -std::exp(A_log[j]) * softplus(float(a[i * num_heads + j]) + float(dt_bias[j])); + beta[i * num_heads + j] = 1 / (1 + std::exp(-b[i * num_heads + j])); + } + } + }); +} + } // anonymous namespace template @@ -981,3 +1209,133 @@ std::tuple chunk_gated_delta_rule_cpu( }); return std::make_tuple(std::move(output), std::move(final_state)); } + +// A_log: [v_num_heads] +// dt_bias: [v_num_heads] +// query: [seq_len, batch_size, num_heads, head_dim] +// key: [seq_len, batch_size, num_heads, head_dim] +// value: [seq_len, batch_size, v_num_heads, v_head_dim] +// a: [batch_size, v_num_heads] +// b: [batch_size, v_num_heads] +// initial_state_source:[num_tokens, v_num_heads, head_dim, v_head_dim] +// initial_state_indices: [batch_size] +// cu_seqlens: [batch_size + 1] +at::Tensor fused_sigmoid_gating_delta_rule_update_cpu( + const at::Tensor& A_log, + const at::Tensor& dt_bias, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& a, + const at::Tensor& b, + at::Tensor& initial_state_source, + const at::Tensor& initial_state_indices, + const at::Tensor& cu_seqlens, + bool use_qk_l2norm_in_kernel, + double softplus_beta = 1.0, + double softplus_threshold = 20.0) { + RECORD_FUNCTION( + "sgl-kernel::fused_sigmoid_gating_delta_rule_update_cpu", + std::vector( + {A_log, dt_bias, q, k, v, a, b, initial_state_source, initial_state_indices, cu_seqlens})); + CHECK_DIM(4, q); + CHECK_DIM(4, v); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); + int64_t seq_len = q.size(0); + int64_t batch_size = q.size(1); + int64_t num_heads = q.size(2); + int64_t head_dim = q.size(3); + int64_t v_num_heads = v.size(2); + int64_t v_head_dim = v.size(3); + CHECK_INPUT_SHAPE_DTYPE(k, {seq_len, batch_size, num_heads, head_dim}, q.scalar_type()); + CHECK_INPUT_SHAPE_DTYPE(v, {seq_len, batch_size, v_num_heads, v_head_dim}, q.scalar_type()); + CHECK_INPUT_SHAPE_DTYPE(A_log, {v_num_heads}, at::kFloat); + CHECK_INPUT_SHAPE_DTYPE(a, {batch_size, v_num_heads}, q.scalar_type()); + CHECK_INPUT_SHAPE_DTYPE(dt_bias, {v_num_heads}, q.scalar_type()); + CHECK_INPUT_SHAPE_DTYPE(b, {batch_size, v_num_heads}, q.scalar_type()); + CHECK_INPUT_SHAPE_DTYPE(initial_state_indices, {batch_size}, at::kInt); + CHECK_INPUT_SHAPE_DTYPE(cu_seqlens, {batch_size + 1}, at::kInt); + CHECK_INPUT_SHAPE_DTYPE( + initial_state_source, {initial_state_source.size(0), v_num_heads, head_dim, v_head_dim}, at::kFloat); + CHECK(initial_state_source.size(0) >= batch_size); + CHECK_EQ(v_num_heads % num_heads, 0); + + int64_t q_strideB = q.stride(1); + int64_t q_strideS = q.stride(0); + int64_t q_strideH = q.stride(2); + int64_t k_strideB = k.stride(1); + int64_t k_strideS = k.stride(0); + int64_t k_strideH = k.stride(2); + int64_t v_strideB = v.stride(1); + int64_t v_strideS = v.stride(0); + int64_t v_strideH = v.stride(2); + at::Tensor core_attn_out = at::empty({batch_size, seq_len, v_num_heads, v_head_dim}, q.options()); + at::Tensor qk_scale_buf = at::empty({2 * batch_size, seq_len, num_heads}, at::kFloat); + AT_DISPATCH_REDUCED_FLOATING_TYPES(q.scalar_type(), "fused_sigmoid_gating_delta_rule_update_kernel_impl", [&] { + fused_sigmoid_gating_delta_rule_update_kernel_impl( + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + A_log.data_ptr(), + a.data_ptr(), + dt_bias.data_ptr(), + b.data_ptr(), + initial_state_indices.data_ptr(), + initial_state_source.data_ptr(), + core_attn_out.data_ptr(), + qk_scale_buf.data_ptr(), + seq_len, + batch_size, + num_heads, + head_dim, + v_num_heads, + v_head_dim, + q_strideB, + q_strideS, + q_strideH, + k_strideB, + k_strideS, + k_strideH, + v_strideB, + v_strideS, + v_strideH, + use_qk_l2norm_in_kernel, + softplus_threshold); + }); + return core_attn_out; +} + +// A_log: [num_v_heads] +// a: [batch, num_v_heads] +// b: [batch, num_v_heads] +// dt_bias: [num_v_heads] +// -A_log.float().exp() * F.softplus(a.float() + dt_bias) +std::tuple +fused_gdn_gating_cpu(const at::Tensor& A_log, const at::Tensor& a, const at::Tensor& b, const at::Tensor& dt_bias) { + RECORD_FUNCTION("sgl-kernel::fused_gdn_gating_cpu", std::vector({A_log, a, b, dt_bias})); + CHECK_DIM(1, A_log); + CHECK_DIM(2, a); + CHECK_DIM(2, b); + CHECK_DIM(1, dt_bias); + CHECK_CONTIGUOUS(a); + CHECK_EQ(A_log.size(0), a.size(1)); + CHECK_EQ(A_log.size(0), dt_bias.size(0)); + int batch = a.size(0); + int num_heads = a.size(1); + CHECK_EQ(b.size(0), batch); + CHECK_EQ(b.size(1), num_heads); + at::Tensor out = at::empty({1, batch, num_heads}, a.options().dtype(at::kFloat)); + at::Tensor beta = at::empty({1, batch, num_heads}, b.options()); + AT_DISPATCH_REDUCED_FLOATING_TYPES(a.scalar_type(), "fused_gdn_gating_kernel", [&] { + fused_gdn_gating_kernel_impl( + A_log.data_ptr(), + a.data_ptr(), + b.data_ptr(), + dt_bias.data_ptr(), + out.data_ptr(), + beta.data_ptr(), + batch, + num_heads); + }); + return std::make_tuple(out, beta); +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index b39850a696a7..aede839fa2fe 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -284,6 +284,25 @@ std::tuple rotary_embedding_cpu( // CPU and memory binding std::string init_cpu_threads_env(const std::string& cpu_ids); +// fused_sigmoid_gating_delta_rule_update +at::Tensor fused_sigmoid_gating_delta_rule_update_cpu( + const at::Tensor& A_log, + const at::Tensor& dt_bias, + const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& a, + const at::Tensor& b, + at::Tensor& initial_state_source, + const at::Tensor& initial_state_indices, + const at::Tensor& cu_seqlens, + bool use_qk_l2norm_in_kernel, + double softplus_beta = 1.0, + double softplus_threshold = 20.0); +// fused_gdn_gating +std::tuple +fused_gdn_gating_cpu(const at::Tensor& A_log, const at::Tensor& a, const at::Tensor& b, const at::Tensor& dt_bias); + // fused_qkvzba_split_reshape_cat_cpu std::tuple fused_qkvzba_split_reshape_cat_cpu( const at::Tensor& mixed_qkvz, @@ -453,6 +472,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // CPU and memory binding m.def("init_cpu_threads_env(str cpu_ids) -> str"); + // fused_sigmoid_gating_delta_rule_update + m.def( + "fused_sigmoid_gating_delta_rule_update_cpu(Tensor A_log, Tensor dt_bias, Tensor q, Tensor k, Tensor v, Tensor " + "a, Tensor b, Tensor(a!) initial_state_source, Tensor initial_state_indices, Tensor cu_seqlens, bool " + "use_qk_l2norm_in_kernel, float softplus_beta=1.0, float softplus_threshold=20.0) -> Tensor"); + m.impl("fused_sigmoid_gating_delta_rule_update_cpu", torch::kCPU, &fused_sigmoid_gating_delta_rule_update_cpu); + // fused_gdn_gating + m.def("fused_gdn_gating_cpu(Tensor A_log, Tensor a, Tensor b, Tensor dt_bias) -> (Tensor, Tensor)"); + m.impl("fused_gdn_gating_cpu", torch::kCPU, &fused_gdn_gating_cpu); // fused_qkvzba_split_reshape_cat_cpu m.def( "fused_qkvzba_split_reshape_cat_cpu(Tensor mixed_qkvz, Tensor mixed_ba, int num_heads_qk, int num_heads_v, int " diff --git a/test/srt/cpu/test_mamba.py b/test/srt/cpu/test_mamba.py index 707f76f54102..f64c3713aba2 100644 --- a/test/srt/cpu/test_mamba.py +++ b/test/srt/cpu/test_mamba.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F +from torch.nn.functional import softplus from utils import precision from sglang.test.test_utils import CustomTestCase @@ -145,6 +146,92 @@ def chunk_gated_delta_rule_update( return output, final_state +def torch_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to( + value + ) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def sigmoid_gating_delta_rule_update( + query, + key, + value, + A_log, + a, + dt_bias, + b, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + beta = b.sigmoid() + g = -A_log.float().exp() * softplus(a.float() + dt_bias) + return torch_recurrent_gated_delta_rule( + query, + key, + value, + g.unsqueeze(0), + beta.unsqueeze(0), + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + +def torch_gdn_gating(A_log, a, b, dt_bias): + return -A_log.float().exp() * softplus(a.float() + dt_bias).unsqueeze( + 0 + ), b.sigmoid().unsqueeze(0) + + class TestMambaAttention(CustomTestCase): def test_chunk_gated_delta_rule(self): B, L, HK, HV, EK, EV, N = 1, 100, 3, 6, 64, 64, 4 @@ -201,6 +288,100 @@ def test_chunk_gated_delta_rule(self): last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol ) + def test_fused_gdn_gating(self): + dims = [6, 32] + for dim in dims: + A_log = torch.rand(dim) + a = torch.rand(1024, dim, dtype=torch.bfloat16) + b = torch.rand(1024, dim, dtype=torch.bfloat16) + dt_bias = torch.rand(dim, dtype=torch.bfloat16) + + g, beta = torch_gdn_gating(A_log, a, b, dt_bias) + g_sgl, beta_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu( + A_log, a, b, dt_bias + ) + atol = rtol = precision[g.dtype] + atol2 = rtol2 = precision[beta.dtype] + torch.testing.assert_close(g, g_sgl, atol=atol, rtol=rtol) + torch.testing.assert_close(beta, beta_sgl, atol=atol2, rtol=rtol2) + + def test_fused_sigmoid_gating_delta_rule_update(self): + batch_size = 1 + num_value_heads = 32 + head_k_dim = 128 + head_v_dim = 128 + num_heads = 16 + seq_len = 1 + attn_tp_size = 1 + key_dim = head_k_dim * num_heads + value_dim = head_v_dim * num_value_heads + mixed_qkv_dim = (key_dim * 2 + value_dim) // attn_tp_size + mixed_qkv = torch.rand( + seq_len * batch_size, mixed_qkv_dim, dtype=torch.bfloat16 + ) + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // attn_tp_size, + key_dim // attn_tp_size, + value_dim // attn_tp_size, + ], + dim=-1, + ) + query = query.view(1, seq_len, num_heads, head_k_dim) + key = key.view(1, seq_len, num_heads, head_k_dim) + value = value.view(1, seq_len, num_value_heads, head_v_dim) + A_log = torch.rand(num_value_heads, dtype=torch.float32) + a = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) + b = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) + dt_bias = torch.rand(num_value_heads, dtype=torch.bfloat16) + ssm_states = torch.rand( + 513, num_value_heads, head_k_dim, head_v_dim, dtype=torch.float32 + ) + cache_indices = torch.randint(0, 513, (batch_size,), dtype=torch.int32) + query_start_loc = torch.tensor([0, 1], dtype=torch.int32) + use_qk_l2norm_in_kernel = True + query_ref = query.clone() + key_ref = key.clone() + if num_value_heads // num_heads > 1: + query_ref = query_ref.repeat_interleave(num_value_heads // num_heads, dim=2) + key_ref = key_ref.repeat_interleave(num_value_heads // num_heads, dim=2) + core_attn_out_ref, last_recurrent_state_ref = sigmoid_gating_delta_rule_update( + query_ref.transpose(0, 1), + key_ref.transpose(0, 1), + value.transpose(0, 1), + A_log, + a, + dt_bias, + b, + initial_state=ssm_states[cache_indices], + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + core_attn_out = torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu( + A_log=A_log, + dt_bias=dt_bias, + q=query, + k=key, + v=value, + a=a, + b=b, + initial_state_source=ssm_states, + initial_state_indices=cache_indices, + cu_seqlens=query_start_loc, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + last_recurrent_state = ssm_states[cache_indices] + atol = rtol = precision[core_attn_out.dtype] + torch.testing.assert_close( + core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol + ) + if __name__ == "__main__": unittest.main()