diff --git a/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh b/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh new file mode 100644 index 000000000000..9272e6248243 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/pos_enc.cuh @@ -0,0 +1,313 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/014ece97c7aa49084a1119dca792af081a18dbc1/csrc/pos_encoding_kernels.cu + +#include +#include + +#include + +#include + +#include +#include + +namespace { + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = SGLANG_LDG(cos_ptr + x_index); + sin = SGLANG_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = SGLANG_LDG(cos_ptr + x_index / 2); + sin = SGLANG_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, + key, + cache_ptr, + head_size, + num_heads, + num_kv_heads, + rot_dim, + token_idx, + query_stride, + key_stride, + head_stride); +} + +// Helper struct to launch kernel +template +void launch_kernel( + const int64_t* positions_data_ptr, + void* query_ptr, + void* key_ptr, + const void* cos_sin_cache_ptr, + int rot_dim, + int64_t query_stride, + int64_t key_stride, + int64_t head_stride, + int num_heads, + int num_kv_heads, + int head_size, + dim3 grid, + dim3 block, + const cudaStream_t stream) { + rotary_embedding_kernel<<>>( + positions_data_ptr, + static_cast(query_ptr), + static_cast(key_ptr), + static_cast(cos_sin_cache_ptr), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size); +}; + +// Helper macro to reduce repetition +#define DISPATCH_DTYPE(DTYPE_CODE, DTYPE_BITS, IS_NEOX, ...) \ + if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 32) { \ + launch_kernel(__VA_ARGS__); \ + } else if (DTYPE_CODE == kDLFloat && DTYPE_BITS == 16) { \ + launch_kernel(__VA_ARGS__); \ + } else if (DTYPE_CODE == kDLBfloat && DTYPE_BITS == 16) { \ + launch_kernel(__VA_ARGS__); \ + } else { \ + RuntimeCheck( \ + false, "Unsupported data type for rotary embedding. Only float32, float16, and bfloat16 are supported."); \ + } + +// Helper function to dispatch based on data type +template +void dispatch_by_dtype( + const int64_t* positions_data_ptr, + DLDataType query_dtype, + void* query_ptr, + void* key_ptr, + void* cos_sin_cache_ptr, + int rot_dim, + int64_t query_stride, + int64_t key_stride, + int64_t head_stride, + int num_heads, + int num_kv_heads, + int head_size, + dim3 grid, + dim3 block, + const cudaStream_t stream) { + using namespace host; + DISPATCH_DTYPE( + query_dtype.code, + query_dtype.bits, + IS_NEOX, + positions_data_ptr, + query_ptr, + key_ptr, + cos_sin_cache_ptr, + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); +} + +struct RotaryEmbeddingKernel { + static void + run(tvm::ffi::TensorView positions, // [batch_size, seq_len] or [num_tokens] + tvm::ffi::TensorView query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + tvm::ffi::Optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + int64_t head_size, + tvm::ffi::TensorView cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + using namespace host; + + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int32_t positions_ndim = positions.ndim(); + + // Make sure num_tokens dim is consistent across positions, query, and key + RuntimeCheck( + positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + RuntimeCheck( + query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + RuntimeCheck( + query.size(0) == positions.size(0) && (!key.has_value() || key.value().size(0) == positions.size(0)) && + query.size(1) == positions.size(1) && (!key.has_value() || key.value().size(1) == positions.size(1)), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key.value().numel() / num_tokens : 0; + RuntimeCheck(query_hidden_size % head_size == 0); + RuntimeCheck(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; + RuntimeCheck(num_heads % num_kv_heads == 0); + + int rot_dim = cos_sin_cache.size(1); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key.value().stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + + auto device = query.device(); + const cudaStream_t stream = LaunchKernel::resolve_device(device); + + auto positions_data_ptr = static_cast(positions.data_ptr()); + + if (is_neox) { + dispatch_by_dtype( + positions_data_ptr, + query.dtype(), + query.data_ptr(), + key.has_value() ? key.value().data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); + } else { + dispatch_by_dtype( + positions_data_ptr, + query.dtype(), + query.data_ptr(), + key.has_value() ? key.value().data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size, + grid, + block, + stream); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh index 01ce21a7a813..235936d2e971 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh @@ -28,6 +28,15 @@ using fp8x2_e5m2_t = __nv_fp8x2_e5m2; using fp32x4_t = float4; #endif +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + namespace device { #define SGL_DEVICE __forceinline__ __device__ diff --git a/python/sglang/jit_kernel/pos_enc.py b/python/sglang/jit_kernel/pos_enc.py new file mode 100644 index 000000000000..16f8ac37e930 --- /dev/null +++ b/python/sglang/jit_kernel/pos_enc.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_rotary_embedding_module() -> Module: + return load_jit( + "rotary_embedding", + cuda_files=["elementwise/pos_enc.cuh"], + cuda_wrappers=[("rotary_embedding", "RotaryEmbeddingKernel::run")], + ) + + +@register_custom_op( + op_name="rotary_embedding_with_key", + mutates_args=["query", "key"], +) +def rotary_embedding_with_key( + positions: torch.Tensor, # [batch_size, seq_len] or [num_tokens] + query: torch.Tensor, # [batch_size, seq_len, num_heads * head_size] or + # [num_tokens, num_heads * head_size] or + # [batch_size, seq_len, num_heads, head_size] or + # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads * head_size] or + # [num_tokens, num_kv_heads * head_size] or + # [batch_size, seq_len, num_heads, head_size] or + # [num_tokens, num_heads, head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [max_position, rot_dim] + is_neox: bool = True, +) -> None: + """ + Apply rotary embedding to query and key tensors. + + Args: + positions: Position indices of shape [num_tokens] or [batch_size, seq_len] + query: Query tensor of shape [num_tokens, num_heads, head_size] or [num_tokens, num_heads * head_size] + key: Key tensor of shape [num_tokens, num_kv_heads, head_size] or [num_tokens, num_kv_heads * head_size] + cos_sin_cache: Cosine and sine cache of shape [max_position, rot_dim] + is_neox: Whether to use GPT-NeoX style rotary embedding (True) or GPT-J style (False) + """ + module = _jit_rotary_embedding_module() + module.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + + +@register_custom_op( + op_name="rotary_embedding_without_key", + mutates_args=["query"], +) +def rotary_embedding_without_key( + positions: torch.Tensor, + query: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +) -> None: + module = _jit_rotary_embedding_module() + module.rotary_embedding(positions, query, None, head_size, cos_sin_cache, is_neox) + + +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +): + if key is None: + rotary_embedding_without_key( + positions, query, head_size, cos_sin_cache, is_neox + ) + else: + rotary_embedding_with_key( + positions, query, key, head_size, cos_sin_cache, is_neox + ) + return query, key diff --git a/python/sglang/jit_kernel/tests/test_pos_enc.py b/python/sglang/jit_kernel/tests/test_pos_enc.py new file mode 100644 index 000000000000..3656a6c2fc87 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_pos_enc.py @@ -0,0 +1,485 @@ +import time +from typing import Optional, Tuple, Union + +import pytest +import torch +import triton +import triton.language as tl + +from sglang.jit_kernel.pos_enc import rotary_embedding + + +@triton.jit +def burn_kernel(out_ptr, iters: tl.constexpr): + pid = tl.program_id(0) + x = tl.full((), pid + 1, dtype=tl.uint32) + + a = tl.full((), 1664525, dtype=tl.uint32) + c = tl.full((), 1013904223, dtype=tl.uint32) + sh = tl.full((), 13, dtype=tl.uint32) + + for _ in range(iters): + x = x * a + c + x = x ^ (x >> sh) + + if pid == 0: + tl.store(out_ptr, x) + + +def triton_burn(ms: float, grid=(256,)): + iters = int(ms * 20000) + out = torch.empty((), device="cuda", dtype=torch.uint32) + burn_kernel[grid](out, iters=iters) + return out + + +def create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads +): + """Create test inputs.""" + total_tokens = batch_size * seq_len + + query = torch.randn( + batch_size, seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size, seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + + pos_ids = torch.randint( + 0, min(seq_len * 2, 100), (total_tokens,), dtype=torch.long, device=device + ) + + query = query.view(total_tokens, num_q_heads, head_size) + key = key.view(total_tokens, num_kv_heads, head_size) + + return query, key, pos_ids + + +def create_cos_sin_cache(rotary_dim, max_position_embeddings, base, dtype, device): + """Create cos/sin cache for rotary embedding.""" + max_pos = max_position_embeddings + extended_max_pos = max(max_pos, 100) + cos_sin_cache = torch.zeros( + extended_max_pos, rotary_dim, dtype=dtype, device=device + ) + + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=device) + / rotary_dim + ) + ) + t = torch.arange(extended_max_pos, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + cos_cache = torch.cos(freqs).to(dtype) + sin_cache = torch.sin(freqs).to(dtype) + + cos_sin_cache[:, : rotary_dim // 2] = cos_cache + cos_sin_cache[:, rotary_dim // 2 :] = sin_cache + + return cos_sin_cache + + +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + key = key.to(self.dtype) + + return query, key + + +def get_torch_rotary_embedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device +): + """Initialize Torch Native RotaryEmbedding based on vLLM implementation.""" + return RotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ).to(device) + + +def get_sgl_rotary_embedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device +): + """Initialize SglKernelRotaryEmbedding.""" + try: + from sgl_kernel.testing.rotary_embedding import SglKernelRotaryEmbedding + except ImportError: + pytest.skip( + "SglKernelRotaryEmbedding is not available. Test case can be removed." + ) + + return SglKernelRotaryEmbedding( + head_size=head_size, + rotary_dim=rotary_dim, + max_position_embeddings=max_position_embeddings, + base=base, + is_neox_style=is_neox_style, + dtype=dtype, + ).to(device) + + +def compare_results(jit_out, sgl_out, dtype): + """Compare results between JIT and SGL implementations.""" + if jit_out is None: + assert sgl_out is None + return + + assert sgl_out is not None + + # Check for NaN values + assert not torch.isnan(jit_out).any(), "NaN in JIT results" + assert not torch.isnan(sgl_out).any(), "NaN in SGL results" + + # Compare results + atol = 1e-2 if dtype != torch.float32 else 1e-5 + rtol = 1e-2 if dtype != torch.float32 else 1e-5 + + torch.testing.assert_close(jit_out, sgl_out, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + # GPT-OSS cases + *[ + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", bs, sl, 8, 8) + for bs, sl in [(1, 1), (32, 1), (128, 1), (512, 1), (2, 512), (4, 4096)] + ], + # Other cases + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + (64, 64, 32, 8000, True, torch.float32, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.float32, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.float32, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.float32, "cuda", 3, 39, 4, 2), + # Additional test cases for different head sizes and dtypes + (64, 32, 1024, 10000, True, torch.float16, "cuda", 16, 64, 8, 4), + (128, 64, 2048, 10000, True, torch.float16, "cuda", 8, 128, 16, 8), + (256, 128, 4096, 10000, True, torch.float16, "cuda", 4, 256, 8, 4), + ], +) +@pytest.mark.parametrize( + "key_is_none", + [True, False], +) +def test_correctness( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + key_is_none, +): + """Test correctness of JIT rotary embedding implementation.""" + # Create inputs and caches + query, key, pos_ids = create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads + ) + cos_sin_cache = create_cos_sin_cache( + rotary_dim, max_position_embeddings, base, dtype, device + ) + + # Initialize torch kernel + torch_rotary_emb = get_torch_rotary_embedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) + torch_rotary_emb.cos_sin_cache = cos_sin_cache + r = torch.randn_like(query) + + # Apply rotary embeddings + query_jit, key_jit = query.clone(), key.clone() + query_torch, key_torch = query.clone(), key.clone() + stream_jit = torch.get_device_module("cuda").Stream() + stream_kernel = torch.get_device_module("cuda").Stream() + + if key_is_none: + key_jit = None + key_torch = None + triton_burn(100.0, grid=(1024,)) + + r_jit, r_torch = r.clone(), r.clone() + torch.cuda.synchronize() + + with torch.cuda.stream(stream_jit): + # Test if rotary_embedding runs on stream_jit + triton_burn(100.0, grid=(1024,)) + query_jit = query_jit + r_jit + query_jit_out, key_jit_out = rotary_embedding( + positions=pos_ids, + query=query_jit, + key=key_jit, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + with torch.cuda.stream(stream_kernel): + triton_burn(100.0, grid=(1024,)) + query_torch = query_torch + r_torch + query_torch_out, key_torch_out = torch_rotary_emb.forward_native( + positions=pos_ids, query=query_torch, key=key_torch + ) + + torch.cuda.synchronize() + compare_results(query_jit_out, query_torch_out, dtype) + compare_results(key_jit_out, key_torch_out, dtype) + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + # Small scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 1, 1, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 4, 16, 8, 8), + # Medium scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 8, 64, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 16, 128, 8, 8), + # Large scale + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 32, 512, 8, 8), + (64, 64, 4096, 8000, True, torch.bfloat16, "cuda", 64, 1024, 8, 8), + ], +) +def test_performance( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, +): + """Performance test comparing JIT and SGL implementations with accuracy validation.""" + # Create inputs and caches + query, key, pos_ids = create_test_inputs( + head_size, batch_size, seq_len, device, dtype, num_q_heads, num_kv_heads + ) + cos_sin_cache = create_cos_sin_cache( + rotary_dim, max_position_embeddings, base, dtype, device + ) + + # Initialize SGL kernel + sgl_rotary_emb = get_sgl_rotary_embedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, + ) + sgl_rotary_emb.cos_sin_cache = cos_sin_cache + + warmup = 3 + + # Warmup runs + for _ in range(warmup): + query_warm, key_warm = query.clone(), key.clone() + rotary_embedding( + positions=pos_ids, + query=query_warm, + key=key_warm, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + query_sgl_warm, key_sgl_warm = query.clone(), key.clone() + sgl_rotary_emb.forward_cuda( + positions=pos_ids, query=query_sgl_warm, key=key_sgl_warm + ) + + iteration = 100 + + # Time JIT implementation + torch.cuda.synchronize() + start_time = time.time() + for _ in range(iteration): + query_jit, key_jit = query.clone(), key.clone() + rotary_embedding( + positions=pos_ids, + query=query_jit, + key=key_jit, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + torch.cuda.synchronize() + jit_time = (time.time() - start_time) / iteration + + # Time SGL implementation + torch.cuda.synchronize() + start_time = time.time() + for _ in range(iteration): + query_sgl, key_sgl = query.clone(), key.clone() + sgl_rotary_emb.forward_cuda(positions=pos_ids, query=query_sgl, key=key_sgl) + torch.cuda.synchronize() + sgl_time = (time.time() - start_time) / iteration + + # Accuracy validation during performance test + # Run one more time to get outputs for comparison + query_jit_final, key_jit_final = query.clone(), key.clone() + query_sgl_final, key_sgl_final = query.clone(), key.clone() + + query_jit_out, key_jit_out = rotary_embedding( + positions=pos_ids, + query=query_jit_final, + key=key_jit_final, + head_size=head_size, + cos_sin_cache=cos_sin_cache, + is_neox=is_neox_style, + ) + + query_sgl_out, key_sgl_out = sgl_rotary_emb.forward_cuda( + positions=pos_ids, query=query_sgl_final, key=key_sgl_final + ) + + # Validate accuracy + compare_results(query_jit_out, query_sgl_out, dtype) + compare_results(key_jit_out, key_sgl_out, dtype) + + # Print results + total_tokens = batch_size * seq_len + print( + f"\nPerformance Test - Batch={batch_size}, SeqLen={seq_len}, Tokens={total_tokens}" + ) + print(f"JIT: {jit_time*1000:.9f}ms, SGL: {sgl_time*1000:.9f}ms") + if sgl_time > 0: + speedup = sgl_time / jit_time if jit_time > 0 else float("inf") + print(f"Speedup (SGL/JIT): {speedup:.2f}x") + + assert jit_time >= 0 and sgl_time >= 0 diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index b38c980f178d..5222a9e5a8ba 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -120,8 +120,10 @@ def __init__( and not (_is_xpu) and not (_is_npu) ): + # rotary_embedding from sglang.jit_kernel.pos_enc and vllm._custom_ops has the same implementation. + # TODO: Test on different devices and remove this conditional. if _is_cuda or _is_hip: - from sgl_kernel import rotary_embedding + from sglang.jit_kernel.pos_enc import rotary_embedding else: from vllm._custom_ops import rotary_embedding @@ -396,6 +398,7 @@ def forward_xpu( fused_set_kv_buffer_arg is None ), "fused_set_kv_buffer_arg is not supported for xpu implementation" positions = torch.add(positions, offsets) if offsets is not None else positions + return torch.ops.sgl_kernel.rotary_embedding( positions, query,