Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ if(ONEDNN_FOUND)
file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp)
list(APPEND VLLM_EXT_XPU_SRC
${_ONEDNN_SRC}
"csrc/xpu/sycl/deepseek_scaling_rope.cpp"
)
include_directories(${ONEDNN_INCLUDE_DIR})
link_libraries(${ONEDNN_LIBRARY})
Expand Down
7 changes: 6 additions & 1 deletion csrc/xpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@
torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B,
bool trans_B,
const std::optional<torch::Tensor>& B_scale_,
const std::optional<torch::Tensor>& bias_);
const std::optional<torch::Tensor>& bias_);

std::tuple<at::Tensor, at::Tensor> deepseek_scaling_rope(
Comment thread
jikunshang marked this conversation as resolved.
const at::Tensor& positions, const at::Tensor& query, const at::Tensor& key,
const c10::optional<at::Tensor>& offsets_opt,
const at::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox);
257 changes: 257 additions & 0 deletions csrc/xpu/sycl/deepseek_scaling_rope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
#include <sycl/sycl.hpp>
#include "utils.h"
#include "dispatch_utils.h"
#include <cmath>
#include <c10/macros/Macros.h>

namespace vllm {

template <typename T, int64_t rotary_dim, bool is_neox>
class deepseek_scaling_rope_kernel {
public:
static constexpr int sg_size = 16;
deepseek_scaling_rope_kernel(
const int64_t* positions, const T* query, const T* key,
const int64_t* offsets, const T* cos_sin_cache, T* query_out, T* key_out,
const int64_t batch, const int64_t q_num_head, const int64_t k_num_head,
const int64_t head_size, const int64_t q_num_head_d,
const int64_t q_batch_d, const int64_t k_num_head_d,
const int64_t k_batch_d)
: positions(positions),
query(query),
key(key),
offsets(offsets),
cos_sin_cache(cos_sin_cache),
query_out(query_out),
key_out(key_out),
batch(batch),
q_num_head(q_num_head),
k_num_head(k_num_head),
head_size(head_size),
q_num_head_d(q_num_head_d),
q_batch_d(q_batch_d),
k_num_head_d(k_num_head_d),
k_batch_d(k_batch_d) {}

void rotary_embedding_kernel(const int64_t position, const T* pe,
const T* cos_sin_cache, T* res) const {
constexpr int64_t half_rotary_dim = rotary_dim / 2;
constexpr int64_t vec_2_len = 2;
using v2_type = sycl::vec<T, vec_2_len>;
const int64_t cache_idx = position * rotary_dim;
const T* cos_cache_offset = &cos_sin_cache[cache_idx];
const T* sin_cache_offset = cos_cache_offset + half_rotary_dim;
if constexpr (is_neox) {
// repeat & rotate mul add
for (int64_t i = 0; i < half_rotary_dim; ++i) {
int64_t j = i + half_rotary_dim;
T cv = cos_cache_offset[i];
T sv = sin_cache_offset[i];
res[i] = pe[i] * cv - pe[j] * sv;
res[j] = pe[j] * cv + pe[i] * sv;
}
} else {
// interleave & rotate mul add, unfortunately no prefetch in sycl
const v2_type* pe_2 = reinterpret_cast<const v2_type*>(pe);
v2_type* res_2 = reinterpret_cast<v2_type*>(res);
for (int64_t h = 0; h < half_rotary_dim; ++h) {
T c = cos_cache_offset[h];
T s = sin_cache_offset[h];
v2_type c2 = {c, c};
v2_type s2 = {s, s};
v2_type t = pe_2[h];
v2_type* dst = &res_2[h];
v2_type tr = {-t[1], t[0]};
*dst = t * c2 + tr * s2;
}
}
}

[[sycl::reqd_sub_group_size(sg_size)]] void operator()(
sycl::nd_item<3> idx) const {
int64_t batch_idx = idx.get_global_id(0);
int64_t sg_idx = idx.get_local_id(1);
int64_t local_id = idx.get_global_id(2);
int64_t head_idx = sg_idx * sg_size + local_id;
int64_t qo_idx = batch_idx * q_num_head * head_size + head_idx * head_size;
int64_t ko_idx = batch_idx * k_num_head * head_size +
(head_idx - q_num_head) * head_size;
int64_t qi_idx = batch_idx * q_batch_d + head_idx * q_num_head_d;
int64_t ki_idx =
batch_idx * k_batch_d + (head_idx - q_num_head) * k_num_head_d;
if (head_idx < q_num_head) {
rotary_embedding_kernel(positions[batch_idx], &query[qi_idx],
cos_sin_cache, &query_out[qo_idx]);
} else if (head_idx < q_num_head + k_num_head) {
rotary_embedding_kernel(positions[batch_idx], &key[ki_idx], cos_sin_cache,
&key_out[ko_idx]);
}
}

private:
const int64_t* positions;
const T* query;
const T* key;
const int64_t* offsets;
const T* cos_sin_cache;
T* query_out;
T* key_out;
const int64_t batch;
const int64_t q_num_head;
const int64_t k_num_head;
const int64_t head_size;
const int64_t q_num_head_d;
const int64_t q_batch_d;
const int64_t k_num_head_d;
const int64_t k_batch_d;
};

} // namespace vllm

