Skip to content
Merged
115 changes: 115 additions & 0 deletions sgl-kernel/csrc/cpu/model/qwen3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "common.h"
#include "vec.h"
namespace {

template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = bVec::size();
int64_t d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
bVec out_bvec = bVec::loadu(src + d);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = src[d];
}
}

template <typename scalar_t>
void fused_qkvzba_split_reshape_cat_impl(
const scalar_t* __restrict__ mixed_qkvz,
const scalar_t* __restrict__ mixed_ba,
scalar_t* __restrict__ mixed_qkv,
scalar_t* __restrict__ z,
scalar_t* __restrict__ b,
scalar_t* __restrict__ a,
int64_t batch,
int64_t num_heads_qk,
int64_t num_heads_v,
int64_t head_qk,
int64_t group,
int64_t head_v,
int64_t qkv_strideB,
int64_t qkvz_strideB,
int64_t ba_strideB) {
int64_t qkvz_stride_per_head = head_qk * 2 + head_v * 2 * group;
at::parallel_for(0, batch * num_heads_qk, 0, [&](int64_t begin, int64_t end) {
int64_t bi{0}, hi{0};
data_index_init(begin, bi, batch, hi, num_heads_qk);
for (int64_t i = begin; i < end; ++i) {
scalar_t* __restrict__ q_out_ptr = mixed_qkv + bi * qkv_strideB + hi * head_qk;
const scalar_t* __restrict__ q_in_ptr = mixed_qkvz + bi * qkvz_strideB + hi * qkvz_stride_per_head;
scalar_t* __restrict__ k_out_ptr = q_out_ptr + num_heads_qk * head_qk;
const scalar_t* __restrict__ k_in_ptr = q_in_ptr + head_qk;
scalar_t* __restrict__ v_out_ptr = k_out_ptr + num_heads_qk * head_qk + hi * head_qk * (group - 1);
const scalar_t* __restrict__ v_in_ptr = k_in_ptr + head_qk;
scalar_t* __restrict__ z_out_ptr = z + bi * num_heads_v * head_v + hi * group * head_v;
const scalar_t* __restrict__ z_in_ptr = v_in_ptr + head_qk * group;
copy_stub(q_out_ptr, q_in_ptr, head_qk);
copy_stub(k_out_ptr, k_in_ptr, head_qk);
copy_stub(v_out_ptr, v_in_ptr, head_qk * group);
copy_stub(z_out_ptr, z_in_ptr, head_qk * group);
scalar_t* __restrict__ b_out_ptr = b + bi * num_heads_v + hi * group;
const scalar_t* __restrict__ b_in_ptr = mixed_ba + bi * ba_strideB + hi * group * 2;
scalar_t* __restrict__ a_out_ptr = a + bi * num_heads_v + hi * group;
const scalar_t* __restrict__ a_in_ptr = b_in_ptr + group;
copy_stub(b_out_ptr, b_in_ptr, group);
copy_stub(a_out_ptr, a_in_ptr, group);
data_index_step(bi, batch, hi, num_heads_qk);
}
});
}
} // anonymous namespace

// mixed_qkvz: [batch, num_heads_qk * head_qk * 2 + num_heads_v * head_v * 2]
// mixed_ba: [batch, num_heads_v * 2]
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> fused_qkvzba_split_reshape_cat_cpu(
const at::Tensor& mixed_qkvz,
const at::Tensor& mixed_ba,
int64_t num_heads_qk,
int64_t num_heads_v,
int64_t head_qk,
int64_t head_v) {
RECORD_FUNCTION("sgl-kernel::fused_qkvzba_split_reshape_cat_cpu", std::vector<c10::IValue>({mixed_qkvz, mixed_ba}));
CHECK_DIM(2, mixed_qkvz);
CHECK_DIM(2, mixed_ba);
CHECK_INPUT(mixed_qkvz);
CHECK_INPUT(mixed_ba);
int64_t batch = mixed_qkvz.size(0);
int64_t qkv_dim = num_heads_qk * head_qk * 2 + num_heads_v * head_v;
int64_t ba_dim = num_heads_v * 2;
int64_t expected_dim = qkv_dim + num_heads_v * head_v;
CHECK_EQ(mixed_qkvz.size(1), expected_dim);
CHECK_EQ(mixed_ba.size(0), batch);
CHECK_EQ(mixed_ba.size(1), ba_dim);
CHECK_EQ(num_heads_v % num_heads_qk, 0);
at::Tensor mixed_qkv = at::empty({batch, qkv_dim}, mixed_qkvz.options());
at::Tensor z = at::empty({batch, num_heads_v, head_v}, mixed_qkvz.options());
at::Tensor b = at::empty({batch, num_heads_v}, mixed_ba.options());
at::Tensor a = at::empty({batch, num_heads_v}, mixed_ba.options());
int64_t group = num_heads_v / num_heads_qk;
int64_t qkvz_strideB = mixed_qkvz.size(1);
int64_t qkv_strideB = mixed_qkv.size(1);
int64_t ba_strideB = mixed_ba.size(1);
AT_DISPATCH_REDUCED_FLOATING_TYPES(mixed_qkvz.scalar_type(), "fused_qkvzba_split_reshape_cat_impl", [&] {
fused_qkvzba_split_reshape_cat_impl<scalar_t>(
mixed_qkvz.data_ptr<scalar_t>(),
mixed_ba.data_ptr<scalar_t>(),
mixed_qkv.data_ptr<scalar_t>(),
z.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
batch,
num_heads_qk,
num_heads_v,
head_qk,
group,
head_v,
qkv_strideB,
qkvz_strideB,
ba_strideB);
});
return std::make_tuple(mixed_qkv, z, b, a);
}
15 changes: 15 additions & 0 deletions sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
// CPU and memory binding
std::string init_cpu_threads_env(const std::string& cpu_ids);

