diff --git a/sgl-kernel/csrc/cpu/model/qwen3.cpp b/sgl-kernel/csrc/cpu/model/qwen3.cpp new file mode 100644 index 000000000000..3a2ce6d6a3a1 --- /dev/null +++ b/sgl-kernel/csrc/cpu/model/qwen3.cpp @@ -0,0 +1,115 @@ +#include "common.h" +#include "vec.h" +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) { + using bVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + int64_t d = 0; +#pragma GCC unroll 4 + for (; d <= size - kVecSize; d += kVecSize) { + bVec out_bvec = bVec::loadu(src + d); + out_bvec.store(out + d); + } + for (; d < size; ++d) { + out[d] = src[d]; + } +} + +template +void fused_qkvzba_split_reshape_cat_impl( + const scalar_t* __restrict__ mixed_qkvz, + const scalar_t* __restrict__ mixed_ba, + scalar_t* __restrict__ mixed_qkv, + scalar_t* __restrict__ z, + scalar_t* __restrict__ b, + scalar_t* __restrict__ a, + int64_t batch, + int64_t num_heads_qk, + int64_t num_heads_v, + int64_t head_qk, + int64_t group, + int64_t head_v, + int64_t qkv_strideB, + int64_t qkvz_strideB, + int64_t ba_strideB) { + int64_t qkvz_stride_per_head = head_qk * 2 + head_v * 2 * group; + at::parallel_for(0, batch * num_heads_qk, 0, [&](int64_t begin, int64_t end) { + int64_t bi{0}, hi{0}; + data_index_init(begin, bi, batch, hi, num_heads_qk); + for (int64_t i = begin; i < end; ++i) { + scalar_t* __restrict__ q_out_ptr = mixed_qkv + bi * qkv_strideB + hi * head_qk; + const scalar_t* __restrict__ q_in_ptr = mixed_qkvz + bi * qkvz_strideB + hi * qkvz_stride_per_head; + scalar_t* __restrict__ k_out_ptr = q_out_ptr + num_heads_qk * head_qk; + const scalar_t* __restrict__ k_in_ptr = q_in_ptr + head_qk; + scalar_t* __restrict__ v_out_ptr = k_out_ptr + num_heads_qk * head_qk + hi * head_qk * (group - 1); + const scalar_t* __restrict__ v_in_ptr = k_in_ptr + head_qk; + scalar_t* __restrict__ z_out_ptr = z + bi * num_heads_v * head_v + hi * group * head_v; + const scalar_t* __restrict__ z_in_ptr = v_in_ptr + head_qk * group; + copy_stub(q_out_ptr, q_in_ptr, head_qk); + copy_stub(k_out_ptr, k_in_ptr, head_qk); + copy_stub(v_out_ptr, v_in_ptr, head_qk * group); + copy_stub(z_out_ptr, z_in_ptr, head_qk * group); + scalar_t* __restrict__ b_out_ptr = b + bi * num_heads_v + hi * group; + const scalar_t* __restrict__ b_in_ptr = mixed_ba + bi * ba_strideB + hi * group * 2; + scalar_t* __restrict__ a_out_ptr = a + bi * num_heads_v + hi * group; + const scalar_t* __restrict__ a_in_ptr = b_in_ptr + group; + copy_stub(b_out_ptr, b_in_ptr, group); + copy_stub(a_out_ptr, a_in_ptr, group); + data_index_step(bi, batch, hi, num_heads_qk); + } + }); +} +} // anonymous namespace + +// mixed_qkvz: [batch, num_heads_qk * head_qk * 2 + num_heads_v * head_v * 2] +// mixed_ba: [batch, num_heads_v * 2] +std::tuple fused_qkvzba_split_reshape_cat_cpu( + const at::Tensor& mixed_qkvz, + const at::Tensor& mixed_ba, + int64_t num_heads_qk, + int64_t num_heads_v, + int64_t head_qk, + int64_t head_v) { + RECORD_FUNCTION("sgl-kernel::fused_qkvzba_split_reshape_cat_cpu", std::vector({mixed_qkvz, mixed_ba})); + CHECK_DIM(2, mixed_qkvz); + CHECK_DIM(2, mixed_ba); + CHECK_INPUT(mixed_qkvz); + CHECK_INPUT(mixed_ba); + int64_t batch = mixed_qkvz.size(0); + int64_t qkv_dim = num_heads_qk * head_qk * 2 + num_heads_v * head_v; + int64_t ba_dim = num_heads_v * 2; + int64_t expected_dim = qkv_dim + num_heads_v * head_v; + CHECK_EQ(mixed_qkvz.size(1), expected_dim); + CHECK_EQ(mixed_ba.size(0), batch); + CHECK_EQ(mixed_ba.size(1), ba_dim); + CHECK_EQ(num_heads_v % num_heads_qk, 0); + at::Tensor mixed_qkv = at::empty({batch, qkv_dim}, mixed_qkvz.options()); + at::Tensor z = at::empty({batch, num_heads_v, head_v}, mixed_qkvz.options()); + at::Tensor b = at::empty({batch, num_heads_v}, mixed_ba.options()); + at::Tensor a = at::empty({batch, num_heads_v}, mixed_ba.options()); + int64_t group = num_heads_v / num_heads_qk; + int64_t qkvz_strideB = mixed_qkvz.size(1); + int64_t qkv_strideB = mixed_qkv.size(1); + int64_t ba_strideB = mixed_ba.size(1); + AT_DISPATCH_REDUCED_FLOATING_TYPES(mixed_qkvz.scalar_type(), "fused_qkvzba_split_reshape_cat_impl", [&] { + fused_qkvzba_split_reshape_cat_impl( + mixed_qkvz.data_ptr(), + mixed_ba.data_ptr(), + mixed_qkv.data_ptr(), + z.data_ptr(), + b.data_ptr(), + a.data_ptr(), + batch, + num_heads_qk, + num_heads_v, + head_qk, + group, + head_v, + qkv_strideB, + qkvz_strideB, + ba_strideB); + }); + return std::make_tuple(mixed_qkv, z, b, a); +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 23f68c403d53..a620b930caa5 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -250,6 +250,15 @@ std::tuple rotary_embedding_cpu( // CPU and memory binding std::string init_cpu_threads_env(const std::string& cpu_ids); +// fused_qkvzba_split_reshape_cat_cpu +std::tuple fused_qkvzba_split_reshape_cat_cpu( + const at::Tensor& mixed_qkvz, + const at::Tensor& mixed_ba, + int64_t num_heads_qk, + int64_t num_heads_v, + int64_t head_qk, + int64_t head_v); + TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); @@ -389,6 +398,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // CPU and memory binding m.def("init_cpu_threads_env(str cpu_ids) -> str"); + + // fused_qkvzba_split_reshape_cat_cpu + m.def( + "fused_qkvzba_split_reshape_cat_cpu(Tensor mixed_qkvz, Tensor mixed_ba, int num_heads_qk, int num_heads_v, int " + "head_qk, int head_v) -> (Tensor, Tensor, Tensor, Tensor)"); + m.impl("fused_qkvzba_split_reshape_cat_cpu", torch::kCPU, &fused_qkvzba_split_reshape_cat_cpu); } TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { diff --git a/test/srt/cpu/test_qwen3.py b/test/srt/cpu/test_qwen3.py new file mode 100644 index 000000000000..4c122c1bb394 --- /dev/null +++ b/test/srt/cpu/test_qwen3.py @@ -0,0 +1,87 @@ +import unittest + +import torch +from utils import precision + +from sglang.test.test_utils import CustomTestCase + +torch.manual_seed(1234) + + +def fix_query_key_value_ordering_reshape_cat( + mixed_qkvz, mixed_ba, num_k_heads, num_v_heads, attn_tp_size, head_k_dim, head_v_dim +): + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + num_k_heads // attn_tp_size, + ( + head_k_dim + + head_k_dim + + (head_v_dim + head_v_dim) * num_v_heads // num_k_heads + ), + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + num_k_heads // attn_tp_size, + 2 * num_v_heads // num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + head_k_dim, + head_k_dim, + (num_v_heads // num_k_heads * head_v_dim), + (num_v_heads // num_k_heads * head_v_dim), + ] + split_arg_list_ba = [ + num_v_heads // num_k_heads, + num_v_heads // num_k_heads, + ] + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, head_v_dim) + z = z.reshape(z.size(0), -1, head_v_dim) + b = b.reshape(b.size(0), num_v_heads // attn_tp_size) + a = a.reshape(a.size(0), num_v_heads // attn_tp_size) + query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value)) + mixed_qkv = torch.cat((query, key, value), dim=-1) + + return mixed_qkv, z, b, a + + +class TestQwen3(CustomTestCase): + def test_fused_qkvzba_split_reshape_cat(self): + mixed_qkvz = torch.rand(1024, 12288, dtype=torch.bfloat16) + mixed_ba = torch.rand(1024, 64, dtype=torch.bfloat16) + head_k_dim = 128 + head_v_dim = 128 + num_v_heads = 32 + num_k_heads = 16 + attn_tp_size = 1 + mixed_qkv_ref, z_ref, b_ref, a_ref = fix_query_key_value_ordering_reshape_cat( + mixed_qkvz, + mixed_ba, + num_k_heads, + num_v_heads, + attn_tp_size, + head_k_dim, + head_v_dim, + ) + num_heads_qk = num_k_heads // attn_tp_size + num_heads_v = num_v_heads // attn_tp_size + mixed_qkv, z, b, a = torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu( + mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim + ) + atol = rtol = precision[mixed_qkv.dtype] + torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol) + torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 70d4514bba68..2a1cec01679c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -337,6 +337,7 @@ TestFile("cpu/test_moe.py"), TestFile("cpu/test_norm.py"), TestFile("cpu/test_qkv_proj_with_rope.py"), + TestFile("cpu/test_qwen3.py"), TestFile("cpu/test_rope.py"), TestFile("cpu/test_shared_expert.py"), TestFile("cpu/test_topk.py"),