diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 489b8248b692..3696666a95bf 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -44,7 +44,7 @@ RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.layers.rotary_embedding import RotaryEmbedding, apply_rotary_pos_emb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import add_prefix @@ -486,6 +486,15 @@ def __init__( softmax_in_single_precision=softmax_in_single_precision, ) + self.rotary_emb = RotaryEmbedding( + head_size=self.head_size, + rotary_dim=self.head_size, + max_position_embeddings=2048, + base=10000, + is_neox_style=False, + dtype=torch.get_default_dtype(), + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: self.qkv_proj = QKVParallelLinear( @@ -626,13 +635,17 @@ def forward( q = q.view(original_shape) k = k.view(original_shape) else: - cos, sin = position_embeddings - # [total_tokens, head, head_size] q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + (cos, sin) = position_embeddings + position_embeddings = ( + cos.float().contiguous(), + sin.float().contiguous(), + ) + + q, k = self.rotary_emb(position_embeddings, q, k) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 91e58f6a085e..0fee8a946f3a 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -111,9 +111,9 @@ def __init__( if ( not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512] ) and not (_is_cpu and _is_cpu_amx_available): - from vllm._custom_ops import rotary_embedding + from sgl_kernel import rotary_embedding - self.vllm_rotary_embedding = rotary_embedding + self.sglang_rotary_embedding = rotary_embedding self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -266,16 +266,26 @@ def forward_cuda( else: assert ( fused_set_kv_buffer_arg is None - ), "save kv cache is not supported for vllm_rotary_embedding." - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.vllm_rotary_embedding( - positions, + ), "save kv cache is not supported for sglang_rotary_embedding." + + cos, sin = positions + assert cos.dtype == torch.float and cos.is_contiguous() + assert sin.dtype == torch.float and sin.is_contiguous() + orig_q_dtype = query.dtype + orig_k_dtype = key.dtype + query, key = query.float(), key.float() + + self.sglang_rotary_embedding( + cos, + sin, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style, ) + + query = query.to(dtype=orig_q_dtype) + key = key.to(dtype=orig_k_dtype) return query, key def extra_repr(self) -> str: diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index f3a6a94fb08d..90ca504ea6e4 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -310,10 +310,10 @@ set(SOURCES "csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" + "csrc/multimodal/rotary_embedding.cu" "csrc/memory/store.cu" "csrc/kvcacheio/transfer.cu" - "csrc/speculative/eagle_utils.cu" "csrc/speculative/ngram_utils.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 48968a64cf3b..21fd0e1f809f 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -470,6 +470,17 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); + // Rotary embedding kernel + m.def( + "rotary_embedding(" + " Tensor! cos_cache, " + " Tensor! sin_cache, " + " Tensor! query," + " Tensor? key, " + " int head_size, " + " bool is_neox) -> ()"); + m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + /* * From csrc/mamba */ diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu new file mode 100644 index 000000000000..105c58e99bc5 --- /dev/null +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2025 by SGLang team. + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include +#include + +#include "utils.h" + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + + if (IS_NEOX) { + x_index = rot_offset; + y_index = embed_dim + rot_offset; + + scalar_t cos_val = SGLANG_LDG(cos_ptr + rot_offset); + scalar_t sin_val = SGLANG_LDG(sin_ptr + rot_offset); + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos_val - y * sin_val; + arr[y_index] = y * cos_val + x * sin_val; + + } else { + // GPT-J style / LLaMA style, matching the Python if cos/sin are [..., head_size] + x_index = rot_offset; // first half + y_index = rot_offset + embed_dim; // second half + + const scalar_t cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); + const scalar_t sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); + const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); + const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos_val_x - y * sin_val_x; + arr[y_index] = y * cos_val_y + x * sin_val_y; + } +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [num_heads, head_size] + scalar_t* __restrict__ key, // [num_kv_heads, head_size] + const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim] + const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim] + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int64_t head_stride_query, + const int64_t head_stride_key) { + const int embed_dim_for_rotation = rot_dim / 2; + + const int nq_pairs = num_heads * embed_dim_for_rotation; + for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { + const int head_idx = i / embed_dim_for_rotation; + const int rot_offset = i % embed_dim_for_rotation; + + scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; + + apply_token_rotary_embedding( + query_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim_for_rotation); + } + + if (key != nullptr) { + const int nk_pairs = num_kv_heads * embed_dim_for_rotation; + for (int i = threadIdx.x; i < nk_pairs; i += blockDim.x) { + const int head_idx = i / embed_dim_for_rotation; + const int rot_offset = i % embed_dim_for_rotation; + + scalar_t* key_for_token_head = key + head_idx * (int)head_stride_key; + + apply_token_rotary_embedding( + key_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim_for_rotation); + } + } +} + +template +__global__ void rotary_embedding_kernel( + const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim_arg] + const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim_arg] + scalar_t* __restrict__ query_total, + scalar_t* __restrict__ key_total, + const int rot_dim_arg, + const int64_t query_token_stride, + const int64_t key_token_stride, + const int64_t head_stride_query, + const int64_t head_stride_key, + const int num_heads, + const int num_kv_heads, + const int head_size) { + const int token_idx = blockIdx.x; + const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim_arg; + const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim_arg; + + scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; + scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; + + apply_rotary_embedding( + query_for_token, + key_for_token, + current_token_cos_ptr, + current_token_sin_ptr, + head_size, + num_heads, + num_kv_heads, + rot_dim_arg, + head_stride_query, + head_stride_key); +} + +void rotary_embedding( + at::Tensor& cos, + at::Tensor& sin, + at::Tensor& query, + const std::optional& key, + int64_t head_size, + bool is_neox) { + TORCH_CHECK( + query.dim() == 2 || query.dim() == 3, + "query must be in shape [num_tokens, hidden_size] or [num_tokens, num_heads, head_size]"); + if (key.has_value()) { + TORCH_CHECK( + key->dim() == 2 || key->dim() == 3, + "key must be in shape [num_tokens, hidden_size] or [num_tokens, num_kv_heads, head_size]"); + } + + int64_t num_tokens = query.size(0); + + TORCH_CHECK(cos.dim() == 2, "cos must be in shape [num_tokens, D_cos]"); + TORCH_CHECK(sin.dim() == 2, "sin must be in shape [num_tokens, D_sin]"); + TORCH_CHECK(cos.size(0) == num_tokens, "cos num_tokens mismatch with query"); + TORCH_CHECK(sin.size(0) == num_tokens, "sin num_tokens mismatch with query"); + TORCH_CHECK(cos.size(1) == sin.size(1), "cos and sin D_cos/D_sin mismatch"); + + TORCH_CHECK(cos.scalar_type() == query.scalar_type(), "cos dtype mismatch"); + TORCH_CHECK(sin.scalar_type() == query.scalar_type(), "sin dtype mismatch"); + TORCH_CHECK(cos.is_cuda() && sin.is_cuda() && query.is_cuda(), "All tensors must be on CUDA"); + if (key.has_value()) { + TORCH_CHECK(key->is_cuda(), "Key tensor must be on CUDA if provided"); + TORCH_CHECK(key->scalar_type() == query.scalar_type(), "Key dtype mismatch"); + } + + int query_hidden_size_calculated; + if (query.dim() == 2) { + query_hidden_size_calculated = (int)query.size(1); + } else { + query_hidden_size_calculated = (int)query.size(1) * (int)query.size(2); + TORCH_CHECK(query.size(2) == head_size, "Query head_size mismatch in 3D tensor"); + } + TORCH_CHECK(query_hidden_size_calculated % head_size == 0, "query_hidden_size not divisible by head_size"); + int num_heads = (int)query_hidden_size_calculated / (int)head_size; + + int key_hidden_size_calculated = 0; + int num_kv_heads = num_heads; + if (key.has_value()) { + TORCH_CHECK((int)key->size(0) == num_tokens, "Key num_tokens mismatch"); + if (key->dim() == 2) { + key_hidden_size_calculated = (int)key->size(1); + } else { + key_hidden_size_calculated = (int)key->size(1) * (int)key->size(2); + TORCH_CHECK((int)key->size(2) == head_size, "Key head_size mismatch in 3D tensor"); + } + TORCH_CHECK(key_hidden_size_calculated % head_size == 0, "key_hidden_size not divisible by head_size"); + num_kv_heads = key_hidden_size_calculated / (int)head_size; + } + TORCH_CHECK(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); + + int rot_dim_from_cache = (int)cos.size(1); + + int64_t query_token_stride = query_hidden_size_calculated; + int64_t key_token_stride = key.has_value() ? key_hidden_size_calculated : 0; + + int64_t head_stride_query; + if (query.dim() == 3 && query.size(1) == num_heads && query.size(2) == head_size) { + head_stride_query = query.stride(1); + } else { + head_stride_query = head_size; + } + + int64_t head_stride_key = head_size; + if (key.has_value()) { + if (key->dim() == 3 && key->size(1) == num_kv_heads && key->size(2) == head_size) { + head_stride_key = key->stride(1); + } else { + head_stride_key = head_size; + } + } + + dim3 grid((int)num_tokens); + + int embed_dim_for_block_calc = rot_dim_from_cache / 2; + int max_pairs_to_rotate_per_token = + std::max(num_heads * embed_dim_for_block_calc, num_kv_heads * embed_dim_for_block_calc); + dim3 block(std::min(max_pairs_to_rotate_per_token, 512L)); + + if (block.x == 0 && num_tokens > 0) block.x = 1; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + SGLANG_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + rotary_embedding_kernel<<>>( + cos.data_ptr(), + sin.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + rot_dim_from_cache, + query_token_stride, + key_token_stride, + head_stride_query, + head_stride_key, + num_heads, + num_kv_heads, + (int)head_size); + } else { + rotary_embedding_kernel<<>>( + cos.data_ptr(), + sin.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + rot_dim_from_cache, + query_token_stride, + key_token_stride, + head_stride_query, + head_stride_key, + num_heads, + num_kv_heads, + (int)head_size); + } + }); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index d316e4248de4..0113ae9bec8c 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -654,6 +654,13 @@ void top_p_sampling_from_probs( void top_k_mask_logits( at::Tensor logits, at::Tensor mask_logits, std::optional maybe_top_k_arr, int64_t top_k_val); +void rotary_embedding( + at::Tensor& cos_cache, // [num_tokens, rot_dim / 2] + at::Tensor& sin_cache, // [num_tokens, rot_dim / 2] + at::Tensor& query, // [num_tokens, num_heads * head_size] + const std::optional& key, // null or similar to query + int64_t head_size, + bool is_neox); namespace flash { /* diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 5cab0786c4d1..501e8dc19577 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -326,11 +326,21 @@ inline bool getEnvEnablePDL() { #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define SGLANG_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define SGLANG_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, SGLANG_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) #define WARP_SIZE 32 #else +#define SGLANG_LDG(arg) *(arg) #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) #define WARP_SIZE 64 #else diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index a53b0256788f..260d51509043 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -293,6 +293,7 @@ def _find_cuda_home(): prepare_moe_input, topk_softmax, ) +from sgl_kernel.rotary_embedding import rotary_embedding from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_mask_logits, diff --git a/sgl-kernel/python/sgl_kernel/rotary_embedding.py b/sgl-kernel/python/sgl_kernel/rotary_embedding.py new file mode 100644 index 000000000000..6d5072e68461 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/rotary_embedding.py @@ -0,0 +1,23 @@ +from typing import Optional + +import torch + + +# Adapted from https://github.com/vllm-project/vllm/blob/9214e60631a79506e7669650de87806a123e0b0b/vllm/_custom_ops.py#L249 +# pos encoding ops +def rotary_embedding( + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor], + head_size: int, + is_neox: bool, +) -> None: + torch.ops.sgl_kernel.rotary_embedding.default( + cos, + sin, + query, + key, + head_size, + is_neox, + ) diff --git a/sgl-kernel/tests/test_mm_rotary_embedding.py b/sgl-kernel/tests/test_mm_rotary_embedding.py new file mode 100644 index 000000000000..14d23ced72fe --- /dev/null +++ b/sgl-kernel/tests/test_mm_rotary_embedding.py @@ -0,0 +1,228 @@ +from typing import Tuple, Union + +import pytest +import torch +from sgl_kernel import rotary_embedding + + +# torch native +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim=1, +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + + # embedding is performed in float + cos = cos.unsqueeze(unsqueeze_dim).float() + sin = sin.unsqueeze(unsqueeze_dim).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + print(f"perform in {cos.dtype=}") + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + + return q_embed, k_embed + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + dtype: torch.dtype, + is_neox_style: bool = False, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + def forward_kernel_inplace( + self, + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + cos, sin = cos.float(), sin.float() + query, key = query.float(), key.float() + print(f"kernel: perform in {cos.dtype=}") + rotary_embedding( + cos, + sin, + query, + key, + self.head_size, + self.is_neox_style, + ) + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +@pytest.mark.benchmark(group="rotary_embedding") +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 32, 32, 16, 16), + (320, 230, 1e6, 1e6, False, torch.bfloat16, "cuda", 32, 32, 16, 16), + (80, 80, 1e6, 1e6, True, torch.bfloat16, "cuda", 32, 32, 16, 16), + ], +) +def test_correctness( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + dtype=dtype, + is_neox_style=is_neox_style, + ).to(device) + + query = torch.randn( + batch_size * seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + + key = torch.randn( + batch_size * seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + + cos = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + sin = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + + query_native_out, key_native_out = rope_ref.forward_native( + cos, sin, query.clone(), key.clone() + ) + + # in-place + query_kernel_out, key_kernel_out = rope_ref.forward_kernel_inplace( + cos, sin, query, key + ) + + torch.testing.assert_close(query_native_out, query_kernel_out, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(key_native_out, key_kernel_out, atol=1e-3, rtol=1e-3) + + +@pytest.mark.benchmark(group="rotary_embedding") +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 1, 8840, 16, 16), + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 1, 4000, 16, 16), + (80, 80, 1e6, 1e6, True, torch.bfloat16, "cuda", 8, 8840, 16, 16), + ], +) +def test_rotary_embedding_benchmark( + benchmark, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + dtype=dtype, + is_neox_style=is_neox_style, + ).to(device) + query = torch.randn( + batch_size * seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + cos = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + sin = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + + def run_kernel(): + rope_ref.forward_kernel_inplace( + cos, + sin, + query, + key, + ) + torch.cuda.synchronize() + + benchmark.pedantic(run_kernel, rounds=20000, warmup_rounds=5) + + +if __name__ == "__main__": + pytest.main([__file__, "--capture=no"])