Skip to content

diffusion: rotary embedding kernel#14302

Open
RubiaCx wants to merge 45 commits intosgl-project:mainfrom
RubiaCx:cx/rotary
Open

diffusion: rotary embedding kernel#14302
RubiaCx wants to merge 45 commits intosgl-project:mainfrom
RubiaCx:cx/rotary

Conversation

@RubiaCx
Copy link
Copy Markdown
Collaborator

@RubiaCx RubiaCx commented Dec 2, 2025

Motivation

  • close diffusion: rotary embedding kernel #12985
  • Diffusion models rely heavily on RoPE in both text and vision branches. This PR adds a rotary embedding CUDA kernel to sgl-kernel.
  • Previous kernels had limited support for non‑standard head sizes (e.g., 80). This work follows the direction of the fast rotary embedding work in #10527, and generalizes earlier rotary kernel efforts in #6530 to cover more layouts and Diffusion‑style use cases.

Modifications

  • Add a generic rotary embedding CUDA kernel in sgl-kernel/csrc/multimodal/rotary_embedding.cu, supporting:
    • Separate cos / sin caches, NeoX and non‑NeoX layouts, 2D/3D Q/K, and GQA/MQA.
  • Wire up the kernel into sgl-kernel:
    • Register C++ op in include/sgl_kernel_ops.h and csrc/common_extension.cc, add to CMakeLists.txt.
    • Expose a Python API via python/sgl_kernel/rotary_embedding.py and export from python/sgl_kernel/__init__.py.
  • Add/update tests and benchmarks:
    • Add tests/test_mm_rotary_embedding.py and benchmark/bench_mm_rotary_embedding.py.

Accuracy Tests

  • tests/test_mm_rotary_embedding.py:
    • Parametrized coverage over multiple (batch_size, seq_len, num_heads, num_kv_heads, head_size, dtype) combinations.
    • Compare Q/K outputs against a float32 reference implementation; all cases pass within the expected tolerance.

Benchmarking and Profiling

  • benchmark/bench_mm_rotary_embedding.py:
    • Benchmarks our kernel vs. vLLM implementation and FlashAttention’s rotary path.
    • On representative configs (e.g., batch_size ∈ {1, 32}, heads ∈ {32/8, 64/8, 8/8, 32/1}, head_size ∈ {64, 80, 128, 256}), the new kernel achieves roughly 1.05–1.25× speedup over the baseline and up to 2×+ over FlashAttention’s rotary.
  • There are potential follow‑up optimizations can be introduced in future PRs without changing the Python API.

Checklist

Ruby Chen (Engrg-Hardware 1) and others added 2 commits December 2, 2025 07:02
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @RubiaCx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang framework by integrating a new, highly performant CUDA kernel for rotary embeddings. This optimization is critical for diffusion models and other applications heavily relying on RoPE, offering improved speed and broader compatibility with diverse model architectures and head sizes, including previously challenging non-standard configurations.

Highlights

  • New Rotary Embedding CUDA Kernel: Introduced a highly optimized CUDA kernel for rotary embedding within sgl-kernel to improve performance in diffusion models.
  • Broadened Support: The new kernel supports various configurations including separate cos/sin caches, NeoX and non-NeoX layouts, 2D/3D Query/Key tensors, and Grouped-Query Attention (GQA)/Multi-Query Attention (MQA), addressing limitations with non-standard head sizes.
  • Significant Performance Gains: Achieves 1.05–1.25x speedup over vLLM's implementation and up to 2x+ over FlashAttention's rotary path on representative configurations.
  • Comprehensive Testing & Benchmarking: Includes new accuracy tests comparing against a float32 reference implementation and detailed benchmarks to validate performance and correctness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new rotary embedding CUDA kernel, which is a significant improvement for diffusion models. The implementation supports various layouts and head sizes, and benchmarks show promising speedups over existing solutions. The integration into the sgl-kernel and Python API is well-handled, and comprehensive tests ensure correctness across a wide range of configurations. The overall structure and approach are solid, demonstrating a good understanding of performance optimization for deep learning workloads.

Comment on lines +77 to +79
except torch.AcceleratorError:
torch.cuda.empty_cache()
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try-except torch.AcceleratorError block around torch.cuda.synchronize() is unusual. torch.cuda.synchronize() typically doesn't raise torch.AcceleratorError. If the intent is to clear the cache after a potential CUDA error from a previous operation, it might be clearer to structure this differently, or to ensure that torch.cuda.synchronize() is only called when the CUDA context is valid. However, given the context of benchmarking, this might be a defensive measure.

