diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py new file mode 100644 index 000000000000..e6ef6bc77a36 --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py @@ -0,0 +1,183 @@ +""" +Benchmark: fused_qknorm_rope JIT vs AOT (sgl_kernel) + +Measures throughput (us) for fused_qk_norm_rope across typical +LLM configurations (head_dim x num_heads x num_tokens). + +Run: + python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py +""" + +import itertools + +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark +from sglang.jit_kernel.fused_qknorm_rope import ( + fused_qk_norm_rope as fused_qk_norm_rope_jit, +) + +try: + from sgl_kernel import fused_qk_norm_rope as fused_qk_norm_rope_aot + + AOT_AVAILABLE = True +except ImportError: + fused_qk_norm_rope_aot = None + AOT_AVAILABLE = False + +# --------------------------------------------------------------------------- +# Benchmark configuration +# --------------------------------------------------------------------------- + +NUM_TOKENS_RANGE = get_benchmark_range( + full_range=[1, 64, 256, 1024, 4096], + ci_range=[64, 512], +) + +# (head_dim, num_heads_q, num_heads_k, num_heads_v) - typical MoE/dense configs +MODEL_CONFIGS = get_benchmark_range( + full_range=[ + (64, 32, 8, 8), # small + (128, 32, 8, 8), # typical (e.g. Qwen3-8B) + (256, 16, 4, 4), # large head_dim + ], + ci_range=[(128, 32, 8, 8)], +) + +LINE_VALS = ["jit", "aot"] if AOT_AVAILABLE else ["jit"] +LINE_NAMES = ["JIT (new)", "AOT sgl_kernel"] if AOT_AVAILABLE else ["JIT (new)"] +STYLES = [("blue", "--"), ("orange", "-")] if AOT_AVAILABLE else [("blue", "--")] + + +# --------------------------------------------------------------------------- +# Benchmark: fused_qk_norm_rope (interleave style, no YaRN) +# --------------------------------------------------------------------------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "head_dim", "num_heads_q", "num_heads_k", "num_heads_v"], + x_vals=[ + (nt, hd, nq, nk, nv) + for nt, (hd, nq, nk, nv) in itertools.product( + NUM_TOKENS_RANGE, MODEL_CONFIGS + ) + ], + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="fused-qknorm-rope-performance", + args={}, + ) +) +def bench_fused_qknorm_rope( + num_tokens: int, + head_dim: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + provider: str, +): + device = "cuda" + total_heads = num_heads_q + num_heads_k + num_heads_v + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + common_kwargs = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=False, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + rotary_dim=head_dim, + ) + + if provider == "jit": + fn = lambda: fused_qk_norm_rope_jit(qkv.clone(), **common_kwargs) + elif provider == "aot": + fn = lambda: fused_qk_norm_rope_aot(qkv.clone(), **common_kwargs) + else: + raise ValueError(f"Unknown provider: {provider}") + + return run_benchmark(fn) + + +# --------------------------------------------------------------------------- +# Quick correctness diff +# --------------------------------------------------------------------------- + + +def calculate_diff(): + if not AOT_AVAILABLE: + print("sgl_kernel not available - skipping AOT diff check") + return + + device = "cuda" + print("Correctness diff (JIT vs AOT):") + + for head_dim, is_neox in [(64, False), (128, False), (128, True), (256, False)]: + num_tokens = 32 + num_heads_q, num_heads_k, num_heads_v = 4, 2, 2 + total_heads = num_heads_q + num_heads_k + num_heads_v + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + common = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=is_neox, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + rotary_dim=head_dim, + ) + + qkv_jit = qkv.clone() + fused_qk_norm_rope_jit(qkv_jit, **common) + qkv_aot = qkv.clone() + fused_qk_norm_rope_aot(qkv_aot, **common) + + match = torch.allclose(qkv_jit.float(), qkv_aot.float(), atol=1e-2, rtol=1e-2) + status = "OK" if match else "MISMATCH" + max_err = (qkv_jit.float() - qkv_aot.float()).abs().max().item() + print( + f" head_dim={head_dim:3d} is_neox={str(is_neox):5s} " + f"max_err={max_err:.2e} [{status}]" + ) + + +if __name__ == "__main__": + calculate_diff() + print() + bench_fused_qknorm_rope.run(print_data=True) diff --git a/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh new file mode 100644 index 000000000000..40401572b3b8 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh @@ -0,0 +1,307 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu + +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +// --------------------------------------------------------------------------- +// YaRN-aware frequency computation +// +// When factor == 1.0, reduces to standard RoPE: base^(-2*half_dim/rotary_dim) +// When factor != 1.0, blends interpolated and extrapolated frequencies. +// --------------------------------------------------------------------------- + +__device__ inline float +compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float low, float high) { + float freq = powf(base, -2.0f * half_dim / static_cast(rotary_dim)); + + if (factor != 1.0f) { + float inv_freq_extrapolation = freq; + float inv_freq_interpolation = freq / factor; + + float high_adj = high; + if (fabsf(low - high_adj) <= 1e-6f) { + high_adj += 0.001f; + } + + float linear_func = (static_cast(half_dim) - low) / (high_adj - low); + float ramp_func = fminf(fmaxf(linear_func, 0.0f), 1.0f); + float inv_freq_extrapolation_factor = 1.0f - ramp_func; + + freq = inv_freq_interpolation * (1.0f - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor; + } + + return freq; +} + +// --------------------------------------------------------------------------- +// Fused QK-Norm + RoPE kernel +// +// Each warp processes one (token, head) pair. +// head_dim: compile-time head dimension (64, 128, or 256) +// interleave: true -> interleave / GPT-J style RoPE (!is_neox) +// false -> NeoX style RoPE (is_neox) +// --------------------------------------------------------------------------- + +template +__global__ void fusedQKNormRopeKernel( + __nv_bfloat16* qkv, // [num_tokens, (nq+nk+nv)*head_dim], in-place + int const num_heads_q, + int const num_heads_k, + int const num_heads_v, + float const eps, + __nv_bfloat16 const* q_weight, // [head_dim] + __nv_bfloat16 const* k_weight, // [head_dim] + float const base, + int const* position_ids, // [num_tokens] + int const num_tokens, + float factor, + float low, + float high, + float attention_factor, + int const rotary_dim) { + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + int const total_qk_heads = num_heads_q + num_heads_k; + + int const tokenIdx = globalWarpIdx / total_qk_heads; + int const localHeadIdx = globalWarpIdx % total_qk_heads; + + if (tokenIdx >= num_tokens) return; + + bool const isQ = localHeadIdx < num_heads_q; + int const headIdx = isQ ? localHeadIdx : localHeadIdx - num_heads_q; + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + + static_assert(head_dim % (32 * 2) == 0, "head_dim must be divisible by 64 (each warp handles one head)"); + constexpr int numElemsPerThread = head_dim / 32; + float elements[numElemsPerThread]; + using vec_T = device::AlignedVector; + + // Compute flat offset of this warp's head in qkv + int offsetWarp; + if (isQ) { + offsetWarp = tokenIdx * num_heads * head_dim + headIdx * head_dim; + } else { + offsetWarp = tokenIdx * num_heads * head_dim + num_heads_q * head_dim + headIdx * head_dim; + } + int offsetThread = offsetWarp + laneId * numElemsPerThread; + + // ------------------------------------------------------------------- + // Load and compute sum-of-squares for RMSNorm + // ------------------------------------------------------------------- + float sumOfSquares = 0.0f; + { + vec_T vec; + vec.load(qkv + offsetThread); + for (int i = 0; i < numElemsPerThread; i++) { + float val = device::cast(vec[i]); + sumOfSquares += val * val; + elements[i] = val; + } + } + + sumOfSquares = device::warp::reduce_sum(sumOfSquares); + + // ------------------------------------------------------------------- + // Apply RMSNorm + // ------------------------------------------------------------------- + float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); + for (int i = 0; i < numElemsPerThread; i++) { + int dim = laneId * numElemsPerThread + i; + float weight = isQ ? device::cast(q_weight[dim]) : device::cast(k_weight[dim]); + elements[i] *= rms_rcp * weight; + } + + // ------------------------------------------------------------------- + // Apply RoPE to the first rotary_dim elements + // ------------------------------------------------------------------- + float elements2[numElemsPerThread]; + float cos_vals[numElemsPerThread]; + float sin_vals[numElemsPerThread]; + float pos_id = static_cast(position_ids[tokenIdx]); + int const rotary_lanes = rotary_dim / numElemsPerThread; + bool const applyRotary = (laneId < rotary_lanes); + + if (applyRotary) { + if constexpr (interleave) { + // Interleave (GPT-J) style: pairs of consecutive elements share a frequency + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = (i % 2 == 0) ? -elements[i + 1] : elements[i - 1]; + + int dim_idx = laneId * numElemsPerThread + i; + int half_dim = dim_idx / 2; + float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); + float theta = pos_id * freq; + __sincosf(theta, &sin_vals[i], &cos_vals[i]); + } + } else { + // NeoX style: first and second halves of the rotary region are paired + __syncwarp(); + int const half_rotary_lanes = rotary_lanes / 2; + // Avoid UB from (1u << 32) when rotary_lanes == 32 + unsigned int active_mask = 0xffffffffu >> (32 - rotary_lanes); + for (int i = 0; i < numElemsPerThread; i++) { + elements2[i] = __shfl_xor_sync(active_mask, elements[i], half_rotary_lanes); + if (laneId < half_rotary_lanes) { + elements2[i] = -elements2[i]; + } + + int dim_idx = laneId * numElemsPerThread + i; + // Remap so that both halves use the same set of frequencies + dim_idx = (dim_idx * 2) % rotary_dim; + int half_dim = dim_idx / 2; + float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); + float theta = pos_id * freq; + __sincosf(theta, &sin_vals[i], &cos_vals[i]); + } + __syncwarp(); + } + + for (int i = 0; i < numElemsPerThread; i++) { + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + } + } + + // ------------------------------------------------------------------- + // Store (all elements: rotated + pass-through normalized) + // ------------------------------------------------------------------- + { + vec_T vec; + for (int i = 0; i < numElemsPerThread; i++) { + vec[i] = device::cast(elements[i]); + } + vec.store(qkv + offsetThread); + } +} + +// --------------------------------------------------------------------------- +// Host-side tvm-ffi entry point +// +// HEAD_DIM and INTERLEAVE are compile-time template parameters, passed as +// template arguments from Python via the cuda_wrappers specialisation in +// fused_qknorm_rope.py (e.g. fused_qk_norm_rope<128, false>). This avoids +// both runtime dispatch and macro-based specialisation. +// --------------------------------------------------------------------------- + +template +void fused_qk_norm_rope( + tvm::ffi::TensorView qkv, // [num_tokens, (nq+nk+nv)*head_dim] bf16 + tvm::ffi::TensorView q_weight, // [head_dim] bf16 + tvm::ffi::TensorView k_weight, // [head_dim] bf16 + tvm::ffi::TensorView position_ids, // [num_tokens] int32 + int num_heads_q, + int num_heads_k, + int num_heads_v, + float eps, + float base, + float factor, + float low, + float high, + float attention_factor, + int rotary_dim) { + using namespace host; + + static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256, "HEAD_DIM must be 64, 128, or 256"); + + RuntimeCheck(qkv.device().device_type == kDLCUDA, "qkv must be a CUDA tensor"); + RuntimeCheck(qkv.is_contiguous(), "qkv must be contiguous"); + RuntimeCheck(qkv.dtype().code == kDLBfloat && qkv.dtype().bits == 16, "qkv must be bfloat16"); + RuntimeCheck(qkv.ndim() == 2, "qkv must be 2D: [num_tokens, (nq+nk+nv)*head_dim]"); + + RuntimeCheck(q_weight.is_contiguous(), "q_weight must be contiguous"); + RuntimeCheck(q_weight.dtype().code == kDLBfloat && q_weight.dtype().bits == 16, "q_weight must be bfloat16"); + RuntimeCheck( + q_weight.ndim() == 1 && static_cast(q_weight.size(0)) == HEAD_DIM, "q_weight must be 1D of size head_dim"); + + RuntimeCheck(k_weight.is_contiguous(), "k_weight must be contiguous"); + RuntimeCheck(k_weight.dtype().code == kDLBfloat && k_weight.dtype().bits == 16, "k_weight must be bfloat16"); + RuntimeCheck( + k_weight.ndim() == 1 && static_cast(k_weight.size(0)) == HEAD_DIM, "k_weight must be 1D of size head_dim"); + + RuntimeCheck(position_ids.device().device_type == kDLCUDA, "position_ids must be a CUDA tensor"); + RuntimeCheck(position_ids.is_contiguous(), "position_ids must be contiguous"); + RuntimeCheck(position_ids.dtype().code == kDLInt && position_ids.dtype().bits == 32, "position_ids must be int32"); + RuntimeCheck(position_ids.ndim() == 1, "position_ids must be 1D: [num_tokens]"); + + int num_tokens = static_cast(qkv.size(0)); + int total_heads = num_heads_q + num_heads_k + num_heads_v; + RuntimeCheck( + static_cast(qkv.size(1)) == total_heads * HEAD_DIM, "qkv.size(1) must equal (nq + nk + nv) * head_dim"); + RuntimeCheck(static_cast(position_ids.size(0)) == num_tokens, "position_ids must have num_tokens elements"); + + constexpr int numElemsPerThread = HEAD_DIM / 32; + RuntimeCheck(rotary_dim % numElemsPerThread == 0, "rotary_dim must be divisible by (head_dim / 32)"); + + if constexpr (!INTERLEAVE) { + // NeoX uses __shfl_xor_sync which requires half_rotary_lanes to be a power of 2 + int rotary_lanes = rotary_dim / numElemsPerThread; + int half_rotary_lanes = rotary_lanes / 2; + bool is_pow2 = (half_rotary_lanes >= 1) && ((half_rotary_lanes & (half_rotary_lanes - 1)) == 0); + RuntimeCheck(is_pow2, "half_rotary_lanes must be a power of 2 for NeoX style RoPE"); + } + + cudaStream_t stream = LaunchKernel::resolve_device(qkv.device()); + + constexpr int blockSize = 256; + int warpsPerBlock = blockSize / 32; + int totalQKHeads = num_heads_q + num_heads_k; + int totalWarps = num_tokens * totalQKHeads; + int gridSize = host::div_ceil(totalWarps, warpsPerBlock); + + auto* qkv_ptr = reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()); + auto const* qw_ptr = reinterpret_cast<__nv_bfloat16 const*>(q_weight.data_ptr()); + auto const* kw_ptr = reinterpret_cast<__nv_bfloat16 const*>(k_weight.data_ptr()); + auto const* pos_ptr = reinterpret_cast(position_ids.data_ptr()); + + fusedQKNormRopeKernel<<>>( + qkv_ptr, + num_heads_q, + num_heads_k, + num_heads_v, + eps, + qw_ptr, + kw_ptr, + base, + pos_ptr, + num_tokens, + factor, + low, + high, + attention_factor, + rotary_dim); +} + +} // namespace diff --git a/python/sglang/jit_kernel/fused_qknorm_rope.py b/python/sglang/jit_kernel/fused_qknorm_rope.py new file mode 100644 index 000000000000..92ea1f4350ad --- /dev/null +++ b/python/sglang/jit_kernel/fused_qknorm_rope.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_fused_qknorm_rope_module(head_dim: int, is_neox: bool) -> Module: + interleave = "false" if is_neox else "true" + return load_jit( + "fused_qknorm_rope", + head_dim, + int(is_neox), + cuda_files=["elementwise/fused_qknorm_rope.cuh"], + cuda_wrappers=[ + ("fused_qk_norm_rope", f"fused_qk_norm_rope<{head_dim}, {interleave}>") + ], + extra_cuda_cflags=["--use_fast_math"], + ) + + +@register_custom_op( + op_name="fused_qk_norm_rope_out", + mutates_args=["qkv"], +) +def fused_qk_norm_rope_out( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + position_ids: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + base: float, + is_neox: bool, + factor: float, + low: float, + high: float, + attention_factor: float, + rotary_dim: int, +) -> None: + """ + Fused QK RMSNorm + RoPE applied in-place on the QKV tensor. + + Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. + + Args: + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place + q_weight: [head_dim] bfloat16 -RMSNorm weights for Q + k_weight: [head_dim] bfloat16 -RMSNorm weights for K + position_ids: [num_tokens] int32 + num_heads_q: number of query heads + num_heads_k: number of key heads + num_heads_v: number of value heads + head_dim: head dimension; must be 64, 128, or 256 + eps: epsilon for RMSNorm + base: RoPE base frequency + is_neox: True ->NeoX style, False ->interleave (GPT-J) style + factor: YaRN scaling factor (1.0 = standard RoPE) + low: YaRN low-frequency threshold + high: YaRN high-frequency threshold + attention_factor: scale applied to the rotary component + rotary_dim: number of elements per head to apply RoPE to + """ + module = _jit_fused_qknorm_rope_module(head_dim, is_neox) + module.fused_qk_norm_rope( + qkv, + q_weight, + k_weight, + position_ids, + num_heads_q, + num_heads_k, + num_heads_v, + eps, + base, + factor, + low, + high, + attention_factor, + rotary_dim, + ) + + +@cache_once +def can_use_fused_qk_norm_rope( + head_dim: int, is_neox: bool, dtype: torch.dtype +) -> bool: + """Return True if the JIT fused QK-Norm + RoPE kernel can be used. + + Args: + head_dim: head dimension; supported values are 64, 128, 256 + dtype: tensor dtype; only bfloat16 is supported + """ + logger = logging.getLogger(__name__) + if head_dim not in (64, 128, 256): + logger.warning( + f"Unsupported head_dim={head_dim} for JIT fused_qk_norm_rope kernel" + ) + return False + if dtype != torch.bfloat16: + logger.warning(f"Unsupported dtype={dtype} for JIT fused_qk_norm_rope kernel") + return False + try: + _jit_fused_qknorm_rope_module(head_dim, is_neox) + return True + except Exception as e: + logger.warning(f"Failed to load JIT fused_qk_norm_rope kernel: {e}") + return False + + +def fused_qk_norm_rope( + qkv: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + base: float, + is_neox: bool, + position_ids: torch.Tensor, + factor: float, + low: float, + high: float, + attention_factor: float, + rotary_dim: Optional[int] = None, +) -> None: + """ + Fused QK RMSNorm + RoPE applied in-place on the QKV tensor. + + Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. + + Args: + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place + num_heads_q: number of query heads + num_heads_k: number of key heads + num_heads_v: number of value heads + head_dim: head dimension; must be 64, 128, or 256 + eps: epsilon for RMSNorm + q_weight: [head_dim] bfloat16 -RMSNorm weights for Q + k_weight: [head_dim] bfloat16 -RMSNorm weights for K + base: RoPE base frequency + is_neox: True ->NeoX style, False ->interleave (GPT-J) style + position_ids: [num_tokens] int32 + factor: YaRN scaling factor (1.0 = standard RoPE) + low: YaRN low-frequency threshold + high: YaRN high-frequency threshold + attention_factor: scale applied to the rotary component + rotary_dim: elements per head to rotate; defaults to head_dim + """ + if rotary_dim is None: + rotary_dim = head_dim + fused_qk_norm_rope_out( + qkv, + q_weight, + k_weight, + position_ids, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + base, + is_neox, + factor, + low, + high, + attention_factor, + rotary_dim, + ) diff --git a/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py new file mode 100644 index 000000000000..10c6572900d3 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py @@ -0,0 +1,444 @@ +""" +Correctness tests for the fused_qknorm_rope JIT kernel. + +Validates fused_qk_norm_rope against a pure-PyTorch reference and (when +available) the sgl_kernel AOT implementation. +""" + +import pytest +import torch + +from sglang.jit_kernel.fused_qknorm_rope import fused_qk_norm_rope + +try: + from sgl_kernel import fused_qk_norm_rope as fused_qk_norm_rope_aot + + AOT_AVAILABLE = True +except ImportError: + AOT_AVAILABLE = False + +HEAD_DIMS = [64, 128, 256] +NUM_TOKENS = [1, 16, 128] + + +# --------------------------------------------------------------------------- +# Pure-PyTorch reference +# --------------------------------------------------------------------------- + + +def _compute_inv_freq_yarn(base, rotary_dim, factor, low, high, device): + """Compute YaRN-adjusted inverse frequencies for rotary_dim//2 positions.""" + half_dims = torch.arange(rotary_dim // 2, dtype=torch.float32, device=device) + inv_freq = base ** (-2.0 * half_dims / rotary_dim) + + if factor != 1.0: + inv_freq_interp = inv_freq / factor + inv_freq_extrap = inv_freq + high_adj = high if abs(high - low) > 1e-6 else high + 0.001 + linear = (half_dims - low) / (high_adj - low) + ramp = linear.clamp(0.0, 1.0) + extrap_factor = 1.0 - ramp + inv_freq = ( + inv_freq_interp * (1 - extrap_factor) + inv_freq_extrap * extrap_factor + ) + + return inv_freq + + +def fused_qk_norm_rope_ref( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + base, + is_neox, + position_ids, + factor, + low, + high, + attention_factor, + rotary_dim, +): + """ + Pure-PyTorch reference: RMSNorm per head, then RoPE on Q and K. + + Returns a new tensor (same shape as qkv) with the transformation applied. + """ + num_tokens = qkv.shape[0] + total_heads = num_heads_q + num_heads_k + num_heads_v + + qkv_f = qkv.float() + qw = q_weight.float() + kw = k_weight.float() + + # Reshape to [num_tokens, total_heads, head_dim] + qkv_3d = qkv_f.view(num_tokens, total_heads, head_dim) + q = qkv_3d[:, :num_heads_q].clone() # [num_tokens, nq, head_dim] + k = qkv_3d[:, num_heads_q : num_heads_q + num_heads_k].clone() + + # RMSNorm per head + def rms_norm_heads(x, w): + # x: [num_tokens, n_heads, head_dim], w: [head_dim] + rms = (x**2).mean(-1, keepdim=True) + return x * torch.rsqrt(rms + eps) * w + + q = rms_norm_heads(q, qw) + k = rms_norm_heads(k, kw) + + # Compute frequencies + inv_freq = _compute_inv_freq_yarn(base, rotary_dim, factor, low, high, qkv.device) + # theta: [num_tokens, rotary_dim//2] + theta = position_ids.float().unsqueeze(1) * inv_freq.unsqueeze(0) + cos = torch.cos(theta) # [num_tokens, rotary_dim//2] + sin = torch.sin(theta) + # Broadcast across heads: [num_tokens, 1, rotary_dim//2] + c = cos.unsqueeze(1) + s = sin.unsqueeze(1) + + if not is_neox: + # Interleave (GPT-J) style: rotate pairs (x[2i], x[2i+1]) + def apply_interleave(x): + # x: [num_tokens, n_heads, head_dim] + x_rot = x[:, :, :rotary_dim] # [num_tokens, n_heads, rotary_dim] + x_pairs = x_rot.view(num_tokens, -1, rotary_dim // 2, 2) + x0, x1 = x_pairs[..., 0], x_pairs[..., 1] + x0_new = x0 * c - x1 * s + x1_new = x1 * c + x0 * s + x_rot_new = torch.stack([x0_new, x1_new], dim=-1).view( + num_tokens, -1, rotary_dim + ) + result = x.clone() + result[:, :, :rotary_dim] = x_rot_new * attention_factor + return result + + q = apply_interleave(q) + k = apply_interleave(k) + else: + # NeoX style: first half * cos - second half * sin (and vice versa) + def apply_neox(x): + # x: [num_tokens, n_heads, head_dim] + x1 = x[:, :, : rotary_dim // 2] + x2 = x[:, :, rotary_dim // 2 : rotary_dim] + x1_new = x1 * c - x2 * s + x2_new = x2 * c + x1 * s + result = x.clone() + result[:, :, : rotary_dim // 2] = x1_new * attention_factor + result[:, :, rotary_dim // 2 : rotary_dim] = x2_new * attention_factor + return result + + q = apply_neox(q) + k = apply_neox(k) + + # Write back into a copy of the full QKV + result_3d = qkv_f.view(num_tokens, total_heads, head_dim).clone() + result_3d[:, :num_heads_q] = q + result_3d[:, num_heads_q : num_heads_q + num_heads_k] = k + return result_3d.view(num_tokens, -1).bfloat16() + + +# --------------------------------------------------------------------------- +# Tests: correctness vs PyTorch reference +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("is_neox", [False, True]) +def test_fused_qknorm_rope_vs_ref(head_dim, num_tokens, is_neox): + torch.manual_seed(head_dim * num_tokens + int(is_neox)) + device = "cuda" + num_heads_q, num_heads_k, num_heads_v = 4, 2, 2 + total_heads = num_heads_q + num_heads_k + num_heads_v + rotary_dim = head_dim # full rotary + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + eps = 1e-5 + base = 10000.0 + factor = 1.0 # no YaRN + low, high = 1.0, 32.0 + attention_factor = 1.0 + + ref = fused_qk_norm_rope_ref( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + base, + is_neox, + position_ids, + factor, + low, + high, + attention_factor, + rotary_dim, + ) + + qkv_jit = qkv.clone() + fused_qk_norm_rope( + qkv_jit, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + eps, + q_weight, + k_weight, + base, + is_neox, + position_ids, + factor, + low, + high, + attention_factor, + rotary_dim, + ) + + assert torch.allclose(qkv_jit.float(), ref.float(), atol=5e-3, rtol=1e-2), ( + f"mismatch: head_dim={head_dim}, num_tokens={num_tokens}, " + f"is_neox={is_neox}, " + f"max_err={( qkv_jit.float() - ref.float()).abs().max().item():.4e}" + ) + + +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +@pytest.mark.parametrize("is_neox", [False, True]) +def test_fused_qknorm_rope_partial_rotary(head_dim, is_neox): + """Test with rotary_dim < head_dim: non-rotary elements should be RMSNorm-only.""" + torch.manual_seed(42 + head_dim + int(is_neox)) + device = "cuda" + num_tokens = 16 + num_heads_q, num_heads_k, num_heads_v = 2, 2, 2 + total_heads = num_heads_q + num_heads_k + num_heads_v + rotary_dim = head_dim // 2 # half of head_dim + + # NeoX requires half_rotary_lanes to be power of 2. + # half_rotary_lanes = rotary_dim / (head_dim / 32) / 2 = (head_dim//2) / (head_dim/32) / 2 + # = 16 / 2 = 8 -> power of 2, OK for all supported head_dims. + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + ref = fused_qk_norm_rope_ref( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + 1e-5, + q_weight, + k_weight, + 10000.0, + is_neox, + position_ids, + 1.0, + 1.0, + 32.0, + 1.0, + rotary_dim, + ) + + qkv_jit = qkv.clone() + fused_qk_norm_rope( + qkv_jit, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + 1e-5, + q_weight, + k_weight, + 10000.0, + is_neox, + position_ids, + 1.0, + 1.0, + 32.0, + 1.0, + rotary_dim, + ) + + assert torch.allclose(qkv_jit.float(), ref.float(), atol=5e-3, rtol=1e-2), ( + f"partial rotary mismatch: head_dim={head_dim}, is_neox={is_neox}, " + f"max_err={(qkv_jit.float() - ref.float()).abs().max().item():.4e}" + ) + + +@pytest.mark.parametrize("head_dim", HEAD_DIMS) +def test_fused_qknorm_rope_yarn_scaling(head_dim): + """Test with YaRN scaling (factor != 1.0).""" + torch.manual_seed(99 + head_dim) + device = "cuda" + num_tokens = 32 + num_heads_q, num_heads_k, num_heads_v = 2, 2, 2 + total_heads = num_heads_q + num_heads_k + num_heads_v + rotary_dim = head_dim + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + factor = 2.5 + low, high = 4.0, 32.0 + attention_factor = 0.9 + is_neox = False # test with interleave; NeoX also tested in other tests + + ref = fused_qk_norm_rope_ref( + qkv, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + 1e-5, + q_weight, + k_weight, + 500000.0, + is_neox, + position_ids, + factor, + low, + high, + attention_factor, + rotary_dim, + ) + + qkv_jit = qkv.clone() + fused_qk_norm_rope( + qkv_jit, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + 1e-5, + q_weight, + k_weight, + 500000.0, + is_neox, + position_ids, + factor, + low, + high, + attention_factor, + rotary_dim, + ) + + assert torch.allclose(qkv_jit.float(), ref.float(), atol=5e-3, rtol=1e-2), ( + f"YaRN mismatch: head_dim={head_dim}, " + f"max_err={(qkv_jit.float() - ref.float()).abs().max().item():.4e}" + ) + + +def test_fused_qknorm_rope_default_rotary_dim(): + """rotary_dim=None should default to head_dim.""" + device = "cuda" + num_tokens = 8 + num_heads_q, num_heads_k, num_heads_v = 2, 2, 2 + head_dim = 128 + total_heads = num_heads_q + num_heads_k + num_heads_v + + torch.manual_seed(0) + qkv1 = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + qkv2 = qkv1.clone() + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device) + + common_kwargs = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=False, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + ) + + fused_qk_norm_rope(qkv1, **common_kwargs, rotary_dim=None) + fused_qk_norm_rope(qkv2, **common_kwargs, rotary_dim=head_dim) + + assert torch.equal(qkv1, qkv2), "rotary_dim=None must equal rotary_dim=head_dim" + + +# --------------------------------------------------------------------------- +# Cross-validation against AOT sgl_kernel +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not AOT_AVAILABLE, reason="sgl_kernel not available") +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("is_neox", [False, True]) +def test_fused_qknorm_rope_vs_aot(head_dim, is_neox): + torch.manual_seed(head_dim * 7 + int(is_neox)) + device = "cuda" + num_tokens = 32 + num_heads_q, num_heads_k, num_heads_v = 4, 2, 2 + total_heads = num_heads_q + num_heads_k + num_heads_v + + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.randn(head_dim, dtype=torch.bfloat16, device=device).abs() + 0.5 + k_weight = torch.randn(head_dim, dtype=torch.bfloat16, device=device).abs() + 0.5 + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + common = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=is_neox, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + rotary_dim=head_dim, + ) + + qkv_jit = qkv.clone() + fused_qk_norm_rope(qkv_jit, **common) + + qkv_aot = qkv.clone() + fused_qk_norm_rope_aot(qkv_aot, **common) + + assert torch.allclose(qkv_jit.float(), qkv_aot.float(), atol=1e-2, rtol=1e-2), ( + f"JIT vs AOT mismatch: head_dim={head_dim}, is_neox={is_neox}, " + f"max_err={(qkv_jit.float() - qkv_aot.float()).abs().max().item():.4e}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index b4099bfad5fd..71dc37e086fd 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -91,7 +91,10 @@ _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import fused_qk_norm_rope + from sglang.jit_kernel.fused_qknorm_rope import ( + can_use_fused_qk_norm_rope, + fused_qk_norm_rope, + ) TConfig = TypeVar("TConfig", bound=PretrainedConfig) @@ -503,12 +506,18 @@ def __init__( self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True ) - self.compatible_with_fused_qk_norm_rope = ( - not isinstance(self.rotary_emb, MRotaryEmbedding) + self.compatible_with_fused_qk_norm_rope = not isinstance( + self.rotary_emb, MRotaryEmbedding ) and self.head_dim in (64, 128, 256) self.use_fused_qk_norm_rope = ( get_global_server_args().enable_fused_qk_norm_rope and self.compatible_with_fused_qk_norm_rope + and _is_cuda + and can_use_fused_qk_norm_rope( + self.head_dim, + self.rotary_emb.is_neox_style, + torch.bfloat16, + ) ) self._used_fused_qk_norm_rope_last_call = False