Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/rotary_embedding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import torch

from sglang.srt.utils import get_compiler_backend, is_npu
from sglang.srt.utils import cpu_has_amx_support, get_compiler_backend, is_cpu, is_npu

_is_npu = is_npu()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()

if _is_npu:
import torch_npu
Expand Down Expand Up @@ -128,5 +130,7 @@ def apply_rotary_pos_emb_npu(

if _is_npu:
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
elif _is_cpu and _is_cpu_amx_available:
apply_rotary_pos_emb = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_native
270 changes: 270 additions & 0 deletions sgl-kernel/csrc/cpu/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,214 @@ void rotary_embedding_neox_4D_kernel_impl(
}
}

template <typename scalar_t>
void apply_rotary_pos_emb_kernel_impl(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
float* __restrict__ cos,
float* __restrict__ sin,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t bVecSize = bVec::size();
constexpr int64_t fVecSize = fVec::size();

int64_t embed_dim = head_size / 2;
bool flag = (embed_dim % bVecSize == 0);
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;

auto compute_loop = [&](int64_t token_head, float* cos_ptr, float* sin_ptr, scalar_t* qk) {
int64_t j = 0;
for (; j < loop_upper; j += bVecSize) {
int64_t rot_offset = j;
int64_t x_index = rot_offset;
int64_t y_index = embed_dim + rot_offset;

int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;

fVec _cos_x_0 = fVec::loadu(cos_ptr + x_index);
fVec _sin_x_0 = fVec::loadu(sin_ptr + x_index);
fVec _cos_x_1 = fVec::loadu(cos_ptr + x_index + fVecSize);
fVec _sin_x_1 = fVec::loadu(sin_ptr + x_index + fVecSize);

fVec _cos_y_0 = fVec::loadu(cos_ptr + y_index);
fVec _sin_y_0 = fVec::loadu(sin_ptr + y_index);
fVec _cos_y_1 = fVec::loadu(cos_ptr + y_index + fVecSize);
fVec _sin_y_1 = fVec::loadu(sin_ptr + y_index + fVecSize);

bVec _q_x = bVec::loadu(qk + out_x);
bVec _q_y = bVec::loadu(qk + out_y);
fVec _q_x_0, _q_x_1;
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
fVec _q_y_0, _q_y_1;
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);

auto out1_0 = _q_x_0 * _cos_x_0 - _q_y_0 * _sin_x_0;
auto out1_1 = _q_x_1 * _cos_x_1 - _q_y_1 * _sin_x_1;
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
out1.store(qk + out_x);

auto out2_0 = _q_y_0 * _cos_y_0 + _q_x_0 * _sin_y_0;
auto out2_1 = _q_y_1 * _cos_y_1 + _q_x_1 * _sin_y_1;
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
out2.store(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
int64_t x_index = j;
int64_t y_index = embed_dim + j;

int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;

float _cos_x = cos_ptr[x_index];
float _sin_x = sin_ptr[x_index];
float _cos_y = cos_ptr[y_index];
float _sin_y = sin_ptr[y_index];

float _q_x = qk[out_x];
float _q_y = qk[out_y];

qk[out_x] = _q_x * _cos_x - _q_y * _sin_x;
qk[out_y] = _q_y * _cos_y + _q_x * _sin_y;
}
}
};

at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
int64_t token_idx = {0};
data_index_init(begin, token_idx, num_tokens);
for (int i = begin; i < end; ++i) {
float* cos_ptr = cos + token_idx * head_size;
float* sin_ptr = sin + token_idx * head_size;

for (int64_t i = 0; i < num_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
compute_loop(token_head, cos_ptr, sin_ptr, query);
}

for (int64_t i = 0; i < num_kv_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
compute_loop(token_head, cos_ptr, sin_ptr, key);
}
data_index_step(token_idx, num_tokens);
}
});
}