Comment on lines +35 to +49
def compute_cos_sin_cache(
max_seq_len: int,
rotary_dim: int,
base: float = 10000.0,
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute separate cos and sin caches.
"""
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The compute_cos_sin_cache function is duplicated in both the benchmark and test files. To improve maintainability and avoid code duplication, consider moving this utility function to a shared helper module that can be imported by both the benchmark and test scripts.

# Conflicts:
#	sgl-kernel/csrc/common_extension.cc
x_index = 2 * rot_offset;
y_index = x_index + 1;

const float cos_val = static_cast<float>(SGLANG_LDG(cos_ptr + rot_offset));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use float4 to vectorize global memory read?

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 11, 2025

nits: If num_tokens is small, the current threading model may not fully utilize all SMs, because block size is num_tokens. Therefore, when num_tokens is relatively small, parallelization can be applied across both tokens and heads.

@RubiaCx
Copy link
Copy Markdown
Collaborator Author

RubiaCx commented Dec 12, 2025

@BBuf I addressed the low-num_tokens under-utilization by updating the diffusion RoPE kernel.

  • Launch policy: when num_tokens <= 4 and blocks_per_token > 1, switch to a 2D launch (token × pair-tile) to increase parallelism for tiny-token regimes. Otherwise keep a 1D launch (one block per token) for general workloads.
  • Vectorization: added a float2 fast path on the hot loop to improve memory throughput.

Results (Qwen-image): average kernel time improved from 29.551 µs to 24.557 µs (1.203× speedup).
image

Additionally, compared to the positions API (FlashInfer-based path), the diffusion RoPE kernel path is ~1.24× faster on average (weighted across configs).

@RubiaCx
Copy link
Copy Markdown
Collaborator Author

RubiaCx commented Jan 3, 2026

@BBuf Hi, #16287 already fixed the AMD multimodal CI failures introduced by the refactor in #15812 / #15813 by restoring a proper normalization fallback on ROCm. I removed my local fallback implementation to avoid duplicating logic and diverging from main.

Separately, I added a FlashInfer RoPE benchmark comparison. FlashInfer currently supports only a limited set of shapes (e.g., it errors on head_dim=96, so those entries are NaN), and in the supported cases it is not faster than our implementations.
image
image

@RubiaCx
Copy link
Copy Markdown
Collaborator Author

RubiaCx commented Jan 3, 2026

E2E qwen-image

We added a small cache for the JIT RoPE path to avoid per-block/step overhead (repeated cast/contiguous on the RoPE cache). In microbenchmarks, FlashInfer RoPE is still slower than our JIT RoPE for most tested shapes, but end-to-end latency is essentially unchanged (within normal variance), so we reverted Qwen-image’s default RoPE changes to minimize behavioral differences.

image
Metric Baseline New Diff Status
E2E Latency 14600.17 ms 14716.08 ms +115.92 ms (+0.8%) ⚪️
Throughput 0.07 req/s 0.07 req/s - -
Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%) Status
InputValidationStage 0.03 0.04 +0.00 +5.3% ⚪️
TextEncodingStage 907.69 892.53 -15.15 -1.7% ⚪️
ConditioningStage 0.01 0.01 +0.00 +5.7% ⚪️
TimestepPreparationStage 0.49 0.49 -0.00 -0.4% ⚪️
LatentPreparationStage 0.17 0.18 +0.00 +1.7% ⚪️
DenoisingStage 13393.81 13516.01 +122.20 +0.9% ⚪️
DecodingStage 291.99 300.82 +8.83 +3.0% ⚪️

Benchmarking

The attached table is the benchmarking report. Each row is a (batch_size, seq_len, head_size, interleaved) config; values are speedup vs our JIT baseline (1.0 = same, <1.0 = slower, >1.0 = faster).

FlashInfer is <1.0 for most shapes; NaN indicates unsupported configs (e.g., head_dim/cache requirements).

image

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Jan 5, 2026

Please fix the lint.
btw, I believe this jit-kernel is also helpful to non-diffusion inference.

@RubiaCx
Copy link
Copy Markdown
Collaborator Author

RubiaCx commented Jan 5, 2026

Hi @DarkSharpness. Thanks for the review. Besides addressing the API/dispatch issues, I also updated the benchmark and tests for better correctness and fairness.
image

@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Jan 9, 2026

/tag-and-rerun-ci

return __float2half_rn(v);
}
template <>
__device__ __forceinline__ nv_bfloat16 from_float<nv_bfloat16>(float v) {
Copy link
Copy Markdown
Collaborator

@DarkSharpness DarkSharpness Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #16884 we introduce type.cuh, which include some type conversion primitives in namespace device. You may try to replace these type casting with cast<To>(value).

union VecU {
float4 v;
scalar_t e[kElePerVec];
} tmp;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to replace this with device::AlignedVector. This can also generate some effective aligned load/store instructions (just like float4).


// RoPE kernel with 1D grid: one block per token.
template <typename scalar_t, bool interleaved, bool vectorized, bool aligned_qk, int ROT_EMBED_DIM>
__launch_bounds__(512) __global__ void rotary_embedding_kernel_1d(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the performance here. In my qknorm implementation, we use a persistent 2d kernel, where each warp handle one qk head dimension at a time. The code logic is as follows:

  1. Each warp determines which q head or k head it will process at this iteration.
  2. Set up the pointer information
    const int64_t token_id = idx / num_q_and_k_heads;
    const int64_t head_id = idx % num_q_and_k_heads;
    const auto load_q = head_id < num_qo_heads;
    const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim))
    : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim));
    const auto weight = load_q ? q_weight : k_weight;
  3. Load and perform algorithm on the input
    const auto input_vec = gmem.load(input);
    const auto weight_vec = gmem.load(weight);
    const auto output_vec = norm::apply_norm_warp<kHeadDim>(input_vec, weight_vec, eps);
    gmem.store(input, output_vec);

This can hopefully reduce code size and improve readability. You may refer to this as an example. I would suggest you take a similar approach (persistent 2d kernel + unified code logic) if it won't bring too much regression to performance.

if ((reinterpret_cast<uintptr_t>(sin.data_ptr()) % kVecBytes) != 0) can_vec_compute = false;
if (((r * elem_bytes) % kVecBytes) != 0) can_vec_compute = false;

bool qk_aligned16 = true;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have to consider the unaligned cases? In most cases, I would expect the head bytes to be at least 64 bytes aligned. Do not optimize too early for rare cases.

const int kElePerVec = kVecBytes / elem_bytes;
const int pairs_per_step = interleaved ? (kElePerVec / 2) : kElePerVec;

bool can_vec_compute = true;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have to consider the unaligned cases? In most cases, I would expect the head bytes to be at least 64 bytes aligned. Do not optimize too early for rare cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

diffusion: rotary embedding kernel

6 participants