template <typename T>
void call_deepseek_scaling_rope(const int64_t* positions, const T* query,
const T* key, const int64_t* offsets,
const T* cos_sin_cache, T* query_out,
T* key_out, int64_t batch, int64_t q_num_head,
int64_t k_num_head, int64_t head_size,
int64_t rotary_dim, bool is_neox,
int64_t q_num_head_d, int64_t q_batch_d,
int64_t k_num_head_d, int64_t k_batch_d) {
static constexpr std::array<int, 5> allowed_dims = {32, 64, 96, 128, 256};
auto it = std::find(allowed_dims.begin(), allowed_dims.end(), rotary_dim);

TORCH_CHECK(it != allowed_dims.end(), "Invalid rotary_dim (", rotary_dim,
"). Supported: 32,64,96,128,256");
TORCH_CHECK(rotary_dim == head_size, "rotary_dim (", rotary_dim,
") must equal head_size (", head_size, ")");

const int rot_idx = std::distance(allowed_dims.begin(), it);
const int neox_idx = is_neox ? 1 : 0;
const int func_idx = neox_idx * allowed_dims.size() + rot_idx;

using LaunchFn =
void (*)(sycl::queue&, const int64_t*, const T*, const T*, const int64_t*,
const T*, T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t,
int64_t, int64_t, int64_t);

// Table builder macro
#define REGISTER_CASE(dim, neox) \
[](sycl::queue& q, const int64_t* pos, const T* q_in, const T* k_in, \
const int64_t* off, const T* cache, T* q_out, T* k_out, int64_t b, \
int64_t qh, int64_t kh, int64_t hs, int64_t qhd, int64_t qbd, \
int64_t khd, int64_t kbd) { \
constexpr int64_t sg_size = 16; \
int64_t sg_per_heads = (qh + kh + sg_size - 1) / sg_size; \
sycl::range<3> local(1, sg_per_heads, sg_size); \
sycl::range<3> global(b, sg_per_heads, sg_size); \
at::DeviceGuard dg(at::Device(at::kXPU, at::xpu::current_device())); \
q.submit([&](sycl::handler& cgh) { \
cgh.parallel_for(sycl::nd_range<3>(global, local), \
vllm::deepseek_scaling_rope_kernel<T, dim, neox>{ \
pos, q_in, k_in, off, cache, q_out, k_out, b, qh, \
kh, hs, qhd, qbd, khd, kbd}); \
}); \
}

static constexpr std::array<LaunchFn, allowed_dims.size() * 2> table = {
REGISTER_CASE(32, false), REGISTER_CASE(64, false),
REGISTER_CASE(96, false), REGISTER_CASE(128, false),
REGISTER_CASE(256, false), REGISTER_CASE(32, true),
REGISTER_CASE(64, true), REGISTER_CASE(96, true),
REGISTER_CASE(128, true), REGISTER_CASE(256, true),
};

auto& queue = vllm::xpu::vllmGetQueue();
table[func_idx](queue, positions, query, key, offsets, cos_sin_cache,
query_out, key_out, batch, q_num_head, k_num_head, head_size,
q_num_head_d, q_batch_d, k_num_head_d, k_batch_d);

#undef REGISTER_CASE
}