template <typename scalar_t>
void apply_rotary_pos_emb_kernel_impl(
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos,
scalar_t* __restrict__ sin,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t bVecSize = bVec::size();

int64_t embed_dim = head_size / 2;
bool flag = (embed_dim % bVecSize == 0);
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;

auto compute_loop = [&](int64_t token_head, scalar_t* cos_ptr, scalar_t* sin_ptr, scalar_t* qk) {
int64_t j = 0;
for (; j < loop_upper; j += bVecSize) {
int64_t rot_offset = j;
int64_t x_index = rot_offset;
int64_t y_index = embed_dim + rot_offset;

int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;

bVec _cos_x = bVec::loadu(cos_ptr + x_index);
bVec _sin_x = bVec::loadu(sin_ptr + x_index);
bVec _cos_y = bVec::loadu(cos_ptr + y_index);
bVec _sin_y = bVec::loadu(sin_ptr + y_index);
fVec _cos_x_0, _cos_x_1;
std::tie(_cos_x_0, _cos_x_1) = at::vec::convert_to_float(_cos_x);
fVec _sin_x_0, _sin_x_1;
std::tie(_sin_x_0, _sin_x_1) = at::vec::convert_to_float(_sin_x);
fVec _cos_y_0, _cos_y_1;
std::tie(_cos_y_0, _cos_y_1) = at::vec::convert_to_float(_cos_y);
fVec _sin_y_0, _sin_y_1;
std::tie(_sin_y_0, _sin_y_1) = at::vec::convert_to_float(_sin_y);

bVec _q_x = bVec::loadu(qk + out_x);
bVec _q_y = bVec::loadu(qk + out_y);
fVec _q_x_0, _q_x_1;
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
fVec _q_y_0, _q_y_1;
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);

auto out1_0 = _q_x_0 * _cos_x_0 - _q_y_0 * _sin_x_0;
auto out1_1 = _q_x_1 * _cos_x_1 - _q_y_1 * _sin_x_1;
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
out1.store(qk + out_x);

auto out2_0 = _q_y_0 * _cos_y_0 + _q_x_0 * _sin_y_0;
auto out2_1 = _q_y_1 * _cos_y_1 + _q_x_1 * _sin_y_1;
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
out2.store(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
int64_t x_index = j;
int64_t y_index = embed_dim + j;

int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;

float _cos_x = cos_ptr[x_index];
float _sin_x = sin_ptr[x_index];
float _cos_y = cos_ptr[y_index];
float _sin_y = sin_ptr[y_index];

float _q_x = qk[out_x];
float _q_y = qk[out_y];

qk[out_x] = _q_x * _cos_x - _q_y * _sin_x;
qk[out_y] = _q_y * _cos_y + _q_x * _sin_y;
}
}
};

at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
int64_t token_idx = {0};
data_index_init(begin, token_idx, num_tokens);
for (int i = begin; i < end; ++i) {
scalar_t* cos_ptr = cos + token_idx * head_size;
scalar_t* sin_ptr = sin + token_idx * head_size;

for (int64_t i = 0; i < num_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
compute_loop(token_head, cos_ptr, sin_ptr, query);
}

for (int64_t i = 0; i < num_kv_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
compute_loop(token_head, cos_ptr, sin_ptr, key);
}
data_index_step(token_idx, num_tokens);
}
});
}

template <typename scalar_t>
inline scalar_t* get_cache_ptr(
int64_t j,
Expand Down Expand Up @@ -561,6 +769,68 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
return std::make_tuple(query_out, key_out);
}

