diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cb1db8ad..30e5a458b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,6 +198,7 @@ if(ONEDNN_FOUND) file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp) list(APPEND VLLM_EXT_XPU_SRC ${_ONEDNN_SRC} + "csrc/xpu/sycl/deepseek_scaling_rope.cpp" ) include_directories(${ONEDNN_INCLUDE_DIR}) link_libraries(${ONEDNN_LIBRARY}) diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index bd529b176..524199efd 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -5,4 +5,9 @@ torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B, bool trans_B, const std::optional& B_scale_, - const std::optional& bias_); \ No newline at end of file + const std::optional& bias_); + +std::tuple deepseek_scaling_rope( + const at::Tensor& positions, const at::Tensor& query, const at::Tensor& key, + const c10::optional& offsets_opt, + const at::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox); diff --git a/csrc/xpu/sycl/deepseek_scaling_rope.cpp b/csrc/xpu/sycl/deepseek_scaling_rope.cpp new file mode 100644 index 000000000..0a7fa310a --- /dev/null +++ b/csrc/xpu/sycl/deepseek_scaling_rope.cpp @@ -0,0 +1,257 @@ +#include +#include "utils.h" +#include "dispatch_utils.h" +#include +#include + +namespace vllm { + +template +class deepseek_scaling_rope_kernel { + public: + static constexpr int sg_size = 16; + deepseek_scaling_rope_kernel( + const int64_t* positions, const T* query, const T* key, + const int64_t* offsets, const T* cos_sin_cache, T* query_out, T* key_out, + const int64_t batch, const int64_t q_num_head, const int64_t k_num_head, + const int64_t head_size, const int64_t q_num_head_d, + const int64_t q_batch_d, const int64_t k_num_head_d, + const int64_t k_batch_d) + : positions(positions), + query(query), + key(key), + offsets(offsets), + cos_sin_cache(cos_sin_cache), + query_out(query_out), + key_out(key_out), + batch(batch), + q_num_head(q_num_head), + k_num_head(k_num_head), + head_size(head_size), + q_num_head_d(q_num_head_d), + q_batch_d(q_batch_d), + k_num_head_d(k_num_head_d), + k_batch_d(k_batch_d) {} + + void rotary_embedding_kernel(const int64_t position, const T* pe, + const T* cos_sin_cache, T* res) const { + constexpr int64_t half_rotary_dim = rotary_dim / 2; + constexpr int64_t vec_2_len = 2; + using v2_type = sycl::vec; + const int64_t cache_idx = position * rotary_dim; + const T* cos_cache_offset = &cos_sin_cache[cache_idx]; + const T* sin_cache_offset = cos_cache_offset + half_rotary_dim; + if constexpr (is_neox) { + // repeat & rotate mul add + for (int64_t i = 0; i < half_rotary_dim; ++i) { + int64_t j = i + half_rotary_dim; + T cv = cos_cache_offset[i]; + T sv = sin_cache_offset[i]; + res[i] = pe[i] * cv - pe[j] * sv; + res[j] = pe[j] * cv + pe[i] * sv; + } + } else { + // interleave & rotate mul add, unfortunately no prefetch in sycl + const v2_type* pe_2 = reinterpret_cast(pe); + v2_type* res_2 = reinterpret_cast(res); + for (int64_t h = 0; h < half_rotary_dim; ++h) { + T c = cos_cache_offset[h]; + T s = sin_cache_offset[h]; + v2_type c2 = {c, c}; + v2_type s2 = {s, s}; + v2_type t = pe_2[h]; + v2_type* dst = &res_2[h]; + v2_type tr = {-t[1], t[0]}; + *dst = t * c2 + tr * s2; + } + } + } + + [[sycl::reqd_sub_group_size(sg_size)]] void operator()( + sycl::nd_item<3> idx) const { + int64_t batch_idx = idx.get_global_id(0); + int64_t sg_idx = idx.get_local_id(1); + int64_t local_id = idx.get_global_id(2); + int64_t head_idx = sg_idx * sg_size + local_id; + int64_t qo_idx = batch_idx * q_num_head * head_size + head_idx * head_size; + int64_t ko_idx = batch_idx * k_num_head * head_size + + (head_idx - q_num_head) * head_size; + int64_t qi_idx = batch_idx * q_batch_d + head_idx * q_num_head_d; + int64_t ki_idx = + batch_idx * k_batch_d + (head_idx - q_num_head) * k_num_head_d; + if (head_idx < q_num_head) { + rotary_embedding_kernel(positions[batch_idx], &query[qi_idx], + cos_sin_cache, &query_out[qo_idx]); + } else if (head_idx < q_num_head + k_num_head) { + rotary_embedding_kernel(positions[batch_idx], &key[ki_idx], cos_sin_cache, + &key_out[ko_idx]); + } + } + + private: + const int64_t* positions; + const T* query; + const T* key; + const int64_t* offsets; + const T* cos_sin_cache; + T* query_out; + T* key_out; + const int64_t batch; + const int64_t q_num_head; + const int64_t k_num_head; + const int64_t head_size; + const int64_t q_num_head_d; + const int64_t q_batch_d; + const int64_t k_num_head_d; + const int64_t k_batch_d; +}; + +} // namespace vllm + +template +void call_deepseek_scaling_rope(const int64_t* positions, const T* query, + const T* key, const int64_t* offsets, + const T* cos_sin_cache, T* query_out, + T* key_out, int64_t batch, int64_t q_num_head, + int64_t k_num_head, int64_t head_size, + int64_t rotary_dim, bool is_neox, + int64_t q_num_head_d, int64_t q_batch_d, + int64_t k_num_head_d, int64_t k_batch_d) { + static constexpr std::array allowed_dims = {32, 64, 96, 128, 256}; + auto it = std::find(allowed_dims.begin(), allowed_dims.end(), rotary_dim); + + TORCH_CHECK(it != allowed_dims.end(), "Invalid rotary_dim (", rotary_dim, + "). Supported: 32,64,96,128,256"); + TORCH_CHECK(rotary_dim == head_size, "rotary_dim (", rotary_dim, + ") must equal head_size (", head_size, ")"); + + const int rot_idx = std::distance(allowed_dims.begin(), it); + const int neox_idx = is_neox ? 1 : 0; + const int func_idx = neox_idx * allowed_dims.size() + rot_idx; + + using LaunchFn = + void (*)(sycl::queue&, const int64_t*, const T*, const T*, const int64_t*, + const T*, T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int64_t); + +// Table builder macro +#define REGISTER_CASE(dim, neox) \ + [](sycl::queue& q, const int64_t* pos, const T* q_in, const T* k_in, \ + const int64_t* off, const T* cache, T* q_out, T* k_out, int64_t b, \ + int64_t qh, int64_t kh, int64_t hs, int64_t qhd, int64_t qbd, \ + int64_t khd, int64_t kbd) { \ + constexpr int64_t sg_size = 16; \ + int64_t sg_per_heads = (qh + kh + sg_size - 1) / sg_size; \ + sycl::range<3> local(1, sg_per_heads, sg_size); \ + sycl::range<3> global(b, sg_per_heads, sg_size); \ + at::DeviceGuard dg(at::Device(at::kXPU, at::xpu::current_device())); \ + q.submit([&](sycl::handler& cgh) { \ + cgh.parallel_for(sycl::nd_range<3>(global, local), \ + vllm::deepseek_scaling_rope_kernel{ \ + pos, q_in, k_in, off, cache, q_out, k_out, b, qh, \ + kh, hs, qhd, qbd, khd, kbd}); \ + }); \ + } + + static constexpr std::array table = { + REGISTER_CASE(32, false), REGISTER_CASE(64, false), + REGISTER_CASE(96, false), REGISTER_CASE(128, false), + REGISTER_CASE(256, false), REGISTER_CASE(32, true), + REGISTER_CASE(64, true), REGISTER_CASE(96, true), + REGISTER_CASE(128, true), REGISTER_CASE(256, true), + }; + + auto& queue = vllm::xpu::vllmGetQueue(); + table[func_idx](queue, positions, query, key, offsets, cos_sin_cache, + query_out, key_out, batch, q_num_head, k_num_head, head_size, + q_num_head_d, q_batch_d, k_num_head_d, k_batch_d); + +#undef REGISTER_CASE +} + +/** + * @brief Perform deepseek rotary embedding with q&k. + * @param positions index of embedding [batch] + * @param query query to be processed [batch, num_head, head_dim] + * @param key key to be processed [batch, num_head, head_dim] + * @param offsets optional tensor for offset with position + * @param cos_sin_cache shared cache with cos/sin + * @param is_neox choose interleave or half. + * @return A tuple of tensors (query_out, key_out). + */ +std::tuple deepseek_scaling_rope( + const torch::Tensor& positions, const torch::Tensor& query, + const torch::Tensor& key, const c10::optional& offsets_opt, + const torch::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox) { + auto query_out = at::empty_like(query); + auto key_out = at::empty_like(key); + + auto q_shape = query.sizes(); + auto q_stride = query.strides(); + int64_t head_size = q_shape[2]; + int64_t q_num_head = q_shape[1]; + int64_t batch = q_shape[0]; + int64_t q_num_head_d = q_stride[1]; + int64_t q_batch_d = q_stride[0]; + auto k_shape = key.sizes(); + auto k_stride = key.strides(); + int64_t k_num_head = k_shape[1]; + int64_t k_num_head_d = k_stride[1]; + int64_t k_batch_d = k_stride[0]; + if (is_neox) { + query_out = query_out.reshape({1, batch, q_num_head, head_size}); + key_out = key_out.reshape({1, batch, k_num_head, head_size}); + } + TORCH_CHECK(cos_sin_cache.sizes()[1] == head_size, + "Rotary dim doesn't match query head_size"); + TORCH_CHECK(cos_sin_cache.sizes()[1] == k_shape[2], + "Rotary dim doesn't match key head_size"); + const c10::MaybeOwned offsets_maybe_owned = + at::borrow_from_optional_tensor(offsets_opt); + const torch::Tensor& offsets = *offsets_maybe_owned; + auto offsets_ptr = offsets.defined() ? offsets.data_ptr() : nullptr; + switch (query.scalar_type()) { + case torch::kFloat: + call_deepseek_scaling_rope( + reinterpret_cast(positions.data_ptr()), + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(offsets_ptr), + reinterpret_cast(cos_sin_cache.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(key_out.data_ptr()), batch, q_num_head, + k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d, + k_num_head_d, k_batch_d); + break; + case torch::kFloat16: + call_deepseek_scaling_rope( + reinterpret_cast(positions.data_ptr()), + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(offsets_ptr), + reinterpret_cast(cos_sin_cache.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(key_out.data_ptr()), batch, q_num_head, + k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d, + k_num_head_d, k_batch_d); + break; + case torch::kBFloat16: + call_deepseek_scaling_rope( + reinterpret_cast(positions.data_ptr()), + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(offsets_ptr), + reinterpret_cast( + cos_sin_cache.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(key_out.data_ptr()), + batch, q_num_head, k_num_head, head_size, rotary_dim, is_neox, + q_num_head_d, q_batch_d, k_num_head_d, k_batch_d); + break; + default: + throw std::invalid_argument( + "Invalid dtype, only supports float32, float16, and bfloat16"); + break; + } + return {query_out, key_out}; +} diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 5906365d0..c34a8d6f3 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -11,6 +11,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) { "fp8_gemm_w8a16(Tensor! A, Tensor! B, bool trans_B, Tensor? B_scale_, " "Tensor? bias_) -> Tensor"); xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16); + + xpu_ops.def( + "deepseek_scaling_rope(Tensor! positions, Tensor! query, Tensor! key, " + "Tensor? offsets_opt, Tensor! cos_sin_cache, int rotary_dim, bool " + "is_neox_style) " + "-> (Tensor, Tensor)"); + xpu_ops.impl("deepseek_scaling_rope", torch::kXPU, &deepseek_scaling_rope); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/register_ops.py b/tests/register_ops.py index 88d3830cc..105defb4c 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -5,6 +5,7 @@ from typing import Optional import vllm_xpu_kernels._C # noqa: F401 import vllm_xpu_kernels._moe_C # noqa: F401 +import vllm_xpu_kernels._xpu_C # noqa: F401 # layer norm ops @@ -61,6 +62,20 @@ def rotary_embedding( cos_sin_cache, is_neox) +def deepseek_scaling_rope( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets_opt: Optional[torch.Tensor], + cos_sin_cache: Optional[torch.Tensor], + rotary_dim: int, + is_neox_style: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._xpu_C.deepseek_scaling_rope(positions, query, key, + offsets_opt, cos_sin_cache, + rotary_dim, is_neox_style) + + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/tests/test_deepseek_scaling_rope.py b/tests/test_deepseek_scaling_rope.py new file mode 100644 index 000000000..14fb36e2a --- /dev/null +++ b/tests/test_deepseek_scaling_rope.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.register_ops import deepseek_scaling_rope + +DEVICE = torch.device("xpu") + + +def _rotate_neox(x): + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + res = x.flatten(-2) + return res + + +class TestTorchMethod: + + def ref_deepseek_scaling_rope( + self, + positions, + query, + key, + cos_sin_cache, + rotary_dim, + head_size, + is_neox_style, + offsets=None, + ): + query_rot = query[..., :rotary_dim] + key_rot = key[..., :rotary_dim] + if rotary_dim < head_size: + query_pass = query[..., rotary_dim:] + key_pass = key[..., rotary_dim:] + + cos_sin_cache = cos_sin_cache.to(positions.device) + cos_sin = cos_sin_cache[torch. + add(positions, offsets + ) if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if is_neox_style else _rotate_gptj + + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if rotary_dim < head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + @pytest.mark.parametrize("seed", [123, 356, 478]) + @pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("batch", [1, 2, 16, 32]) + @pytest.mark.parametrize("q_num_head,k_num_head", [(16, 1), (32, 1)]) + @pytest.mark.parametrize("rotary_dim", [64, 128]) + @pytest.mark.parametrize("q_head_pad,k_head_pad", [(0, 0), (128, 512)]) + @pytest.mark.parametrize("is_neox", [True, False]) + def test_deepseek_scaling_rope( + self, + seed, + dtype, + batch, + q_num_head, + k_num_head, + rotary_dim, + q_head_pad, + k_head_pad, + is_neox, + ): + torch.manual_seed(seed) + head_size = rotary_dim + # if rotary_dim < head_size, reference code wrong behavior + # and not going to fix the original code + positions = torch.randint(0, batch * 10000, (batch, ), device=DEVICE) + cos_sin_cache = torch.randn(batch * 10000, rotary_dim, + device=DEVICE).to(dtype) + q_head_size_pad = q_head_pad + head_size + k_head_size_pad = k_head_pad + head_size + query_pad = torch.randn(batch, + q_num_head, + q_head_size_pad, + device=DEVICE).to(dtype) + key_pad = torch.randn(batch, + k_num_head, + k_head_size_pad, + device=DEVICE).to(dtype) + query = query_pad[..., :head_size] + key = key_pad[..., :head_size] + ref_query, ref_key = self.ref_deepseek_scaling_rope( + positions, query, key, cos_sin_cache, rotary_dim, head_size, + is_neox) + query_out, key_out = deepseek_scaling_rope(positions, query, key, None, + cos_sin_cache, rotary_dim, + is_neox) + torch.testing.assert_close(ref_query, query_out, atol=5e-3, rtol=1e-3) + torch.testing.assert_close(ref_key, key_out, atol=5e-3, rtol=1e-3)