// fused_qkvzba_split_reshape_cat_cpu
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> fused_qkvzba_split_reshape_cat_cpu(
const at::Tensor& mixed_qkvz,
const at::Tensor& mixed_ba,
int64_t num_heads_qk,
int64_t num_heads_v,
int64_t head_qk,
int64_t head_v);

TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
Expand Down Expand Up @@ -389,6 +398,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

// CPU and memory binding
m.def("init_cpu_threads_env(str cpu_ids) -> str");

// fused_qkvzba_split_reshape_cat_cpu
m.def(
"fused_qkvzba_split_reshape_cat_cpu(Tensor mixed_qkvz, Tensor mixed_ba, int num_heads_qk, int num_heads_v, int "
"head_qk, int head_v) -> (Tensor, Tensor, Tensor, Tensor)");
m.impl("fused_qkvzba_split_reshape_cat_cpu", torch::kCPU, &fused_qkvzba_split_reshape_cat_cpu);
}

TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
Expand Down
87 changes: 87 additions & 0 deletions test/srt/cpu/test_qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest

import torch
from utils import precision

from sglang.test.test_utils import CustomTestCase

torch.manual_seed(1234)


def fix_query_key_value_ordering_reshape_cat(
mixed_qkvz, mixed_ba, num_k_heads, num_v_heads, attn_tp_size, head_k_dim, head_v_dim
):
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
num_k_heads // attn_tp_size,
(
head_k_dim
+ head_k_dim
+ (head_v_dim + head_v_dim) * num_v_heads // num_k_heads
),
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (
num_k_heads // attn_tp_size,
2 * num_v_heads // num_k_heads,
)

mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

split_arg_list_qkvz = [
head_k_dim,
head_k_dim,
(num_v_heads // num_k_heads * head_v_dim),
(num_v_heads // num_k_heads * head_v_dim),
]
split_arg_list_ba = [
num_v_heads // num_k_heads,
num_v_heads // num_k_heads,
]
# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
(query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), -1, head_v_dim)
z = z.reshape(z.size(0), -1, head_v_dim)
b = b.reshape(b.size(0), num_v_heads // attn_tp_size)
a = a.reshape(a.size(0), num_v_heads // attn_tp_size)
query, key, value = map(lambda x: x.reshape(x.shape[0], -1), (query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)

return mixed_qkv, z, b, a


class TestQwen3(CustomTestCase):
def test_fused_qkvzba_split_reshape_cat(self):
mixed_qkvz = torch.rand(1024, 12288, dtype=torch.bfloat16)
mixed_ba = torch.rand(1024, 64, dtype=torch.bfloat16)
head_k_dim = 128
head_v_dim = 128
num_v_heads = 32
num_k_heads = 16
attn_tp_size = 1
mixed_qkv_ref, z_ref, b_ref, a_ref = fix_query_key_value_ordering_reshape_cat(
mixed_qkvz,
mixed_ba,
num_k_heads,
num_v_heads,
attn_tp_size,
head_k_dim,
head_v_dim,
)
num_heads_qk = num_k_heads // attn_tp_size
num_heads_v = num_v_heads // attn_tp_size
mixed_qkv, z, b, a = torch.ops.sgl_kernel.fused_qkvzba_split_reshape_cat_cpu(
mixed_qkvz, mixed_ba, num_heads_qk, num_heads_v, head_k_dim, head_v_dim
)
atol = rtol = precision[mixed_qkv.dtype]
torch.testing.assert_close(mixed_qkv, mixed_qkv_ref, atol=atol, rtol=rtol)
torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol)
torch.testing.assert_close(b, b_ref, atol=atol, rtol=rtol)
torch.testing.assert_close(a, a_ref, atol=atol, rtol=rtol)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
TestFile("cpu/test_moe.py"),
TestFile("cpu/test_norm.py"),
TestFile("cpu/test_qkv_proj_with_rope.py"),
TestFile("cpu/test_qwen3.py"),
TestFile("cpu/test_rope.py"),
TestFile("cpu/test_shared_expert.py"),
TestFile("cpu/test_topk.py"),
Expand Down
Loading