diffusion: rotary embedding kernel#14302
Conversation
… benchmarking infrastructure.
Summary of ChangesHello @RubiaCx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the SGLang framework by integrating a new, highly performant CUDA kernel for rotary embeddings. This optimization is critical for diffusion models and other applications heavily relying on RoPE, offering improved speed and broader compatibility with diverse model architectures and head sizes, including previously challenging non-standard configurations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a new rotary embedding CUDA kernel, which is a significant improvement for diffusion models. The implementation supports various layouts and head sizes, and benchmarks show promising speedups over existing solutions. The integration into the sgl-kernel and Python API is well-handled, and comprehensive tests ensure correctness across a wide range of configurations. The overall structure and approach are solid, demonstrating a good understanding of performance optimization for deep learning workloads.
| except torch.AcceleratorError: | ||
| torch.cuda.empty_cache() | ||
| pass |
There was a problem hiding this comment.
The try-except torch.AcceleratorError block around torch.cuda.synchronize() is unusual. torch.cuda.synchronize() typically doesn't raise torch.AcceleratorError. If the intent is to clear the cache after a potential CUDA error from a previous operation, it might be clearer to structure this differently, or to ensure that torch.cuda.synchronize() is only called when the CUDA context is valid. However, given the context of benchmarking, this might be a defensive measure.
| def compute_cos_sin_cache( | ||
| max_seq_len: int, | ||
| rotary_dim: int, | ||
| base: float = 10000.0, | ||
| dtype: torch.dtype = torch.float32, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Compute separate cos and sin caches. | ||
| """ | ||
| inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim)) | ||
| t = torch.arange(max_seq_len, dtype=torch.float32) | ||
| freqs = torch.einsum("i,j->ij", t, inv_freq) | ||
| cos = freqs.cos().to(dtype) | ||
| sin = freqs.sin().to(dtype) | ||
| return cos, sin |
There was a problem hiding this comment.
# Conflicts: # sgl-kernel/csrc/common_extension.cc
| x_index = 2 * rot_offset; | ||
| y_index = x_index + 1; | ||
|
|
||
| const float cos_val = static_cast<float>(SGLANG_LDG(cos_ptr + rot_offset)); |
There was a problem hiding this comment.
Can we use float4 to vectorize global memory read?
|
nits: If |
|
@BBuf I addressed the low-num_tokens under-utilization by updating the diffusion RoPE kernel.
Results (Qwen-image): average kernel time improved from 29.551 µs to 24.557 µs (1.203× speedup). Additionally, compared to the positions API (FlashInfer-based path), the diffusion RoPE kernel path is ~1.24× faster on average (weighted across configs). |
|
@BBuf Hi, #16287 already fixed the AMD multimodal CI failures introduced by the refactor in #15812 / #15813 by restoring a proper normalization fallback on ROCm. I removed my local fallback implementation to avoid duplicating logic and diverging from main. Separately, I added a FlashInfer RoPE benchmark comparison. FlashInfer currently supports only a limited set of shapes (e.g., it errors on head_dim=96, so those entries are NaN), and in the supported cases it is not faster than our implementations. |
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
Signed-off-by: Ther-LF <2639852836@qq.com>
|
Please fix the lint. |
|
Hi @DarkSharpness. Thanks for the review. Besides addressing the API/dispatch issues, I also updated the benchmark and tests for better correctness and fairness. |
python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py
Outdated
Show resolved
Hide resolved
python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py
Outdated
Show resolved
Hide resolved
|
/tag-and-rerun-ci |
| return __float2half_rn(v); | ||
| } | ||
| template <> | ||
| __device__ __forceinline__ nv_bfloat16 from_float<nv_bfloat16>(float v) { |
There was a problem hiding this comment.
In #16884 we introduce type.cuh, which include some type conversion primitives in namespace device. You may try to replace these type casting with cast<To>(value).
| union VecU { | ||
| float4 v; | ||
| scalar_t e[kElePerVec]; | ||
| } tmp; |
There was a problem hiding this comment.
Try to replace this with device::AlignedVector. This can also generate some effective aligned load/store instructions (just like float4).
|
|
||
| // RoPE kernel with 1D grid: one block per token. | ||
| template <typename scalar_t, bool interleaved, bool vectorized, bool aligned_qk, int ROT_EMBED_DIM> | ||
| __launch_bounds__(512) __global__ void rotary_embedding_kernel_1d( |
There was a problem hiding this comment.
I'm not sure about the performance here. In my qknorm implementation, we use a persistent 2d kernel, where each warp handle one qk head dimension at a time. The code logic is as follows:
- Each warp determines which q head or k head it will process at this iteration.
- Set up the pointer information
sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh
Lines 54 to 59 in e7df8bd
- Load and perform algorithm on the input
sglang/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh
Lines 60 to 63 in e7df8bd
This can hopefully reduce code size and improve readability. You may refer to this as an example. I would suggest you take a similar approach (persistent 2d kernel + unified code logic) if it won't bring too much regression to performance.
| if ((reinterpret_cast<uintptr_t>(sin.data_ptr()) % kVecBytes) != 0) can_vec_compute = false; | ||
| if (((r * elem_bytes) % kVecBytes) != 0) can_vec_compute = false; | ||
|
|
||
| bool qk_aligned16 = true; |
There was a problem hiding this comment.
Do we really have to consider the unaligned cases? In most cases, I would expect the head bytes to be at least 64 bytes aligned. Do not optimize too early for rare cases.
| const int kElePerVec = kVecBytes / elem_bytes; | ||
| const int pairs_per_step = interleaved ? (kElePerVec / 2) : kElePerVec; | ||
|
|
||
| bool can_vec_compute = true; |
There was a problem hiding this comment.
Do we really have to consider the unaligned cases? In most cases, I would expect the head bytes to be at least 64 bytes aligned. Do not optimize too early for rare cases.






Motivation
sgl-kernel.Modifications
sgl-kernel/csrc/multimodal/rotary_embedding.cu, supporting:cos/sincaches, NeoX and non‑NeoX layouts, 2D/3D Q/K, and GQA/MQA.sgl-kernel:include/sgl_kernel_ops.handcsrc/common_extension.cc, add toCMakeLists.txt.python/sgl_kernel/rotary_embedding.pyand export frompython/sgl_kernel/__init__.py.tests/test_mm_rotary_embedding.pyandbenchmark/bench_mm_rotary_embedding.py.Accuracy Tests
tests/test_mm_rotary_embedding.py:(batch_size, seq_len, num_heads, num_kv_heads, head_size, dtype)combinations.Benchmarking and Profiling
benchmark/bench_mm_rotary_embedding.py:batch_size ∈ {1, 32},heads ∈ {32/8, 64/8, 8/8, 32/1},head_size ∈ {64, 80, 128, 256}), the new kernel achieves roughly 1.05–1.25× speedup over the baseline and up to 2×+ over FlashAttention’s rotary.Checklist