// query: [num_tokens, num_heads, head_size]
// key: [num_tokens, num_heads, head_size]
// cos: [num_tokens, head_size]
// sin: [num_tokens, head_size]
std::tuple<at::Tensor, at::Tensor>
apply_rotary_pos_emb_cpu(at::Tensor& query, at::Tensor& key, at::Tensor& cos, at::Tensor& sin) {
RECORD_FUNCTION("sgl-kernel::apply_rotary_pos_emb_cpu", std::vector<c10::IValue>({query, key}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
CHECK_INPUT(cos);
CHECK_INPUT(sin);
CHECK_DIM(3, query);
CHECK_DIM(3, key);
CHECK_DIM(2, cos);
CHECK_DIM(2, sin);
const auto input_dtype = query.scalar_type();
int64_t num_tokens = query.size(0);
CHECK_EQ(num_tokens, key.size(0));
CHECK_EQ(num_tokens, cos.size(0));
CHECK_EQ(num_tokens, sin.size(0));
int64_t num_heads = query.size(1);
CHECK_EQ(num_heads, key.size(1));
int64_t head_size = query.size(2);
CHECK_EQ(head_size, key.size(2));
CHECK_EQ(head_size, cos.size(1));
CHECK_EQ(head_size, sin.size(1));
int64_t q_stride_s = query.stride(0);
int64_t k_stride_s = key.stride(0);
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "apply_rotary_pos_emb_cpu", [&] {
if (cos.scalar_type() == at::kFloat && sin.scalar_type() == at::kFloat) {
apply_rotary_pos_emb_kernel_impl<scalar_t>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<float>(),
sin.data_ptr<float>(),
q_stride_s,
k_stride_s,
num_heads,
num_heads,
head_size,
num_tokens);
} else if (cos.scalar_type() == input_dtype && sin.scalar_type() == input_dtype) {
apply_rotary_pos_emb_kernel_impl<scalar_t>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
q_stride_s,
k_stride_s,
num_heads,
num_heads,
head_size,
num_tokens);
} else {
TORCH_CHECK(
false, "cos and sin must have the same data type, and must be either float or the same type as query/key");
}
});
return std::make_tuple(query, key);
}

// positions: [num_tokens] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs)
// query: [num_tokens, num_heads * head_size]
// key: [num_tokens, num_kv_heads * head_size]
Expand Down
5 changes: 5 additions & 0 deletions sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox);
std::tuple<at::Tensor, at::Tensor>
apply_rotary_pos_emb_cpu(at::Tensor& query, at::Tensor& key, at::Tensor& cos, at::Tensor& sin);

// mrope
std::tuple<at::Tensor, at::Tensor> multimodal_rotary_embedding_cpu(
Expand Down Expand Up @@ -572,6 +574,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)");
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
m.def("apply_rotary_pos_emb_cpu(Tensor query, Tensor key, Tensor cos, Tensor sin) -> (Tensor, Tensor)");
m.impl("apply_rotary_pos_emb_cpu", torch::kCPU, &apply_rotary_pos_emb_cpu);

// multimodal rope
m.def(
"multimodal_rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor "
Expand Down
23 changes: 23 additions & 0 deletions test/srt/cpu/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from sglang.srt.layers.rotary_embedding.rope_variant import (
DeepseekScalingRotaryEmbedding,
apply_rotary_pos_emb_native,
)
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
Expand All @@ -18,6 +19,7 @@

class TestROPE(CustomTestCase):
def test_mrope(self):
torch.manual_seed(100)
head_size = 128
seq_len = 512
num_heads = 16
Expand Down Expand Up @@ -254,6 +256,27 @@ def single_test(
num_kv_heads,
)

def test_apply_rotary_pos_emb(self):
num_tokens = 1024
num_heads = 8
head_size = 72
qkv = torch.randn(num_tokens, num_heads * head_size * 3).to(torch.bfloat16)
query, key, _ = qkv.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
query = query.view(num_tokens, num_heads, head_size)
key = key.view(num_tokens, num_heads, head_size)
for sincos_dtype in [torch.float32, torch.bfloat16]:
cos = torch.rand(num_tokens, head_size).to(sincos_dtype)
sin = torch.rand(num_tokens, head_size).to(sincos_dtype)
q_out_ref, k_out_ref = apply_rotary_pos_emb_native(query, key, cos, sin)
q_out_sgl, k_out_sgl = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu(
query, key, cos, sin
)
torch.testing.assert_close(q_out_ref, q_out_sgl, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(k_out_ref, k_out_sgl, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
unittest.main()
Loading