/**
* @brief Perform deepseek rotary embedding with q&k.
* @param positions index of embedding [batch]
* @param query query to be processed [batch, num_head, head_dim]
* @param key key to be processed [batch, num_head, head_dim]
* @param offsets optional tensor for offset with position
* @param cos_sin_cache shared cache with cos/sin
* @param is_neox choose interleave or half.
* @return A tuple of tensors (query_out, key_out).
*/
std::tuple<torch::Tensor, torch::Tensor> deepseek_scaling_rope(
const torch::Tensor& positions, const torch::Tensor& query,
const torch::Tensor& key, const c10::optional<torch::Tensor>& offsets_opt,
const torch::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox) {
auto query_out = at::empty_like(query);
auto key_out = at::empty_like(key);

auto q_shape = query.sizes();
auto q_stride = query.strides();
int64_t head_size = q_shape[2];
int64_t q_num_head = q_shape[1];
int64_t batch = q_shape[0];
int64_t q_num_head_d = q_stride[1];
int64_t q_batch_d = q_stride[0];
auto k_shape = key.sizes();
auto k_stride = key.strides();
int64_t k_num_head = k_shape[1];
int64_t k_num_head_d = k_stride[1];
int64_t k_batch_d = k_stride[0];
if (is_neox) {
query_out = query_out.reshape({1, batch, q_num_head, head_size});
key_out = key_out.reshape({1, batch, k_num_head, head_size});
}
TORCH_CHECK(cos_sin_cache.sizes()[1] == head_size,
"Rotary dim doesn't match query head_size");
TORCH_CHECK(cos_sin_cache.sizes()[1] == k_shape[2],
"Rotary dim doesn't match key head_size");
const c10::MaybeOwned<torch::Tensor> offsets_maybe_owned =
at::borrow_from_optional_tensor(offsets_opt);
const torch::Tensor& offsets = *offsets_maybe_owned;
auto offsets_ptr = offsets.defined() ? offsets.data_ptr() : nullptr;
switch (query.scalar_type()) {
case torch::kFloat:
call_deepseek_scaling_rope<float>(
reinterpret_cast<int64_t*>(positions.data_ptr()),
reinterpret_cast<float*>(query.data_ptr()),
reinterpret_cast<float*>(key.data_ptr()),
reinterpret_cast<int64_t*>(offsets_ptr),
reinterpret_cast<float*>(cos_sin_cache.data_ptr()),
reinterpret_cast<float*>(query_out.data_ptr()),
reinterpret_cast<float*>(key_out.data_ptr()), batch, q_num_head,
k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d,
k_num_head_d, k_batch_d);
break;
case torch::kFloat16:
call_deepseek_scaling_rope<sycl::half>(
reinterpret_cast<int64_t*>(positions.data_ptr()),
reinterpret_cast<sycl::half*>(query.data_ptr()),
reinterpret_cast<sycl::half*>(key.data_ptr()),
reinterpret_cast<int64_t*>(offsets_ptr),
reinterpret_cast<sycl::half*>(cos_sin_cache.data_ptr()),
reinterpret_cast<sycl::half*>(query_out.data_ptr()),
reinterpret_cast<sycl::half*>(key_out.data_ptr()), batch, q_num_head,
k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d,
k_num_head_d, k_batch_d);
break;
case torch::kBFloat16:
call_deepseek_scaling_rope<sycl::ext::oneapi::bfloat16>(
reinterpret_cast<int64_t*>(positions.data_ptr()),
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query.data_ptr()),
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key.data_ptr()),
reinterpret_cast<int64_t*>(offsets_ptr),
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(
cos_sin_cache.data_ptr()),
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query_out.data_ptr()),
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key_out.data_ptr()),
batch, q_num_head, k_num_head, head_size, rotary_dim, is_neox,
q_num_head_d, q_batch_d, k_num_head_d, k_batch_d);
break;
default:
throw std::invalid_argument(
"Invalid dtype, only supports float32, float16, and bfloat16");
break;
}
return {query_out, key_out};
}
7 changes: 7 additions & 0 deletions csrc/xpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
"fp8_gemm_w8a16(Tensor! A, Tensor! B, bool trans_B, Tensor? B_scale_, "
"Tensor? bias_) -> Tensor");
xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16);

xpu_ops.def(
"deepseek_scaling_rope(Tensor! positions, Tensor! query, Tensor! key, "
"Tensor? offsets_opt, Tensor! cos_sin_cache, int rotary_dim, bool "
"is_neox_style) "
"-> (Tensor, Tensor)");
xpu_ops.impl("deepseek_scaling_rope", torch::kXPU, &deepseek_scaling_rope);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
15 changes: 15 additions & 0 deletions tests/register_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
import vllm_xpu_kernels._C # noqa: F401
import vllm_xpu_kernels._moe_C # noqa: F401
import vllm_xpu_kernels._xpu_C # noqa: F401


# layer norm ops
Expand Down Expand Up @@ -61,6 +62,20 @@ def rotary_embedding(
cos_sin_cache, is_neox)


def deepseek_scaling_rope(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets_opt: Optional[torch.Tensor],
cos_sin_cache: Optional[torch.Tensor],
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops._xpu_C.deepseek_scaling_rope(positions, query, key,
offsets_opt, cos_sin_cache,
rotary_dim, is_neox_style)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
Expand Down
Loading