From 9a823c5e44539fd1d443a89976c4a97c66860b09 Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 18 May 2025 12:51:32 +0800 Subject: [PATCH 01/11] initial --- python/sglang/srt/_custom_ops.py | 3 +- python/sglang/srt/layers/attention/vision.py | 19 +- python/sglang/srt/layers/rotary_embedding.py | 41 ++- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/csrc/common_extension.cc | 12 + .../csrc/multimodal/rotary_embedding.cu | 268 ++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 8 + sgl-kernel/include/utils.h | 14 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + .../python/sgl_kernel/rotary_embedding.py | 35 +++ 10 files changed, 387 insertions(+), 15 deletions(-) create mode 100644 sgl-kernel/csrc/multimodal/rotary_embedding.cu create mode 100644 sgl-kernel/python/sgl_kernel/rotary_embedding.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 07c087bf6c42..d506fbb7342e 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -24,7 +24,6 @@ except ImportError as e: logger.warning("Failed to import from custom_ar with %r", e) - if not is_hip(): if use_vllm_custom_allreduce: custom_op = torch.ops._C_custom_ar diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f1f45e27ab96..43d4c630003e 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -27,7 +27,7 @@ RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.layers.rotary_embedding import RotaryEmbedding, apply_rotary_pos_emb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import add_prefix, logger @@ -361,6 +361,17 @@ def __init__( softmax_in_single_precision=softmax_in_single_precision, ) + print(f"{self.head_size=}") + self.rotary_emb = RotaryEmbedding( + head_size=self.head_size, + rotary_dim=self.head_size, + max_position_embeddings=2048, + base=10000, + is_neox_style=False, + # dtype=rope_dtype + dtype=torch.get_default_dtype(), + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: self.qkv_proj = QKVParallelLinear( @@ -444,7 +455,11 @@ def forward( q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + print(f"{cos.dtype}") + print(f"{q.dtype}") + + q, k = self.rotary_emb(position_embeddings, q, k) + # q, k = apply_rotary_pos_emb(q, k, cos, sin) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c5c285ca0fc4..9a8c692f1258 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -84,9 +84,9 @@ def __init__( cache = cache.to(dtype) if not _is_cuda or self.head_size not in [64, 128, 256, 512]: - from vllm._custom_ops import rotary_embedding + from sgl_kernel import rotary_embedding - self.vllm_rotary_embedding = rotary_embedding + self.sglang_rotary_embedding = rotary_embedding self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -118,7 +118,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def forward_native( self, - positions: torch.Tensor, + positions: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, @@ -126,10 +126,15 @@ def forward_native( """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) + + if isinstance(positions, torch.Tensor): + positions = positions.flatten() + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + num_tokens = positions.shape[0] + else: + cos, sin = positions + num_tokens = cos.shape[0] query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -148,7 +153,7 @@ def forward_native( def forward_cuda( self, - positions: torch.Tensor, + positions: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, @@ -163,15 +168,29 @@ def forward_cuda( is_neox=self.is_neox_style, ) else: + orig_q_dtype = query.dtype + orig_k_dtype = key.dtype self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.vllm_rotary_embedding( - positions, + cos, sin = positions + query = query.to(dtype=cos.dtype) + key = key.to(dtype=cos.dtype) + + print(f"{type(cos)=}") + print(f"{cos.shape=}") + print(f"{query.shape=}") + print(f"{key.shape=}") + self.sglang_rotary_embedding( + cos, + sin, query, key, self.head_size, - self.cos_sin_cache, + # self.cos_sin_cache, self.is_neox_style, ) + + query = query.to(dtype=orig_q_dtype) + key = key.to(dtype=orig_k_dtype) return query, key def extra_repr(self) -> str: diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 71f77d51bc5a..2df005bea592 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -210,6 +210,7 @@ set(SOURCES "csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" + "csrc/multimodal/rotary_embedding.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d83944b566a0..4b4296d6536a 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -278,6 +278,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, " "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); + + /* + */ + m.def( + "rotary_embedding(" + " Tensor! cos_cache, " + " Tensor! sin_cache, " + " Tensor! query," + " Tensor? key, " + " int head_size, " + " bool is_neox) -> ()"); + m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu new file mode 100644 index 000000000000..de5ec819826a --- /dev/null +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2025 by SGLang team. + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include "utils.h" +// #include +#include +#include +#include + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos_val, sin_val; // Renamed to avoid conflict with input ptrs + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; // This was embed_dim + rot_offset, which is correct for NeoX's interleaved pairs + cos_val = SGLANG_LDG(cos_ptr + x_index); + sin_val = SGLANG_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos_val = SGLANG_LDG(cos_ptr + rot_offset); // x_index / 2 = rot_offset + sin_val = SGLANG_LDG(sin_ptr + rot_offset); // x_index / 2 = rot_offset + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos_val - y * sin_val; + arr[y_index] = y * cos_val + x * sin_val; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [num_heads, head_size] for current token + scalar_t* __restrict__ key, // nullptr or [num_kv_heads, head_size] for current token + const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim/2] for current token + const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim/2] for current token + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int64_t head_stride_query, // Stride to get to next head in query + const int64_t head_stride_key // Stride to get to next head in key +) { + const int embed_dim = rot_dim / 2; // Number of elements in cos/sin arrays for one token + + // No need to offset current_token_cos_ptr and current_token_sin_ptr further here, + // they already point to the start of cos/sin values for the current token. + + const int nq_pairs = num_heads * embed_dim; // Total pairs to rotate for query + for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int rot_offset = i % embed_dim; // Offset within the head's part to be rotated + + // query_for_token_head points to the start of the specific head for the current token + scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; + + apply_token_rotary_embedding( + query_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim); + } + + if (key != nullptr) { + const int nk_pairs = num_kv_heads * embed_dim; // Total pairs to rotate for key + for (int i = threadIdx.x; i < nk_pairs; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int rot_offset = i % embed_dim; + + // key_for_token_head points to the start of the specific head for the current token + scalar_t* key_for_token_head = key + head_idx * (int)head_stride_key; + + apply_token_rotary_embedding( + key_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim); + } + } +} + +template +__global__ void rotary_embedding_kernel( + const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim / 2] + const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim / 2] + scalar_t* __restrict__ query_total, // [num_tokens, num_heads, head_size] or [num_tokens, num_heads * head_size] + scalar_t* __restrict__ key_total, // nullptr or similar shape to query_total + const int rot_dim, + const int64_t query_token_stride, // Elements to skip to get to next token in query_total + const int64_t key_token_stride, // Elements to skip to get to next token in key_total + const int64_t head_stride_query, // Elements to skip to get to next head within a token's query data + const int64_t head_stride_key, // Elements to skip to get to next head within a token's key data + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + const int embed_dim = rot_dim / 2; + + const scalar_t* current_token_cos_ptr = cos_data + token_idx * embed_dim; + const scalar_t* current_token_sin_ptr = sin_data + token_idx * embed_dim; + + scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; + scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; + + apply_rotary_embedding( + query_for_token, + key_for_token, + current_token_cos_ptr, + current_token_sin_ptr, + head_size, + num_heads, + num_kv_heads, + rot_dim, + head_stride_query, + head_stride_key); +} + +void rotary_embedding( + at::Tensor& cos_cache, // [num_tokens, rot_dim / 2] + at::Tensor& sin_cache, // [num_tokens, rot_dim / 2] + at::Tensor& query, // [num_tokens, num_heads * head_size] + const std::optional& key, // null or similar to query + int64_t head_size, + bool is_neox) { + TORCH_CHECK( + query.dim() == 2 || query.dim() == 3, + "query must have shape [num_tokens, hidden_size] or [num_tokens, num_heads, head_size]"); + if (key.has_value()) { + TORCH_CHECK( + key->dim() == 2 || key->dim() == 3, + "key must have shape [num_tokens, hidden_size] or [num_tokens, num_kv_heads, head_size]"); + } + + int64_t num_tokens = query.size(0); + + TORCH_CHECK(cos_cache.dim() == 2, "cos_cache must have shape [num_tokens, rot_dim/2]"); + TORCH_CHECK(sin_cache.dim() == 2, "sin_cache must have shape [num_tokens, rot_dim/2]"); + TORCH_CHECK(cos_cache.size(0) == num_tokens, "cos_cache num_tokens mismatch with query"); + TORCH_CHECK(sin_cache.size(0) == num_tokens, "sin_cache num_tokens mismatch with query"); + TORCH_CHECK(cos_cache.size(1) == sin_cache.size(1), "cos_cache and sin_cache rot_dim/2 mismatch"); + + TORCH_CHECK(cos_cache.scalar_type() == query.scalar_type(), "cos_cache dtype mismatch"); + TORCH_CHECK(sin_cache.scalar_type() == query.scalar_type(), "sin_cache dtype mismatch"); + TORCH_CHECK(cos_cache.is_cuda() && sin_cache.is_cuda() && query.is_cuda(), "All tensors must be on CUDA"); + if (key.has_value()) { + TORCH_CHECK(key->is_cuda(), "Key tensor must be on CUDA if provided"); + TORCH_CHECK(key->scalar_type() == query.scalar_type(), "Key dtype mismatch"); + } + + // hidden_size = num_heads * head_size + int query_hidden_size_calculated; + if (query.dim() == 2) { // [num_tokens, hidden_size] + query_hidden_size_calculated = (int)query.size(1); + } else { // [num_tokens, num_heads, head_size] + query_hidden_size_calculated = (int)query.size(1) * (int)query.size(2); + TORCH_CHECK(query.size(2) == head_size, "Query head_size mismatch in 3D tensor"); + } + TORCH_CHECK(query_hidden_size_calculated % head_size == 0, "query_hidden_size not divisible by head_size"); + int num_heads = (int)query_hidden_size_calculated / (int)head_size; + + int key_hidden_size_calculated = 0; + int num_kv_heads = num_heads; // Default if key is not present or GQA not used + if (key.has_value()) { + TORCH_CHECK((int)key->size(0) == num_tokens, "Key num_tokens mismatch"); + if (key->dim() == 2) { // [num_tokens, kv_hidden_size] + key_hidden_size_calculated = (int)key->size(1); + } else { // [num_tokens, num_kv_heads, head_size] + key_hidden_size_calculated = (int)key->size(1) * (int)key->size(2); + TORCH_CHECK((int)key->size(2) == head_size, "Key head_size mismatch in 3D tensor"); + } + TORCH_CHECK(key_hidden_size_calculated % head_size == 0, "key_hidden_size not divisible by head_size"); + num_kv_heads = key_hidden_size_calculated / (int)head_size; + } + TORCH_CHECK(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); + + int rot_dim = (int)cos_cache.size(1) * 2; + // TORCH_CHECK(rot_dim <= head_size, "rot_dim must be <= head_size"); + + // Strides to get to the next token's data + int64_t query_token_stride = query_hidden_size_calculated; + int64_t key_token_stride = key.has_value() ? key_hidden_size_calculated : 0; + + // Strides to get to the next head's data *within* a token + // If query is [num_tokens, num_heads, head_size], stride is query.stride(1) + // If query is [num_tokens, num_heads * head_size], stride is head_size + int64_t head_stride_query; + if (query.dim() == 3 && query.size(1) == num_heads && query.size(2) == head_size) { + head_stride_query = query.stride(1); + } else { // Assumed to be [num_tokens, num_heads * head_size] or will be viewed as such + head_stride_query = head_size; + } + + int64_t head_stride_key = head_size; // Default for key + if (key.has_value()) { + if (key->dim() == 3 && key->size(1) == num_kv_heads && key->size(2) == head_size) { + head_stride_key = key->stride(1); + } else { + head_stride_key = head_size; + } + } + + dim3 grid((int)num_tokens); + // Max threads per block is usually 1024. + // Each thread handles one pair (x,y) to rotate. + // Total pairs for query for one token: num_heads * (rot_dim / 2) + // We want enough threads to cover these pairs, up to a limit. + // The loop inside apply_rotary_embedding handles thread stride. + int max_pairs_to_rotate_per_token = std::max(num_heads * (rot_dim / 2), num_kv_heads * (rot_dim / 2)); + dim3 block(std::min(max_pairs_to_rotate_per_token, 512L)); // 512L to ensure long comparison + if (block.x == 0 && num_tokens > 0) block.x = 1; // Ensure at least one thread if there's work + + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + SGLANG_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + rotary_embedding_kernel<<>>( + cos_cache.data_ptr(), + sin_cache.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + rot_dim, + query_token_stride, + key_token_stride, + head_stride_query, + head_stride_key, + num_heads, + num_kv_heads, + head_size); + } else { + rotary_embedding_kernel<<>>( + cos_cache.data_ptr(), + sin_cache.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + rot_dim, + query_token_stride, + key_token_stride, + head_stride_query, + head_stride_key, + num_heads, + num_kv_heads, + head_size); + } + }); + // C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index b5e376dc809b..de0e5317f190 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -324,6 +324,14 @@ void top_p_sampling_from_probs( bool deterministic, std::optional gen); +void rotary_embedding( + at::Tensor& cos_cache, // [num_tokens, rot_dim / 2] + at::Tensor& sin_cache, // [num_tokens, rot_dim / 2] + at::Tensor& query, // [num_tokens, num_heads * head_size] + const std::optional& key, // null or similar to query + int64_t head_size, + bool is_neox); + namespace flash { /* * From fa2 sparse diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 229c6e9c4b8e..3a098de31384 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -279,9 +279,23 @@ inline int getSMVersion() { #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define SGLANG_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define SGLANG_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, SGLANG_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define WARP_SIZE 32 +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index ec97fa4b591f..4fdcaf0bc175 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -52,6 +52,7 @@ prepare_moe_input, topk_softmax, ) +from sgl_kernel.rotary_embedding import rotary_embedding from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, diff --git a/sgl-kernel/python/sgl_kernel/rotary_embedding.py b/sgl-kernel/python/sgl_kernel/rotary_embedding.py new file mode 100644 index 000000000000..11abfa2b259e --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/rotary_embedding.py @@ -0,0 +1,35 @@ +from typing import Optional + +import torch + + +# Adapted from https://github.com/vllm-project/vllm/blob/9214e60631a79506e7669650de87806a123e0b0b/vllm/_custom_ops.py#L249 +# pos encoding ops +def rotary_embedding( + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor], + head_size: int, + # cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + torch.ops.sgl_kernel.rotary_embedding.default( + cos, + sin, + query, + key, + head_size, + # cos_sin_cache, + is_neox, + ) + + +# def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, +# key: Optional[torch.Tensor], head_size: int, +# cos_sin_cache: torch.Tensor, is_neox: bool, +# rot_dim: int, +# cos_sin_cache_offsets: torch.Tensor) -> None: +# torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, +# cos_sin_cache, is_neox, rot_dim, +# cos_sin_cache_offsets) From 86ac8ebe867375b31873e71692e5ef9fb739100c Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 18 May 2025 16:16:03 +0800 Subject: [PATCH 02/11] add test --- python/sglang/srt/layers/attention/vision.py | 18 +- python/sglang/srt/layers/rotary_embedding.py | 2 - .../csrc/multimodal/rotary_embedding.cu | 44 +++-- .../python/sgl_kernel/rotary_embedding.py | 12 -- sgl-kernel/tests/test_mm_rotary_embedding.py | 172 ++++++++++++++++++ 5 files changed, 212 insertions(+), 36 deletions(-) create mode 100644 sgl-kernel/tests/test_mm_rotary_embedding.py diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 43d4c630003e..1824a2a56d40 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -368,7 +368,6 @@ def __init__( max_position_embeddings=2048, base=10000, is_neox_style=False, - # dtype=rope_dtype dtype=torch.get_default_dtype(), ) @@ -459,7 +458,22 @@ def forward( print(f"{q.dtype}") q, k = self.rotary_emb(position_embeddings, q, k) - # q, k = apply_rotary_pos_emb(q, k, cos, sin) + q_old, k_old = apply_rotary_pos_emb(q, k, cos, sin) + + torch.testing.assert_close( + q, + q_old, + rtol=5, + atol=5, + msg="", + ) + torch.testing.assert_close( + k, + k_old, + rtol=5, + atol=5, + msg="", + ) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 9a8c692f1258..871d66d14ed5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -185,10 +185,8 @@ def forward_cuda( query, key, self.head_size, - # self.cos_sin_cache, self.is_neox_style, ) - query = query.to(dtype=orig_q_dtype) key = key.to(dtype=orig_k_dtype) return query, key diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu index de5ec819826a..2356124b05da 100644 --- a/sgl-kernel/csrc/multimodal/rotary_embedding.cu +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -33,25 +33,29 @@ inline __device__ void apply_token_rotary_embedding( int rot_offset, int embed_dim) { int x_index, y_index; - scalar_t cos_val, sin_val; // Renamed to avoid conflict with input ptrs + scalar_t cos_val_x, sin_val_x; if (IS_NEOX) { - // GPT-NeoX style rotary embedding. + // NEOX-specific case (unchanged) x_index = rot_offset; - y_index = embed_dim + rot_offset; // This was embed_dim + rot_offset, which is correct for NeoX's interleaved pairs - cos_val = SGLANG_LDG(cos_ptr + x_index); - sin_val = SGLANG_LDG(sin_ptr + x_index); + y_index = embed_dim + rot_offset; + cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); + sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); } else { - // GPT-J style rotary embedding. - x_index = 2 * rot_offset; - y_index = 2 * rot_offset + 1; - cos_val = SGLANG_LDG(cos_ptr + rot_offset); // x_index / 2 = rot_offset - sin_val = SGLANG_LDG(sin_ptr + rot_offset); // x_index / 2 = rot_offset + // GPT-J style - modified to match Python implementation + x_index = rot_offset; + y_index = rot_offset + embed_dim; + cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); + sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); } const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; - arr[x_index] = x * cos_val - y * sin_val; - arr[y_index] = y * cos_val + x * sin_val; + + // Modified to match Python implementation + // Python: q_embed = (q * cos) + (rotate_half(q) * sin) + // Where rotate_half negates the second half + arr[x_index] = x * cos_val_x - y * sin_val_x; // First half: q[i]*cos[i] - q[i+half]*sin[i] + arr[y_index] = y * cos_val_x + x * sin_val_x; // Second half: q[i+half]*cos[i] + q[i]*sin[i] } template @@ -101,8 +105,8 @@ inline __device__ void apply_rotary_embedding( template __global__ void rotary_embedding_kernel( - const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim / 2] - const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim / 2] + const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim] + const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim] scalar_t* __restrict__ query_total, // [num_tokens, num_heads, head_size] or [num_tokens, num_heads * head_size] scalar_t* __restrict__ key_total, // nullptr or similar shape to query_total const int rot_dim, @@ -117,8 +121,8 @@ __global__ void rotary_embedding_kernel( const int token_idx = blockIdx.x; const int embed_dim = rot_dim / 2; - const scalar_t* current_token_cos_ptr = cos_data + token_idx * embed_dim; - const scalar_t* current_token_sin_ptr = sin_data + token_idx * embed_dim; + const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim; + const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim; scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; @@ -137,9 +141,9 @@ __global__ void rotary_embedding_kernel( } void rotary_embedding( - at::Tensor& cos_cache, // [num_tokens, rot_dim / 2] - at::Tensor& sin_cache, // [num_tokens, rot_dim / 2] - at::Tensor& query, // [num_tokens, num_heads * head_size] + at::Tensor& cos_cache, // [num_tokens, rot_dim] + at::Tensor& sin_cache, // [num_tokens, rot_dim] + at::Tensor& query, // [num_tokens, num_heads, head_size] const std::optional& key, // null or similar to query int64_t head_size, bool is_neox) { @@ -194,7 +198,7 @@ void rotary_embedding( } TORCH_CHECK(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); - int rot_dim = (int)cos_cache.size(1) * 2; + int rot_dim = (int)cos_cache.size(1); // TORCH_CHECK(rot_dim <= head_size, "rot_dim must be <= head_size"); // Strides to get to the next token's data diff --git a/sgl-kernel/python/sgl_kernel/rotary_embedding.py b/sgl-kernel/python/sgl_kernel/rotary_embedding.py index 11abfa2b259e..6d5072e68461 100644 --- a/sgl-kernel/python/sgl_kernel/rotary_embedding.py +++ b/sgl-kernel/python/sgl_kernel/rotary_embedding.py @@ -11,7 +11,6 @@ def rotary_embedding( query: torch.Tensor, key: Optional[torch.Tensor], head_size: int, - # cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: torch.ops.sgl_kernel.rotary_embedding.default( @@ -20,16 +19,5 @@ def rotary_embedding( query, key, head_size, - # cos_sin_cache, is_neox, ) - - -# def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, -# key: Optional[torch.Tensor], head_size: int, -# cos_sin_cache: torch.Tensor, is_neox: bool, -# rot_dim: int, -# cos_sin_cache_offsets: torch.Tensor) -> None: -# torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, -# cos_sin_cache, is_neox, rot_dim, -# cos_sin_cache_offsets) diff --git a/sgl-kernel/tests/test_mm_rotary_embedding.py b/sgl-kernel/tests/test_mm_rotary_embedding.py new file mode 100644 index 000000000000..fcc65129fe62 --- /dev/null +++ b/sgl-kernel/tests/test_mm_rotary_embedding.py @@ -0,0 +1,172 @@ +from typing import Tuple, Union + +import pytest +import torch +from sgl_kernel import rotary_embedding + + +# torch native +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim=1, +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + + # embedding is performed in float + cos = cos.unsqueeze(unsqueeze_dim).float() + sin = sin.unsqueeze(unsqueeze_dim).float() + print(f"{cos.shape=}") + print(f"{q.shape=}") + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + + return q_embed, k_embed + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + dtype: torch.dtype, + is_neox_style: bool = False, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + def forward_kernel_inplace( + self, + cos: torch.Tensor, + sin: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + rotary_embedding( + cos, + sin, + query, + key, + self.head_size, + self.is_neox_style, + ) + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 32, 32, 16, 16), + ], +) +def test_correctness( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + dtype=dtype, + is_neox_style=is_neox_style, + ).to(device) + + query = torch.randn( + batch_size * seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + + key = torch.randn( + batch_size * seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + + cos = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + sin = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + + print(f"{query.shape=}") + print(f"{cos.shape=}") + + # Modification: float32 is required for the rotary embedding to work correctly + query_native_out, key_native_out = rope_ref.forward_native(cos, sin, query, key) + + # in-place + rope_ref.forward_kernel_inplace(cos, sin, query, key) + + torch.testing.assert_close(query_native_out, query, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(key_native_out, key, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) From 2c7c71c3cd53bf6e71e412f4ff10d9466269edf2 Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 18 May 2025 19:24:41 +0800 Subject: [PATCH 03/11] test passed --- python/sglang/srt/layers/attention/vision.py | 37 ++-- python/sglang/srt/layers/rotary_embedding.py | 9 +- python/sglang/srt/models/qwen2_5_vl.py | 2 +- .../csrc/multimodal/rotary_embedding.cu | 199 ++++++++++-------- sgl-kernel/tests/test_mm_rotary_embedding.py | 21 +- 5 files changed, 144 insertions(+), 124 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 1824a2a56d40..52f31bfd8297 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -361,7 +361,6 @@ def __init__( softmax_in_single_precision=softmax_in_single_precision, ) - print(f"{self.head_size=}") self.rotary_emb = RotaryEmbedding( head_size=self.head_size, rotary_dim=self.head_size, @@ -454,26 +453,26 @@ def forward( q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - print(f"{cos.dtype}") - print(f"{q.dtype}") + # print(f"{cos.dtype}") + # print(f"{q.dtype}") q, k = self.rotary_emb(position_embeddings, q, k) - q_old, k_old = apply_rotary_pos_emb(q, k, cos, sin) - - torch.testing.assert_close( - q, - q_old, - rtol=5, - atol=5, - msg="", - ) - torch.testing.assert_close( - k, - k_old, - rtol=5, - atol=5, - msg="", - ) + # q_old, k_old = apply_rotary_pos_emb(q, k, cos, sin) + + # torch.testing.assert_close( + # q, + # q_old, + # rtol=5, + # atol=5, + # msg="", + # ) + # torch.testing.assert_close( + # k, + # k_old, + # rtol=5, + # atol=5, + # msg="", + # ) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 871d66d14ed5..55e6794b4ae5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -170,15 +170,10 @@ def forward_cuda( else: orig_q_dtype = query.dtype orig_k_dtype = key.dtype - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) cos, sin = positions - query = query.to(dtype=cos.dtype) - key = key.to(dtype=cos.dtype) + cos, sin = cos.float(), sin.float() + query, key = query.float(), key.float() - print(f"{type(cos)=}") - print(f"{cos.shape=}") - print(f"{query.shape=}") - print(f"{key.shape=}") self.sglang_rotary_embedding( cos, sin, diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 420216c7bb0d..0f28002abdbe 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -398,7 +398,7 @@ def forward( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1).float() emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu index 2356124b05da..01c0e321055d 100644 --- a/sgl-kernel/csrc/multimodal/rotary_embedding.cu +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -31,98 +31,118 @@ inline __device__ void apply_token_rotary_embedding( const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, - int embed_dim) { + int embed_dim /* this is rot_dim / 2 */) { int x_index, y_index; - scalar_t cos_val_x, sin_val_x; + // scalar_t cos_val, sin_val; // Will be split for IS_NEOX = false case + if (IS_NEOX) { - // NEOX-specific case (unchanged) + // Assuming NeoX pairs (2k, 2k+1) and cos/sin are indexed by k (pair index) + // And rot_offset is the pair index k, embed_dim is total number of pairs for this head. + // This part might need further review if NeoX is critical and cos/sin have full head_size. + // The original code's pairing (k, k + embed_dim) is LLaMA/GPT-J style. + // If IS_NEOX=true means actual NeoX pairing (2*rot_offset, 2*rot_offset+1) + // and cos/sin are [num_tokens, rot_dim/2], then this would be: + // x_index = rot_offset * 2; + // y_index = rot_offset * 2 + 1; + // cos_val = SGLANG_LDG(cos_ptr + rot_offset); + // sin_val = SGLANG_LDG(sin_ptr + rot_offset); + // For now, keeping original logic for IS_NEOX=true as it wasn't the focus. + // The original comment "CORRECTION: We need to ensure this pairs correctly" is important. x_index = rot_offset; - y_index = embed_dim + rot_offset; - cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); - sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); + y_index = embed_dim + rot_offset; // embed_dim is half of the feature dim being rotated + + scalar_t cos_val = SGLANG_LDG(cos_ptr + rot_offset); + scalar_t sin_val = SGLANG_LDG(sin_ptr + rot_offset); + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos_val - y * sin_val; + arr[y_index] = y * cos_val + x * sin_val; + } else { - // GPT-J style - modified to match Python implementation - x_index = rot_offset; - y_index = rot_offset + embed_dim; - cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); - sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); - } + // GPT-J style / LLaMA style, matching the Python if cos/sin are [..., head_size] + x_index = rot_offset; // e.g., 0 to 39 if head_size=80, embed_dim=40 + y_index = rot_offset + embed_dim; // e.g., 40 to 79 - const scalar_t x = arr[x_index]; - const scalar_t y = arr[y_index]; + // cos_ptr and sin_ptr point to the start of the current token's head_size dimension + // rot_offset is the index within the first half of head_size + // embed_dim is head_size / 2 + const scalar_t cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); + const scalar_t sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); + const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); // Index for the second half + const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); // Index for the second half - // Modified to match Python implementation - // Python: q_embed = (q * cos) + (rotate_half(q) * sin) - // Where rotate_half negates the second half - arr[x_index] = x * cos_val_x - y * sin_val_x; // First half: q[i]*cos[i] - q[i+half]*sin[i] - arr[y_index] = y * cos_val_x + x * sin_val_x; // Second half: q[i+half]*cos[i] + q[i]*sin[i] + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos_val_x - y * sin_val_x; + arr[y_index] = y * cos_val_y + x * sin_val_y; // Matches Python: k_half * cos_half + q_first_half * sin_half + } } template inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [num_heads, head_size] for current token scalar_t* __restrict__ key, // nullptr or [num_kv_heads, head_size] for current token - const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim/2] for current token - const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim/2] for current token + const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim] for current token (rot_dim is head_size here) + const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim] for current token (rot_dim is head_size here) const int head_size, const int num_heads, const int num_kv_heads, - const int rot_dim, - const int64_t head_stride_query, // Stride to get to next head in query - const int64_t head_stride_key // Stride to get to next head in key -) { - const int embed_dim = rot_dim / 2; // Number of elements in cos/sin arrays for one token - - // No need to offset current_token_cos_ptr and current_token_sin_ptr further here, - // they already point to the start of cos/sin values for the current token. + const int rot_dim, // This rot_dim is the one from cos_cache.size(1), assumed to be head_size + const int64_t head_stride_query, + const int64_t head_stride_key) { + // If rot_dim from cache is full head_size, then embed_dim here is head_size / 2 + // This embed_dim is the number of pairs to rotate if using LLaMA/GPT-J style pairing, + // or the number of elements in the first half of the rotation. + const int embed_dim_for_rotation = rot_dim / 2; - const int nq_pairs = num_heads * embed_dim; // Total pairs to rotate for query + const int nq_pairs = num_heads * embed_dim_for_rotation; for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int rot_offset = i % embed_dim; // Offset within the head's part to be rotated + const int head_idx = i / embed_dim_for_rotation; + const int rot_offset = i % embed_dim_for_rotation; // Offset within the first half of features to be rotated - // query_for_token_head points to the start of the specific head for the current token scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; apply_token_rotary_embedding( - query_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim); + query_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim_for_rotation); } if (key != nullptr) { - const int nk_pairs = num_kv_heads * embed_dim; // Total pairs to rotate for key + const int nk_pairs = num_kv_heads * embed_dim_for_rotation; for (int i = threadIdx.x; i < nk_pairs; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int rot_offset = i % embed_dim; + const int head_idx = i / embed_dim_for_rotation; + const int rot_offset = i % embed_dim_for_rotation; - // key_for_token_head points to the start of the specific head for the current token scalar_t* key_for_token_head = key + head_idx * (int)head_stride_key; apply_token_rotary_embedding( - key_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim); + key_for_token_head, current_token_cos_ptr, current_token_sin_ptr, rot_offset, embed_dim_for_rotation); } } } template __global__ void rotary_embedding_kernel( - const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim] - const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim] - scalar_t* __restrict__ query_total, // [num_tokens, num_heads, head_size] or [num_tokens, num_heads * head_size] - scalar_t* __restrict__ key_total, // nullptr or similar shape to query_total - const int rot_dim, - const int64_t query_token_stride, // Elements to skip to get to next token in query_total - const int64_t key_token_stride, // Elements to skip to get to next token in key_total - const int64_t head_stride_query, // Elements to skip to get to next head within a token's query data - const int64_t head_stride_key, // Elements to skip to get to next head within a token's key data + const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim_arg] where rot_dim_arg is from cos_cache.size(1) + const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim_arg] + scalar_t* __restrict__ query_total, + scalar_t* __restrict__ key_total, + const int rot_dim_arg, // This is cos_cache.size(1). Per clarification, this is head_size. + const int64_t query_token_stride, + const int64_t key_token_stride, + const int64_t head_stride_query, + const int64_t head_stride_key, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. + const int head_size) { // head_size of q/k tensors const int token_idx = blockIdx.x; - const int embed_dim = rot_dim / 2; + // const int embed_dim = rot_dim_arg / 2; // This is head_size / 2 - const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim; - const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim; + // MODIFICATION 1: + // If cos_data is [num_tokens, rot_dim_arg], then stride to next token is rot_dim_arg. + // Original code used 'embed_dim' which would be rot_dim_arg / 2. + const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim_arg; + const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim_arg; scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; @@ -132,20 +152,20 @@ __global__ void rotary_embedding_kernel( key_for_token, current_token_cos_ptr, current_token_sin_ptr, - head_size, + head_size, // actual head_size of q/k num_heads, num_kv_heads, - rot_dim, + rot_dim_arg, // rot_dim from cos_cache, passed to apply_token_rotary_embedding head_stride_query, head_stride_key); } void rotary_embedding( - at::Tensor& cos_cache, // [num_tokens, rot_dim] - at::Tensor& sin_cache, // [num_tokens, rot_dim] - at::Tensor& query, // [num_tokens, num_heads, head_size] - const std::optional& key, // null or similar to query - int64_t head_size, + at::Tensor& cos_cache, // Per clarification: [num_tokens, head_size] + at::Tensor& sin_cache, // Per clarification: [num_tokens, head_size] + at::Tensor& query, + const std::optional& key, + int64_t head_size, // head_size of q/k bool is_neox) { TORCH_CHECK( query.dim() == 2 || query.dim() == 3, @@ -158,11 +178,15 @@ void rotary_embedding( int64_t num_tokens = query.size(0); - TORCH_CHECK(cos_cache.dim() == 2, "cos_cache must have shape [num_tokens, rot_dim/2]"); - TORCH_CHECK(sin_cache.dim() == 2, "sin_cache must have shape [num_tokens, rot_dim/2]"); + // The original check assumed cos_cache's last dim is rot_dim/2. + // Given clarification, cos_cache.size(1) is effectively the 'rot_dim' for the cache, + // which you stated is head_size. + // So, if cos_cache is [num_tokens, D_cos], then D_cos is passed as rot_dim to kernel. + TORCH_CHECK(cos_cache.dim() == 2, "cos_cache must have shape [num_tokens, D_cos]"); + TORCH_CHECK(sin_cache.dim() == 2, "sin_cache must have shape [num_tokens, D_sin]"); TORCH_CHECK(cos_cache.size(0) == num_tokens, "cos_cache num_tokens mismatch with query"); TORCH_CHECK(sin_cache.size(0) == num_tokens, "sin_cache num_tokens mismatch with query"); - TORCH_CHECK(cos_cache.size(1) == sin_cache.size(1), "cos_cache and sin_cache rot_dim/2 mismatch"); + TORCH_CHECK(cos_cache.size(1) == sin_cache.size(1), "cos_cache and sin_cache D_cos/D_sin mismatch"); TORCH_CHECK(cos_cache.scalar_type() == query.scalar_type(), "cos_cache dtype mismatch"); TORCH_CHECK(sin_cache.scalar_type() == query.scalar_type(), "sin_cache dtype mismatch"); @@ -172,11 +196,10 @@ void rotary_embedding( TORCH_CHECK(key->scalar_type() == query.scalar_type(), "Key dtype mismatch"); } - // hidden_size = num_heads * head_size int query_hidden_size_calculated; - if (query.dim() == 2) { // [num_tokens, hidden_size] + if (query.dim() == 2) { query_hidden_size_calculated = (int)query.size(1); - } else { // [num_tokens, num_heads, head_size] + } else { query_hidden_size_calculated = (int)query.size(1) * (int)query.size(2); TORCH_CHECK(query.size(2) == head_size, "Query head_size mismatch in 3D tensor"); } @@ -184,12 +207,12 @@ void rotary_embedding( int num_heads = (int)query_hidden_size_calculated / (int)head_size; int key_hidden_size_calculated = 0; - int num_kv_heads = num_heads; // Default if key is not present or GQA not used + int num_kv_heads = num_heads; if (key.has_value()) { TORCH_CHECK((int)key->size(0) == num_tokens, "Key num_tokens mismatch"); - if (key->dim() == 2) { // [num_tokens, kv_hidden_size] + if (key->dim() == 2) { key_hidden_size_calculated = (int)key->size(1); - } else { // [num_tokens, num_kv_heads, head_size] + } else { key_hidden_size_calculated = (int)key->size(1) * (int)key->size(2); TORCH_CHECK((int)key->size(2) == head_size, "Key head_size mismatch in 3D tensor"); } @@ -198,24 +221,24 @@ void rotary_embedding( } TORCH_CHECK(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); - int rot_dim = (int)cos_cache.size(1); - // TORCH_CHECK(rot_dim <= head_size, "rot_dim must be <= head_size"); + // This rot_dim_from_cache is what's passed to the kernel as rot_dim_arg. + // Per your clarification, this is effectively head_size. + int rot_dim_from_cache = (int)cos_cache.size(1); + // The check `rot_dim <= head_size` is still generally good. + // If rot_dim_from_cache is indeed head_size, then this becomes `head_size <= head_size`. + // TORCH_CHECK(rot_dim_from_cache <= head_size, "rot_dim from cache must be <= head_size of q/k"); - // Strides to get to the next token's data int64_t query_token_stride = query_hidden_size_calculated; int64_t key_token_stride = key.has_value() ? key_hidden_size_calculated : 0; - // Strides to get to the next head's data *within* a token - // If query is [num_tokens, num_heads, head_size], stride is query.stride(1) - // If query is [num_tokens, num_heads * head_size], stride is head_size int64_t head_stride_query; if (query.dim() == 3 && query.size(1) == num_heads && query.size(2) == head_size) { head_stride_query = query.stride(1); - } else { // Assumed to be [num_tokens, num_heads * head_size] or will be viewed as such + } else { head_stride_query = head_size; } - int64_t head_stride_key = head_size; // Default for key + int64_t head_stride_key = head_size; if (key.has_value()) { if (key->dim() == 3 && key->size(1) == num_kv_heads && key->size(2) == head_size) { head_stride_key = key->stride(1); @@ -225,14 +248,12 @@ void rotary_embedding( } dim3 grid((int)num_tokens); - // Max threads per block is usually 1024. - // Each thread handles one pair (x,y) to rotate. - // Total pairs for query for one token: num_heads * (rot_dim / 2) - // We want enough threads to cover these pairs, up to a limit. - // The loop inside apply_rotary_embedding handles thread stride. - int max_pairs_to_rotate_per_token = std::max(num_heads * (rot_dim / 2), num_kv_heads * (rot_dim / 2)); - dim3 block(std::min(max_pairs_to_rotate_per_token, 512L)); // 512L to ensure long comparison - if (block.x == 0 && num_tokens > 0) block.x = 1; // Ensure at least one thread if there's work + // embed_dim_for_block_calc is head_size / 2 if rot_dim_from_cache is head_size + int embed_dim_for_block_calc = rot_dim_from_cache / 2; + int max_pairs_to_rotate_per_token = + std::max(num_heads * embed_dim_for_block_calc, num_kv_heads * embed_dim_for_block_calc); + dim3 block(std::min(max_pairs_to_rotate_per_token, 512L)); + if (block.x == 0 && num_tokens > 0) block.x = 1; const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -244,28 +265,28 @@ void rotary_embedding( sin_cache.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - rot_dim, + rot_dim_from_cache, // Pass the dimension from cos_cache query_token_stride, key_token_stride, head_stride_query, head_stride_key, num_heads, num_kv_heads, - head_size); + (int)head_size); // Pass the actual head_size of q/k } else { rotary_embedding_kernel<<>>( cos_cache.data_ptr(), sin_cache.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - rot_dim, + rot_dim_from_cache, // Pass the dimension from cos_cache query_token_stride, key_token_stride, head_stride_query, head_stride_key, num_heads, num_kv_heads, - head_size); + (int)head_size); // Pass the actual head_size of q/k } }); // C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/sgl-kernel/tests/test_mm_rotary_embedding.py b/sgl-kernel/tests/test_mm_rotary_embedding.py index fcc65129fe62..2ace678578d4 100644 --- a/sgl-kernel/tests/test_mm_rotary_embedding.py +++ b/sgl-kernel/tests/test_mm_rotary_embedding.py @@ -27,11 +27,9 @@ def apply_rotary_pos_emb( # embedding is performed in float cos = cos.unsqueeze(unsqueeze_dim).float() sin = sin.unsqueeze(unsqueeze_dim).float() - print(f"{cos.shape=}") - print(f"{q.shape=}") q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) - + print(f"perform in {cos.dtype=}") q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) @@ -103,6 +101,9 @@ def forward_kernel_inplace( key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-native implementation of forward().""" + cos, sin = cos.float(), sin.float() + query, key = query.float(), key.float() + print(f"kernel: perform in {cos.dtype=}") rotary_embedding( cos, sin, @@ -159,14 +160,18 @@ def test_correctness( print(f"{cos.shape=}") # Modification: float32 is required for the rotary embedding to work correctly - query_native_out, key_native_out = rope_ref.forward_native(cos, sin, query, key) + query_native_out, key_native_out = rope_ref.forward_native( + cos, sin, query.clone(), key.clone() + ) # in-place - rope_ref.forward_kernel_inplace(cos, sin, query, key) + query_kernel_out, key_kernel_out = rope_ref.forward_kernel_inplace( + cos, sin, query, key + ) - torch.testing.assert_close(query_native_out, query, atol=1e-1, rtol=1e-1) - torch.testing.assert_close(key_native_out, key, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(query_native_out, query_kernel_out, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(key_native_out, key_kernel_out, atol=1e-3, rtol=1e-3) if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "--capture=no"]) From 64e45e775c500d6959b127ba10e7707a7ff492d6 Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 18 May 2025 19:44:24 +0800 Subject: [PATCH 04/11] cherry-pick update-test --- test/srt/models/test_vlm_models.py | 118 ++++++++++++++++++++++++----- 1 file changed, 98 insertions(+), 20 deletions(-) diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index c55e98da2272..ec9b672d4ed8 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -1,34 +1,48 @@ +""" + python test_vlm_models.py --batch-size 1 +""" + import argparse import glob import json +import logging import os import random import subprocess import sys +import time import unittest +from collections import defaultdict from types import SimpleNamespace +from typing import Optional +from sglang.bench_serving import async_request_profile from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, - CustomTestCase, is_in_ci, popen_launch_server, ) # VLM models for testing MODELS = [ - SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45), + SimpleNamespace( + model="google/gemma-3-4b-it", chat_template="gemma-it", mmmu_accuracy=0.384 + ), SimpleNamespace( model="Qwen/Qwen2.5-VL-3B-Instruct", - mmmu_accuracy=0.4, + mmmu_accuracy=0.466, + ), + SimpleNamespace( + model="openbmb/MiniCPM-V-2_6", + chat_template="minicpmv", + mmmu_accuracy=0.3867, ), - SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4), ] -class TestVLMModels(CustomTestCase): +class TestVLMModels(unittest.IsolatedAsyncioTestCase): parsed_args = None # Class variable to store args @classmethod @@ -41,11 +55,25 @@ def setUpClass(cls): # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. os.environ["OPENAI_API_KEY"] = cls.api_key os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" + cmd = ["python3", "-m", "pip", "show", "lmms_eval"] + + ret = subprocess.run( + cmd, + timeout=3600, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + assert ( + ret.returncode == 0 + ), "please install lmms_eval by `pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git`" def run_mmmu_eval( self, model_version: str, + batch_size: int, output_path: str, + limit: Optional[str] = None, *, env: dict | None = None, ): @@ -58,7 +86,6 @@ def run_mmmu_eval( model = "openai_compatible" tp = 1 tasks = "mmmu_val" - batch_size = 2 log_suffix = "openai_compatible" os.makedirs(output_path, exist_ok=True) @@ -85,32 +112,36 @@ def run_mmmu_eval( str(output_path), ] + if limit is not None: + cmd += [ + "--limit", + limit, + ] + subprocess.run( cmd, check=True, timeout=3600, ) - def test_vlm_mmmu_benchmark(self): + async def test_vlm_mmmu_benchmark(self): """Test VLM models against MMMU benchmark.""" models_to_test = MODELS if is_in_ci(): models_to_test = [random.choice(MODELS)] - + results = defaultdict(dict) for model in models_to_test: print(f"\nTesting model: {model.model}") process = None mmmu_accuracy = 0 # Initialize to handle potential exceptions - try: # Launch server for testing process = popen_launch_server( model.model, base_url=self.base_url, timeout=self.time_out, - api_key=self.api_key, other_args=[ "--trust-remote-code", "--cuda-graph-max-bs", @@ -121,11 +152,38 @@ def test_vlm_mmmu_benchmark(self): ], ) + if args.profile: + print("Starting profiler...") + profile_output = await async_request_profile( + api_url=self.base_url + "/start_profile" + ) + if profile_output.success: + print("Profiler started") + # Run evaluation - self.run_mmmu_eval(model.model, "./logs") + self.run_mmmu_eval( + model.model, + self.parsed_args.batch_size, + "./logs", + limit=str(1) if self.parsed_args.profile else None, + ) + + if args.profile: + profile_output = await async_request_profile( + api_url=self.base_url + "/stop_profile" + ) + if profile_output.success: + print("Profiler stopped") + print( + "You should kill the process manually until profiling actually stopped" + ) + while True: + time.sleep(10) # Get the result file - result_file_path = glob.glob("./logs/*.json")[0] + files = glob.glob("./logs/*.json") + + result_file_path = max(files, key=os.path.getmtime) with open(result_file_path, "r") as f: result = json.load(f) @@ -133,19 +191,23 @@ def test_vlm_mmmu_benchmark(self): # Process the result mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}") - + print(f"Evaluation time:", result["total_evaluation_time_seconds"]) + results[model.model] = { + "accu": mmmu_accuracy, + "time": result["total_evaluation_time_seconds"], + } # Assert performance meets expected threshold - self.assertGreaterEqual( - mmmu_accuracy, - model.mmmu_accuracy, - f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})", - ) - + # self.assertGreaterEqual( + # mmmu_accuracy, + # model.mmmu_accuracy, + # f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})", + # ) except Exception as e: print(f"Error testing {model.model}: {e}") self.fail(f"Test failed for {model.model}: {e}") finally: + print(json.dumps(dict(results), indent=4)) # Ensure process cleanup happens regardless of success/failure if process is not None and process.poll() is None: print(f"Cleaning up process {process.pid}") @@ -153,6 +215,7 @@ def test_vlm_mmmu_benchmark(self): kill_process_tree(process.pid) except Exception as e: print(f"Error killing process: {e}") + print(json.dumps(dict(results), indent=4)) if __name__ == "__main__": @@ -164,10 +227,25 @@ def test_vlm_mmmu_benchmark(self): help="Static memory fraction for the model", default=0.8, ) - + parser.add_argument( + "--batch-size", + type=int, + default=1, + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler", + default=False, + ) # Parse args intended for unittest args = parser.parse_args() + if args.profile: + log_level = os.getenv("LOG_LEVEL", "WARNING").upper() + logging.basicConfig(level="INFO") + # Store the parsed args object on the class TestVLMModels.parsed_args = args From 4cf093d06e67a5b7ef499efcca75615d8d187412 Mon Sep 17 00:00:00 2001 From: Mick Date: Mon, 19 May 2025 08:52:55 +0800 Subject: [PATCH 05/11] cleanups --- python/sglang/srt/layers/attention/vision.py | 20 --- python/sglang/srt/layers/rotary_embedding.py | 1 - .../csrc/multimodal/rotary_embedding.cu | 126 +++++++----------- 3 files changed, 46 insertions(+), 101 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 52f31bfd8297..9e4d934ea98b 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -447,32 +447,12 @@ def forward( ] if position_embeddings is not None: - cos, sin = position_embeddings original_shape = q.shape # [total_tokens, head, head_size] q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - # print(f"{cos.dtype}") - # print(f"{q.dtype}") - q, k = self.rotary_emb(position_embeddings, q, k) - # q_old, k_old = apply_rotary_pos_emb(q, k, cos, sin) - - # torch.testing.assert_close( - # q, - # q_old, - # rtol=5, - # atol=5, - # msg="", - # ) - # torch.testing.assert_close( - # k, - # k_old, - # rtol=5, - # atol=5, - # msg="", - # ) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 55e6794b4ae5..f372873877f2 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -173,7 +173,6 @@ def forward_cuda( cos, sin = positions cos, sin = cos.float(), sin.float() query, key = query.float(), key.float() - self.sglang_rotary_embedding( cos, sin, diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu index 01c0e321055d..105c58e99bc5 100644 --- a/sgl-kernel/csrc/multimodal/rotary_embedding.cu +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -15,15 +15,14 @@ * limitations under the License. */ #include +#include +#include +#include #include #include #include "utils.h" -// #include -#include -#include -#include template inline __device__ void apply_token_rotary_embedding( @@ -31,25 +30,12 @@ inline __device__ void apply_token_rotary_embedding( const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, - int embed_dim /* this is rot_dim / 2 */) { + int embed_dim) { int x_index, y_index; - // scalar_t cos_val, sin_val; // Will be split for IS_NEOX = false case if (IS_NEOX) { - // Assuming NeoX pairs (2k, 2k+1) and cos/sin are indexed by k (pair index) - // And rot_offset is the pair index k, embed_dim is total number of pairs for this head. - // This part might need further review if NeoX is critical and cos/sin have full head_size. - // The original code's pairing (k, k + embed_dim) is LLaMA/GPT-J style. - // If IS_NEOX=true means actual NeoX pairing (2*rot_offset, 2*rot_offset+1) - // and cos/sin are [num_tokens, rot_dim/2], then this would be: - // x_index = rot_offset * 2; - // y_index = rot_offset * 2 + 1; - // cos_val = SGLANG_LDG(cos_ptr + rot_offset); - // sin_val = SGLANG_LDG(sin_ptr + rot_offset); - // For now, keeping original logic for IS_NEOX=true as it wasn't the focus. - // The original comment "CORRECTION: We need to ensure this pairs correctly" is important. x_index = rot_offset; - y_index = embed_dim + rot_offset; // embed_dim is half of the feature dim being rotated + y_index = embed_dim + rot_offset; scalar_t cos_val = SGLANG_LDG(cos_ptr + rot_offset); scalar_t sin_val = SGLANG_LDG(sin_ptr + rot_offset); @@ -61,45 +47,39 @@ inline __device__ void apply_token_rotary_embedding( } else { // GPT-J style / LLaMA style, matching the Python if cos/sin are [..., head_size] - x_index = rot_offset; // e.g., 0 to 39 if head_size=80, embed_dim=40 - y_index = rot_offset + embed_dim; // e.g., 40 to 79 + x_index = rot_offset; // first half + y_index = rot_offset + embed_dim; // second half - // cos_ptr and sin_ptr point to the start of the current token's head_size dimension - // rot_offset is the index within the first half of head_size - // embed_dim is head_size / 2 const scalar_t cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); const scalar_t sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); - const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); // Index for the second half - const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); // Index for the second half + const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); + const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; arr[x_index] = x * cos_val_x - y * sin_val_x; - arr[y_index] = y * cos_val_y + x * sin_val_y; // Matches Python: k_half * cos_half + q_first_half * sin_half + arr[y_index] = y * cos_val_y + x * sin_val_y; } } template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [num_heads, head_size] for current token - scalar_t* __restrict__ key, // nullptr or [num_kv_heads, head_size] for current token - const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim] for current token (rot_dim is head_size here) - const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim] for current token (rot_dim is head_size here) + scalar_t* __restrict__ query, // [num_heads, head_size] + scalar_t* __restrict__ key, // [num_kv_heads, head_size] + const scalar_t* __restrict__ current_token_cos_ptr, // [rot_dim] + const scalar_t* __restrict__ current_token_sin_ptr, // [rot_dim] const int head_size, const int num_heads, const int num_kv_heads, - const int rot_dim, // This rot_dim is the one from cos_cache.size(1), assumed to be head_size + const int rot_dim, const int64_t head_stride_query, const int64_t head_stride_key) { - // If rot_dim from cache is full head_size, then embed_dim here is head_size / 2 - // This embed_dim is the number of pairs to rotate if using LLaMA/GPT-J style pairing, - // or the number of elements in the first half of the rotation. const int embed_dim_for_rotation = rot_dim / 2; const int nq_pairs = num_heads * embed_dim_for_rotation; for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { const int head_idx = i / embed_dim_for_rotation; - const int rot_offset = i % embed_dim_for_rotation; // Offset within the first half of features to be rotated + const int rot_offset = i % embed_dim_for_rotation; scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; @@ -123,24 +103,19 @@ inline __device__ void apply_rotary_embedding( template __global__ void rotary_embedding_kernel( - const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim_arg] where rot_dim_arg is from cos_cache.size(1) + const scalar_t* __restrict__ cos_data, // [num_tokens, rot_dim_arg] const scalar_t* __restrict__ sin_data, // [num_tokens, rot_dim_arg] scalar_t* __restrict__ query_total, scalar_t* __restrict__ key_total, - const int rot_dim_arg, // This is cos_cache.size(1). Per clarification, this is head_size. + const int rot_dim_arg, const int64_t query_token_stride, const int64_t key_token_stride, const int64_t head_stride_query, const int64_t head_stride_key, const int num_heads, const int num_kv_heads, - const int head_size) { // head_size of q/k tensors + const int head_size) { const int token_idx = blockIdx.x; - // const int embed_dim = rot_dim_arg / 2; // This is head_size / 2 - - // MODIFICATION 1: - // If cos_data is [num_tokens, rot_dim_arg], then stride to next token is rot_dim_arg. - // Original code used 'embed_dim' which would be rot_dim_arg / 2. const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim_arg; const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim_arg; @@ -152,45 +127,41 @@ __global__ void rotary_embedding_kernel( key_for_token, current_token_cos_ptr, current_token_sin_ptr, - head_size, // actual head_size of q/k + head_size, num_heads, num_kv_heads, - rot_dim_arg, // rot_dim from cos_cache, passed to apply_token_rotary_embedding + rot_dim_arg, head_stride_query, head_stride_key); } void rotary_embedding( - at::Tensor& cos_cache, // Per clarification: [num_tokens, head_size] - at::Tensor& sin_cache, // Per clarification: [num_tokens, head_size] + at::Tensor& cos, + at::Tensor& sin, at::Tensor& query, const std::optional& key, - int64_t head_size, // head_size of q/k + int64_t head_size, bool is_neox) { TORCH_CHECK( query.dim() == 2 || query.dim() == 3, - "query must have shape [num_tokens, hidden_size] or [num_tokens, num_heads, head_size]"); + "query must be in shape [num_tokens, hidden_size] or [num_tokens, num_heads, head_size]"); if (key.has_value()) { TORCH_CHECK( key->dim() == 2 || key->dim() == 3, - "key must have shape [num_tokens, hidden_size] or [num_tokens, num_kv_heads, head_size]"); + "key must be in shape [num_tokens, hidden_size] or [num_tokens, num_kv_heads, head_size]"); } int64_t num_tokens = query.size(0); - // The original check assumed cos_cache's last dim is rot_dim/2. - // Given clarification, cos_cache.size(1) is effectively the 'rot_dim' for the cache, - // which you stated is head_size. - // So, if cos_cache is [num_tokens, D_cos], then D_cos is passed as rot_dim to kernel. - TORCH_CHECK(cos_cache.dim() == 2, "cos_cache must have shape [num_tokens, D_cos]"); - TORCH_CHECK(sin_cache.dim() == 2, "sin_cache must have shape [num_tokens, D_sin]"); - TORCH_CHECK(cos_cache.size(0) == num_tokens, "cos_cache num_tokens mismatch with query"); - TORCH_CHECK(sin_cache.size(0) == num_tokens, "sin_cache num_tokens mismatch with query"); - TORCH_CHECK(cos_cache.size(1) == sin_cache.size(1), "cos_cache and sin_cache D_cos/D_sin mismatch"); - - TORCH_CHECK(cos_cache.scalar_type() == query.scalar_type(), "cos_cache dtype mismatch"); - TORCH_CHECK(sin_cache.scalar_type() == query.scalar_type(), "sin_cache dtype mismatch"); - TORCH_CHECK(cos_cache.is_cuda() && sin_cache.is_cuda() && query.is_cuda(), "All tensors must be on CUDA"); + TORCH_CHECK(cos.dim() == 2, "cos must be in shape [num_tokens, D_cos]"); + TORCH_CHECK(sin.dim() == 2, "sin must be in shape [num_tokens, D_sin]"); + TORCH_CHECK(cos.size(0) == num_tokens, "cos num_tokens mismatch with query"); + TORCH_CHECK(sin.size(0) == num_tokens, "sin num_tokens mismatch with query"); + TORCH_CHECK(cos.size(1) == sin.size(1), "cos and sin D_cos/D_sin mismatch"); + + TORCH_CHECK(cos.scalar_type() == query.scalar_type(), "cos dtype mismatch"); + TORCH_CHECK(sin.scalar_type() == query.scalar_type(), "sin dtype mismatch"); + TORCH_CHECK(cos.is_cuda() && sin.is_cuda() && query.is_cuda(), "All tensors must be on CUDA"); if (key.has_value()) { TORCH_CHECK(key->is_cuda(), "Key tensor must be on CUDA if provided"); TORCH_CHECK(key->scalar_type() == query.scalar_type(), "Key dtype mismatch"); @@ -221,12 +192,7 @@ void rotary_embedding( } TORCH_CHECK(num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); - // This rot_dim_from_cache is what's passed to the kernel as rot_dim_arg. - // Per your clarification, this is effectively head_size. - int rot_dim_from_cache = (int)cos_cache.size(1); - // The check `rot_dim <= head_size` is still generally good. - // If rot_dim_from_cache is indeed head_size, then this becomes `head_size <= head_size`. - // TORCH_CHECK(rot_dim_from_cache <= head_size, "rot_dim from cache must be <= head_size of q/k"); + int rot_dim_from_cache = (int)cos.size(1); int64_t query_token_stride = query_hidden_size_calculated; int64_t key_token_stride = key.has_value() ? key_hidden_size_calculated : 0; @@ -248,11 +214,12 @@ void rotary_embedding( } dim3 grid((int)num_tokens); - // embed_dim_for_block_calc is head_size / 2 if rot_dim_from_cache is head_size + int embed_dim_for_block_calc = rot_dim_from_cache / 2; int max_pairs_to_rotate_per_token = std::max(num_heads * embed_dim_for_block_calc, num_kv_heads * embed_dim_for_block_calc); dim3 block(std::min(max_pairs_to_rotate_per_token, 512L)); + if (block.x == 0 && num_tokens > 0) block.x = 1; const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); @@ -261,33 +228,32 @@ void rotary_embedding( SGLANG_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { if (is_neox) { rotary_embedding_kernel<<>>( - cos_cache.data_ptr(), - sin_cache.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - rot_dim_from_cache, // Pass the dimension from cos_cache + rot_dim_from_cache, query_token_stride, key_token_stride, head_stride_query, head_stride_key, num_heads, num_kv_heads, - (int)head_size); // Pass the actual head_size of q/k + (int)head_size); } else { rotary_embedding_kernel<<>>( - cos_cache.data_ptr(), - sin_cache.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, - rot_dim_from_cache, // Pass the dimension from cos_cache + rot_dim_from_cache, query_token_stride, key_token_stride, head_stride_query, head_stride_key, num_heads, num_kv_heads, - (int)head_size); // Pass the actual head_size of q/k + (int)head_size); } }); - // C10_CUDA_KERNEL_LAUNCH_CHECK(); } From 2b958b3315c409dad5766459b7ec1a29e9b625f1 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 22 May 2025 14:35:20 +0800 Subject: [PATCH 06/11] update test --- python/sglang/srt/layers/attention/vision.py | 3 +- python/sglang/srt/layers/rotary_embedding.py | 2 +- sgl-kernel/tests/test_mm_rotary_embedding.py | 61 ++++++++++++++++++-- test/srt/models/test_vlm_models.py | 23 +++++--- 4 files changed, 72 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 9e4d934ea98b..ced68788cecb 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -151,8 +151,6 @@ def forward( Returns: [b * s, h, head_size] """ - if self.flatten_batch: - assert bsz == 1, "flatten_batch is True, bsz must be 1" assert q.dim() == 3, q.shape @@ -452,6 +450,7 @@ def forward( q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) + print(f"{q.shape=}") q, k = self.rotary_emb(position_embeddings, q, k) q = q.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f372873877f2..fbd5e0a903f6 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -171,7 +171,7 @@ def forward_cuda( orig_q_dtype = query.dtype orig_k_dtype = key.dtype cos, sin = positions - cos, sin = cos.float(), sin.float() + cos, sin = cos.float().contiguous(), sin.float().contiguous() query, key = query.float(), key.float() self.sglang_rotary_embedding( cos, diff --git a/sgl-kernel/tests/test_mm_rotary_embedding.py b/sgl-kernel/tests/test_mm_rotary_embedding.py index 2ace678578d4..ae0f30cfd6dc 100644 --- a/sgl-kernel/tests/test_mm_rotary_embedding.py +++ b/sgl-kernel/tests/test_mm_rotary_embedding.py @@ -9,7 +9,7 @@ def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -117,10 +117,13 @@ def forward_kernel_inplace( return query, key +@pytest.mark.benchmark(group="rotary_embedding") @pytest.mark.parametrize( "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", [ (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 32, 32, 16, 16), + (320, 230, 1e6, 1e6, False, torch.bfloat16, "cuda", 32, 32, 16, 16), + (80, 80, 1e6, 1e6, True, torch.bfloat16, "cuda", 32, 32, 16, 16), ], ) def test_correctness( @@ -156,10 +159,6 @@ def test_correctness( cos = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) sin = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) - print(f"{query.shape=}") - print(f"{cos.shape=}") - - # Modification: float32 is required for the rotary embedding to work correctly query_native_out, key_native_out = rope_ref.forward_native( cos, sin, query.clone(), key.clone() ) @@ -173,5 +172,57 @@ def test_correctness( torch.testing.assert_close(key_native_out, key_kernel_out, atol=1e-3, rtol=1e-3) +@pytest.mark.benchmark(group="rotary_embedding") +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 1, 8840, 16, 16), + (80, 80, 1e6, 1e6, False, torch.bfloat16, "cuda", 1, 4000, 16, 16), + (80, 80, 1e6, 1e6, True, torch.bfloat16, "cuda", 8, 8840, 16, 16), + ], +) +def test_rotary_embedding_benchmark( + benchmark, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, + rotary_dim, + max_position_embeddings, + base, + dtype=dtype, + is_neox_style=is_neox_style, + ).to(device) + query = torch.randn( + batch_size * seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads, head_size, dtype=dtype, device=device + ) + cos = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + sin = torch.randn(batch_size * seq_len, head_size, dtype=dtype, device=device) + + def run_kernel(): + rope_ref.forward_kernel_inplace( + cos, + sin, + query, + key, + ) + torch.cuda.synchronize() + + benchmark.pedantic(run_kernel, rounds=20000, warmup_rounds=5) + + if __name__ == "__main__": pytest.main([__file__, "--capture=no"]) diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index ec9b672d4ed8..de9e43490831 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -27,18 +27,23 @@ # VLM models for testing MODELS = [ + # SimpleNamespace( + # model="google/gemma-3-4b-it", chat_template="gemma-it", mmmu_accuracy=0.384 + # ), SimpleNamespace( - model="google/gemma-3-4b-it", chat_template="gemma-it", mmmu_accuracy=0.384 - ), - SimpleNamespace( - model="Qwen/Qwen2.5-VL-3B-Instruct", + model="Qwen/Qwen2-VL-7B-Instruct", mmmu_accuracy=0.466, ), - SimpleNamespace( - model="openbmb/MiniCPM-V-2_6", - chat_template="minicpmv", - mmmu_accuracy=0.3867, - ), + # SimpleNamespace( + # model="Qwen/Qwen2.5-VL-3B-Instruct", + # mmmu_accuracy=0.466, + # ), + # + # SimpleNamespace( + # model="openbmb/MiniCPM-V-2_6", + # chat_template="minicpmv", + # mmmu_accuracy=0.3867, + # ), ] From 226c2109f0839b77e7ae615019add488658c3a2b Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 22 May 2025 14:35:34 +0800 Subject: [PATCH 07/11] update test --- sgl-kernel/tests/test_mm_rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/tests/test_mm_rotary_embedding.py b/sgl-kernel/tests/test_mm_rotary_embedding.py index ae0f30cfd6dc..14d23ced72fe 100644 --- a/sgl-kernel/tests/test_mm_rotary_embedding.py +++ b/sgl-kernel/tests/test_mm_rotary_embedding.py @@ -9,7 +9,7 @@ def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) From 68b68448fba8152466dbbfac82a310fbd0922fdc Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 22 May 2025 14:35:46 +0800 Subject: [PATCH 08/11] update kernel with sram --- python/sglang/srt/layers/attention/vision.py | 1 - .../csrc/multimodal/rotary_embedding.cu | 38 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index ced68788cecb..280e17a6aebb 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -450,7 +450,6 @@ def forward( q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - print(f"{q.shape=}") q, k = self.rotary_emb(position_embeddings, q, k) q = q.view(original_shape) diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu index 105c58e99bc5..613c402e78c2 100644 --- a/sgl-kernel/csrc/multimodal/rotary_embedding.cu +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -37,8 +37,8 @@ inline __device__ void apply_token_rotary_embedding( x_index = rot_offset; y_index = embed_dim + rot_offset; - scalar_t cos_val = SGLANG_LDG(cos_ptr + rot_offset); - scalar_t sin_val = SGLANG_LDG(sin_ptr + rot_offset); + scalar_t cos_val = cos_ptr[rot_offset]; + scalar_t sin_val = sin_ptr[rot_offset]; const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; @@ -47,13 +47,13 @@ inline __device__ void apply_token_rotary_embedding( } else { // GPT-J style / LLaMA style, matching the Python if cos/sin are [..., head_size] - x_index = rot_offset; // first half - y_index = rot_offset + embed_dim; // second half + x_index = rot_offset; + y_index = rot_offset + embed_dim; - const scalar_t cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); - const scalar_t sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); - const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); - const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); + const scalar_t cos_val_x = cos_ptr[rot_offset]; + const scalar_t sin_val_x = sin_ptr[rot_offset]; + const scalar_t cos_val_y = cos_ptr[rot_offset + embed_dim]; + const scalar_t sin_val_y = sin_ptr[rot_offset + embed_dim]; const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; @@ -78,8 +78,8 @@ inline __device__ void apply_rotary_embedding( const int nq_pairs = num_heads * embed_dim_for_rotation; for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { - const int head_idx = i / embed_dim_for_rotation; const int rot_offset = i % embed_dim_for_rotation; + const int head_idx = i / embed_dim_for_rotation; scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; @@ -115,18 +115,29 @@ __global__ void rotary_embedding_kernel( const int num_heads, const int num_kv_heads, const int head_size) { + extern __shared__ char smem[]; + scalar_t* shared_cos = reinterpret_cast(smem); + scalar_t* shared_sin = shared_cos + rot_dim_arg; + const int token_idx = blockIdx.x; const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim_arg; const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim_arg; + for (int i = threadIdx.x; i < rot_dim_arg; i += blockDim.x) { + shared_cos[i] = current_token_cos_ptr[i]; + shared_sin[i] = current_token_sin_ptr[i]; + } + + __syncthreads(); + scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; apply_rotary_embedding( query_for_token, key_for_token, - current_token_cos_ptr, - current_token_sin_ptr, + shared_cos, + shared_sin, head_size, num_heads, num_kv_heads, @@ -226,8 +237,9 @@ void rotary_embedding( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); SGLANG_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + const size_t shared_mem_size = 2 * rot_dim_from_cache * sizeof(scalar_t); if (is_neox) { - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( cos.data_ptr(), sin.data_ptr(), query.data_ptr(), @@ -241,7 +253,7 @@ void rotary_embedding( num_kv_heads, (int)head_size); } else { - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( cos.data_ptr(), sin.data_ptr(), query.data_ptr(), From a783bbbcc031742c78ea47d7a837efe843b429b6 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 22 May 2025 19:14:53 +0800 Subject: [PATCH 09/11] reduce dynamic cast --- python/sglang/srt/layers/rotary_embedding.py | 7 +++++-- python/sglang/srt/models/qwen2_vl.py | 6 ++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index fbd5e0a903f6..67bd574c2989 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -168,11 +168,14 @@ def forward_cuda( is_neox=self.is_neox_style, ) else: + + cos, sin = positions + assert cos.dtype == torch.float and cos.is_contiguous() + assert sin.dtype == torch.float and sin.is_contiguous() orig_q_dtype = query.dtype orig_k_dtype = key.dtype - cos, sin = positions - cos, sin = cos.float().contiguous(), sin.float().contiguous() query, key = query.float(), key.float() + self.sglang_rotary_embedding( cos, sin, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index b4421290edea..ad2ce4e7f915 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -402,7 +402,10 @@ def forward( # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) + position_embeddings = ( + emb.cos().float().contiguous(), + emb.sin().float().contiguous(), + ) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -413,7 +416,6 @@ def forward( x = x.unsqueeze(1) for blk in self.blocks: x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) - # adapter x = self.merger(x) return x From de64889d5c838a24367943363e96a95f75c61310 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 22 May 2025 20:02:59 +0800 Subject: [PATCH 10/11] remove python parts --- python/sglang/srt/_custom_ops.py | 3 +- python/sglang/srt/layers/attention/vision.py | 16 +-- python/sglang/srt/layers/rotary_embedding.py | 38 ++---- python/sglang/srt/models/qwen2_5_vl.py | 2 +- python/sglang/srt/models/qwen2_vl.py | 6 +- .../csrc/multimodal/rotary_embedding.cu | 38 ++---- test/srt/models/test_vlm_models.py | 125 +++--------------- 7 files changed, 56 insertions(+), 172 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index d506fbb7342e..07c087bf6c42 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import logging -from typing import List, Optional, Tuple +from typing import List, Tuple import torch @@ -24,6 +24,7 @@ except ImportError as e: logger.warning("Failed to import from custom_ar with %r", e) + if not is_hip(): if use_vllm_custom_allreduce: custom_op = torch.ops._C_custom_ar diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 280e17a6aebb..f1f45e27ab96 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -27,7 +27,7 @@ RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.rotary_embedding import RotaryEmbedding, apply_rotary_pos_emb +from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import add_prefix, logger @@ -151,6 +151,8 @@ def forward( Returns: [b * s, h, head_size] """ + if self.flatten_batch: + assert bsz == 1, "flatten_batch is True, bsz must be 1" assert q.dim() == 3, q.shape @@ -359,15 +361,6 @@ def __init__( softmax_in_single_precision=softmax_in_single_precision, ) - self.rotary_emb = RotaryEmbedding( - head_size=self.head_size, - rotary_dim=self.head_size, - max_position_embeddings=2048, - base=10000, - is_neox_style=False, - dtype=torch.get_default_dtype(), - ) - self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: self.qkv_proj = QKVParallelLinear( @@ -445,12 +438,13 @@ def forward( ] if position_embeddings is not None: + cos, sin = position_embeddings original_shape = q.shape # [total_tokens, head, head_size] q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - q, k = self.rotary_emb(position_embeddings, q, k) + q, k = apply_rotary_pos_emb(q, k, cos, sin) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 67bd574c2989..c5c285ca0fc4 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -84,9 +84,9 @@ def __init__( cache = cache.to(dtype) if not _is_cuda or self.head_size not in [64, 128, 256, 512]: - from sgl_kernel import rotary_embedding + from vllm._custom_ops import rotary_embedding - self.sglang_rotary_embedding = rotary_embedding + self.vllm_rotary_embedding = rotary_embedding self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -118,7 +118,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def forward_native( self, - positions: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, @@ -126,15 +126,10 @@ def forward_native( """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets - - if isinstance(positions, torch.Tensor): - positions = positions.flatten() - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - num_tokens = positions.shape[0] - else: - cos, sin = positions - num_tokens = cos.shape[0] + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -153,7 +148,7 @@ def forward_native( def forward_cuda( self, - positions: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, @@ -168,24 +163,15 @@ def forward_cuda( is_neox=self.is_neox_style, ) else: - - cos, sin = positions - assert cos.dtype == torch.float and cos.is_contiguous() - assert sin.dtype == torch.float and sin.is_contiguous() - orig_q_dtype = query.dtype - orig_k_dtype = key.dtype - query, key = query.float(), key.float() - - self.sglang_rotary_embedding( - cos, - sin, + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + self.vllm_rotary_embedding( + positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style, ) - query = query.to(dtype=orig_q_dtype) - key = key.to(dtype=orig_k_dtype) return query, key def extra_repr(self) -> str: diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 0f28002abdbe..420216c7bb0d 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -398,7 +398,7 @@ def forward( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1).float() + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index ad2ce4e7f915..b4421290edea 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -402,10 +402,7 @@ def forward( # compute position embedding rotary_pos_emb = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = ( - emb.cos().float().contiguous(), - emb.sin().float().contiguous(), - ) + position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens cu_seqlens = torch.repeat_interleave( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] @@ -416,6 +413,7 @@ def forward( x = x.unsqueeze(1) for blk in self.blocks: x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + # adapter x = self.merger(x) return x diff --git a/sgl-kernel/csrc/multimodal/rotary_embedding.cu b/sgl-kernel/csrc/multimodal/rotary_embedding.cu index 613c402e78c2..105c58e99bc5 100644 --- a/sgl-kernel/csrc/multimodal/rotary_embedding.cu +++ b/sgl-kernel/csrc/multimodal/rotary_embedding.cu @@ -37,8 +37,8 @@ inline __device__ void apply_token_rotary_embedding( x_index = rot_offset; y_index = embed_dim + rot_offset; - scalar_t cos_val = cos_ptr[rot_offset]; - scalar_t sin_val = sin_ptr[rot_offset]; + scalar_t cos_val = SGLANG_LDG(cos_ptr + rot_offset); + scalar_t sin_val = SGLANG_LDG(sin_ptr + rot_offset); const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; @@ -47,13 +47,13 @@ inline __device__ void apply_token_rotary_embedding( } else { // GPT-J style / LLaMA style, matching the Python if cos/sin are [..., head_size] - x_index = rot_offset; - y_index = rot_offset + embed_dim; + x_index = rot_offset; // first half + y_index = rot_offset + embed_dim; // second half - const scalar_t cos_val_x = cos_ptr[rot_offset]; - const scalar_t sin_val_x = sin_ptr[rot_offset]; - const scalar_t cos_val_y = cos_ptr[rot_offset + embed_dim]; - const scalar_t sin_val_y = sin_ptr[rot_offset + embed_dim]; + const scalar_t cos_val_x = SGLANG_LDG(cos_ptr + rot_offset); + const scalar_t sin_val_x = SGLANG_LDG(sin_ptr + rot_offset); + const scalar_t cos_val_y = SGLANG_LDG(cos_ptr + rot_offset + embed_dim); + const scalar_t sin_val_y = SGLANG_LDG(sin_ptr + rot_offset + embed_dim); const scalar_t x = arr[x_index]; const scalar_t y = arr[y_index]; @@ -78,8 +78,8 @@ inline __device__ void apply_rotary_embedding( const int nq_pairs = num_heads * embed_dim_for_rotation; for (int i = threadIdx.x; i < nq_pairs; i += blockDim.x) { - const int rot_offset = i % embed_dim_for_rotation; const int head_idx = i / embed_dim_for_rotation; + const int rot_offset = i % embed_dim_for_rotation; scalar_t* query_for_token_head = query + head_idx * (int)head_stride_query; @@ -115,29 +115,18 @@ __global__ void rotary_embedding_kernel( const int num_heads, const int num_kv_heads, const int head_size) { - extern __shared__ char smem[]; - scalar_t* shared_cos = reinterpret_cast(smem); - scalar_t* shared_sin = shared_cos + rot_dim_arg; - const int token_idx = blockIdx.x; const scalar_t* current_token_cos_ptr = cos_data + token_idx * rot_dim_arg; const scalar_t* current_token_sin_ptr = sin_data + token_idx * rot_dim_arg; - for (int i = threadIdx.x; i < rot_dim_arg; i += blockDim.x) { - shared_cos[i] = current_token_cos_ptr[i]; - shared_sin[i] = current_token_sin_ptr[i]; - } - - __syncthreads(); - scalar_t* query_for_token = query_total + token_idx * (int)query_token_stride; scalar_t* key_for_token = (key_total != nullptr) ? (key_total + token_idx * (int)key_token_stride) : nullptr; apply_rotary_embedding( query_for_token, key_for_token, - shared_cos, - shared_sin, + current_token_cos_ptr, + current_token_sin_ptr, head_size, num_heads, num_kv_heads, @@ -237,9 +226,8 @@ void rotary_embedding( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); SGLANG_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { - const size_t shared_mem_size = 2 * rot_dim_from_cache * sizeof(scalar_t); if (is_neox) { - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( cos.data_ptr(), sin.data_ptr(), query.data_ptr(), @@ -253,7 +241,7 @@ void rotary_embedding( num_kv_heads, (int)head_size); } else { - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( cos.data_ptr(), sin.data_ptr(), query.data_ptr(), diff --git a/test/srt/models/test_vlm_models.py b/test/srt/models/test_vlm_models.py index de9e43490831..c55e98da2272 100644 --- a/test/srt/models/test_vlm_models.py +++ b/test/srt/models/test_vlm_models.py @@ -1,53 +1,34 @@ -""" - python test_vlm_models.py --batch-size 1 -""" - import argparse import glob import json -import logging import os import random import subprocess import sys -import time import unittest -from collections import defaultdict from types import SimpleNamespace -from typing import Optional -from sglang.bench_serving import async_request_profile from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + CustomTestCase, is_in_ci, popen_launch_server, ) # VLM models for testing MODELS = [ - # SimpleNamespace( - # model="google/gemma-3-4b-it", chat_template="gemma-it", mmmu_accuracy=0.384 - # ), + SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45), SimpleNamespace( - model="Qwen/Qwen2-VL-7B-Instruct", - mmmu_accuracy=0.466, + model="Qwen/Qwen2.5-VL-3B-Instruct", + mmmu_accuracy=0.4, ), - # SimpleNamespace( - # model="Qwen/Qwen2.5-VL-3B-Instruct", - # mmmu_accuracy=0.466, - # ), - # - # SimpleNamespace( - # model="openbmb/MiniCPM-V-2_6", - # chat_template="minicpmv", - # mmmu_accuracy=0.3867, - # ), + SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4), ] -class TestVLMModels(unittest.IsolatedAsyncioTestCase): +class TestVLMModels(CustomTestCase): parsed_args = None # Class variable to store args @classmethod @@ -60,25 +41,11 @@ def setUpClass(cls): # Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work. os.environ["OPENAI_API_KEY"] = cls.api_key os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1" - cmd = ["python3", "-m", "pip", "show", "lmms_eval"] - - ret = subprocess.run( - cmd, - timeout=3600, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - assert ( - ret.returncode == 0 - ), "please install lmms_eval by `pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git`" def run_mmmu_eval( self, model_version: str, - batch_size: int, output_path: str, - limit: Optional[str] = None, *, env: dict | None = None, ): @@ -91,6 +58,7 @@ def run_mmmu_eval( model = "openai_compatible" tp = 1 tasks = "mmmu_val" + batch_size = 2 log_suffix = "openai_compatible" os.makedirs(output_path, exist_ok=True) @@ -117,36 +85,32 @@ def run_mmmu_eval( str(output_path), ] - if limit is not None: - cmd += [ - "--limit", - limit, - ] - subprocess.run( cmd, check=True, timeout=3600, ) - async def test_vlm_mmmu_benchmark(self): + def test_vlm_mmmu_benchmark(self): """Test VLM models against MMMU benchmark.""" models_to_test = MODELS if is_in_ci(): models_to_test = [random.choice(MODELS)] - results = defaultdict(dict) + for model in models_to_test: print(f"\nTesting model: {model.model}") process = None mmmu_accuracy = 0 # Initialize to handle potential exceptions + try: # Launch server for testing process = popen_launch_server( model.model, base_url=self.base_url, timeout=self.time_out, + api_key=self.api_key, other_args=[ "--trust-remote-code", "--cuda-graph-max-bs", @@ -157,38 +121,11 @@ async def test_vlm_mmmu_benchmark(self): ], ) - if args.profile: - print("Starting profiler...") - profile_output = await async_request_profile( - api_url=self.base_url + "/start_profile" - ) - if profile_output.success: - print("Profiler started") - # Run evaluation - self.run_mmmu_eval( - model.model, - self.parsed_args.batch_size, - "./logs", - limit=str(1) if self.parsed_args.profile else None, - ) - - if args.profile: - profile_output = await async_request_profile( - api_url=self.base_url + "/stop_profile" - ) - if profile_output.success: - print("Profiler stopped") - print( - "You should kill the process manually until profiling actually stopped" - ) - while True: - time.sleep(10) + self.run_mmmu_eval(model.model, "./logs") # Get the result file - files = glob.glob("./logs/*.json") - - result_file_path = max(files, key=os.path.getmtime) + result_file_path = glob.glob("./logs/*.json")[0] with open(result_file_path, "r") as f: result = json.load(f) @@ -196,23 +133,19 @@ async def test_vlm_mmmu_benchmark(self): # Process the result mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] print(f"Model {model.model} achieved accuracy: {mmmu_accuracy:.4f}") - print(f"Evaluation time:", result["total_evaluation_time_seconds"]) - results[model.model] = { - "accu": mmmu_accuracy, - "time": result["total_evaluation_time_seconds"], - } + # Assert performance meets expected threshold - # self.assertGreaterEqual( - # mmmu_accuracy, - # model.mmmu_accuracy, - # f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})", - # ) + self.assertGreaterEqual( + mmmu_accuracy, + model.mmmu_accuracy, + f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f})", + ) + except Exception as e: print(f"Error testing {model.model}: {e}") self.fail(f"Test failed for {model.model}: {e}") finally: - print(json.dumps(dict(results), indent=4)) # Ensure process cleanup happens regardless of success/failure if process is not None and process.poll() is None: print(f"Cleaning up process {process.pid}") @@ -220,7 +153,6 @@ async def test_vlm_mmmu_benchmark(self): kill_process_tree(process.pid) except Exception as e: print(f"Error killing process: {e}") - print(json.dumps(dict(results), indent=4)) if __name__ == "__main__": @@ -232,25 +164,10 @@ async def test_vlm_mmmu_benchmark(self): help="Static memory fraction for the model", default=0.8, ) - parser.add_argument( - "--batch-size", - type=int, - default=1, - ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler", - default=False, - ) + # Parse args intended for unittest args = parser.parse_args() - if args.profile: - log_level = os.getenv("LOG_LEVEL", "WARNING").upper() - logging.basicConfig(level="INFO") - # Store the parsed args object on the class TestVLMModels.parsed_args = args From 4094c4cd50b9054781258ee9104a14c61d98483b Mon Sep 17 00:00:00 2001 From: "Xiang (Kevin) Li" Date: Fri, 5 Sep 2025 13:43:50 -0700 Subject: [PATCH 11/11] Support SGL rotary_embedding kernel in RotaryEmbedding and by extension VisionAttention --- python/sglang/srt/layers/attention/vision.py | 21 +++++++++++++---- python/sglang/srt/layers/rotary_embedding.py | 24 ++++++++++++++------ 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 2be3e450b2d4..4bc4fcfe31f4 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -39,7 +39,7 @@ RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig -from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb +from sglang.srt.layers.rotary_embedding import RotaryEmbedding, apply_rotary_pos_emb from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import add_prefix @@ -428,6 +428,15 @@ def __init__( softmax_in_single_precision=softmax_in_single_precision, ) + self.rotary_emb = RotaryEmbedding( + head_size=self.head_size, + rotary_dim=self.head_size, + max_position_embeddings=2048, + base=10000, + is_neox_style=False, + dtype=torch.get_default_dtype(), + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: self.qkv_proj = QKVParallelLinear( @@ -568,13 +577,17 @@ def forward( q = q.view(original_shape) k = k.view(original_shape) else: - cos, sin = position_embeddings - # [total_tokens, head, head_size] q = q.view(-1, head, self.head_size) k = k.view(-1, head, self.head_size) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + (cos, sin) = position_embeddings + position_embeddings = ( + cos.float().contiguous(), + sin.float().contiguous(), + ) + + q, k = self.rotary_emb(position_embeddings, q, k) q = q.view(original_shape) k = k.view(original_shape) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 05f06855725a..09fb6fd48f7d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -104,9 +104,9 @@ def __init__( if ( not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512] ) and not (_is_cpu and _is_cpu_amx_available): - from vllm._custom_ops import rotary_embedding + from sgl_kernel import rotary_embedding - self.vllm_rotary_embedding = rotary_embedding + self.sglang_rotary_embedding = rotary_embedding self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -242,16 +242,26 @@ def forward_cuda( else: assert ( fused_set_kv_buffer_arg is None - ), "save kv cache is not supported for vllm_rotary_embedding." - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.vllm_rotary_embedding( - positions, + ), "save kv cache is not supported for sglang_rotary_embedding." + + cos, sin = positions + assert cos.dtype == torch.float and cos.is_contiguous() + assert sin.dtype == torch.float and sin.is_contiguous() + orig_q_dtype = query.dtype + orig_k_dtype = key.dtype + query, key = query.float(), key.float() + + self.sglang_rotary_embedding( + cos, + sin, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style, ) + + query = query.to(dtype=orig_q_dtype) + key = key.to(dtype=orig_k_dtype) return query, key def extra_repr(self) -> str: