Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 85 additions & 4 deletions python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Benchmark: fused_qknorm_rope JIT vs AOT (sgl_kernel)

Measures throughput (us) for fused_qk_norm_rope across typical
LLM configurations (head_dim x num_heads x num_tokens).
Measures throughput (µs) for fused_qk_norm_rope across typical
LLM configurations (head_dim × num_heads × num_tokens).

Run:
python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py
Expand Down Expand Up @@ -39,7 +39,7 @@
ci_range=[64, 512],
)

# (head_dim, num_heads_q, num_heads_k, num_heads_v) - typical MoE/dense configs
# (head_dim, num_heads_q, num_heads_k, num_heads_v) typical MoE/dense configs
MODEL_CONFIGS = get_benchmark_range(
full_range=[
(64, 32, 8, 8), # small
Expand All @@ -49,6 +49,16 @@
ci_range=[(128, 32, 8, 8)],
)

# Real production shapes (self-attention; num_heads_k == num_heads_v == num_heads_q).
# Format: (name, num_tokens, num_heads_q, num_heads_k, num_heads_v, head_dim, rotary_dim)
PRODUCTION_SHAPES = [
("flux_1024", 4096, 24, 24, 24, 128, 128),
("qwen_image_1024", 4096, 32, 32, 32, 128, 128),
("qwen_image_partial", 4096, 32, 32, 32, 128, 64),
("zimage_1024", 4096, 30, 30, 30, 128, 128),
("batch2_medium", 4096, 24, 24, 24, 128, 128), # B=2, T=2048
]

LINE_VALS = ["jit", "aot"] if AOT_AVAILABLE else ["jit"]
LINE_NAMES = ["JIT (new)", "AOT sgl_kernel"] if AOT_AVAILABLE else ["JIT (new)"]
STYLES = [("blue", "--"), ("orange", "-")] if AOT_AVAILABLE else [("blue", "--")]
Expand Down Expand Up @@ -123,14 +133,83 @@ def bench_fused_qknorm_rope(
return run_benchmark(fn)


# ---------------------------------------------------------------------------
# Benchmark: fused_qk_norm_rope — real production shapes (with speedup column)
# ---------------------------------------------------------------------------


def bench_fused_qknorm_rope_production():
device = "cuda"
header = f"{'name':<22} {'tokens':>6} {'nq':>4} {'nk':>4} {'nv':>4} {'hd':>4} {'rdim':>5} {'JIT(us)':>9} {'AOT(us)':>9} {'speedup':>8}"
sep = "-" * len(header)
print("\nfused-qknorm-rope-production-shapes:")
print(sep)
print(header)
print(sep)

for (
name,
num_tokens,
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
rotary_dim,
) in PRODUCTION_SHAPES:
total_heads = num_heads_q + num_heads_k + num_heads_v
qkv = torch.randn(
(num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device
)
q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device)
k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device)
position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device)

common_kwargs = dict(
num_heads_q=num_heads_q,
num_heads_k=num_heads_k,
num_heads_v=num_heads_v,
head_dim=head_dim,
eps=1e-5,
q_weight=q_weight,
k_weight=k_weight,
base=10000.0,
is_neox=False,
position_ids=position_ids,
factor=1.0,
low=1.0,
high=32.0,
attention_factor=1.0,
rotary_dim=rotary_dim,
)

jit_us, _, _ = run_benchmark(
lambda: fused_qk_norm_rope_jit(qkv.clone(), **common_kwargs)
)
if AOT_AVAILABLE:
aot_us, _, _ = run_benchmark(
lambda: fused_qk_norm_rope_aot(qkv.clone(), **common_kwargs)
)
speedup = f"{aot_us / jit_us:.2f}x"
aot_str = f"{aot_us:9.3f}"
else:
aot_str = f"{'N/A':>9}"
speedup = "N/A"

print(
f"{name:<22} {num_tokens:>6} {num_heads_q:>4} {num_heads_k:>4} {num_heads_v:>4}"
f" {head_dim:>4} {rotary_dim:>5} {jit_us:9.3f} {aot_str} {speedup:>8}"
)
print(sep)


# ---------------------------------------------------------------------------
# Quick correctness diff
# ---------------------------------------------------------------------------


def calculate_diff():
if not AOT_AVAILABLE:
print("sgl_kernel not available - skipping AOT diff check")
print("sgl_kernel not available skipping AOT diff check")
return

device = "cuda"
Expand Down Expand Up @@ -184,3 +263,5 @@ def calculate_diff():
calculate_diff()
print()
bench_fused_qknorm_rope.run(print_data=True)
print()
bench_fused_qknorm_rope_production()
149 changes: 94 additions & 55 deletions python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ namespace {
// When factor != 1.0, blends interpolated and extrapolated frequencies.
// ---------------------------------------------------------------------------

__device__ inline float
compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float low, float high) {
template <bool yarn>
__device__ inline float compute_freq(float base, int rotary_dim, int half_dim, float factor, float low, float high) {
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));

if (factor != 1.0f) {
if constexpr (yarn) {
float inv_freq_extrapolation = freq;
float inv_freq_interpolation = freq / factor;

Expand All @@ -68,11 +68,14 @@ compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float
//
// Each warp processes one (token, head) pair.
// head_dim: compile-time head dimension (64, 128, or 256)
// interleave: true -> interleave / GPT-J style RoPE (!is_neox)
// false -> NeoX style RoPE (is_neox)
// interleave: true interleave / GPT-J style RoPE (!is_neox)
// false NeoX style RoPE (is_neox)
// ---------------------------------------------------------------------------

template <int head_dim, bool interleave>
// interleave (GPT-J) pairs (2k,2k+1) share the same freq/theta,
// so sin/cos is computed once per pair and copied to the odd element,
// halving powf + __sincosf calls vs a naive per-element approach.
template <int head_dim, bool interleave, bool yarn>
__global__ void fusedQKNormRopeKernel(
__nv_bfloat16* qkv, // [num_tokens, (nq+nk+nv)*head_dim], in-place
int const num_heads_q,
Expand Down Expand Up @@ -139,36 +142,65 @@ __global__ void fusedQKNormRopeKernel(
// Apply RMSNorm
// -------------------------------------------------------------------
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
for (int i = 0; i < numElemsPerThread; i++) {
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? device::cast<float>(q_weight[dim]) : device::cast<float>(k_weight[dim]);
elements[i] *= rms_rcp * weight;
{
vec_T wvec;
wvec.load((isQ ? q_weight : k_weight) + offsetThread - offsetWarp);
for (int i = 0; i < numElemsPerThread; i++) {
elements[i] *= rms_rcp * device::cast<float>(wvec[i]);
}
}

// -------------------------------------------------------------------
// Apply RoPE to the first rotary_dim elements
// -------------------------------------------------------------------
float elements2[numElemsPerThread];
float cos_vals[numElemsPerThread];
float sin_vals[numElemsPerThread];
float pos_id = static_cast<float>(position_ids[tokenIdx]);
int const rotary_lanes = rotary_dim / numElemsPerThread;
bool const applyRotary = (laneId < rotary_lanes);

if (applyRotary) {
if constexpr (interleave) {
// Interleave (GPT-J) style: pairs of consecutive elements share a frequency
for (int i = 0; i < numElemsPerThread; i++) {
elements2[i] = (i % 2 == 0) ? -elements[i + 1] : elements[i - 1];
// Pairs (2k, 2k+1) share the same half_dim → same freq/theta.
// numElemsPerThread is always even (head_dim/32, head_dim in {64,128,256}),
// so we step by 2 and handle both elements of each pair per iteration.
//
// freq follows a geometric series across pairs: freq[k] = freq[0] * ratio^k,
// where ratio = base^(-2/rotary_dim). Pre-compute both outside the loop to
// replace all but the first powf call with a single multiply per iteration.
//
// sin/cos are applied immediately to e0/e1, eliminating the elements2,
// cos_vals, sin_vals intermediate arrays and reducing register pressure.
int const half_dim_start = laneId * numElemsPerThread / 2;
float freq = powf(base, -2.0f * static_cast<float>(half_dim_start) / static_cast<float>(rotary_dim));
float const freq_ratio = powf(base, -2.0f / static_cast<float>(rotary_dim));

for (int i = 0; i < numElemsPerThread; i += 2) {
float e0 = elements[i];
float e1 = elements[i + 1];

float f = freq;
if constexpr (yarn) {
int half_dim = half_dim_start + i / 2;
float inv_freq_interpolation = freq / factor;
float high_adj = (fabsf(low - high) <= 1e-6f) ? high + 0.001f : high;
float linear_func = (static_cast<float>(half_dim) - low) / (high_adj - low);
float ramp_func = fminf(fmaxf(linear_func, 0.0f), 1.0f);
float extrap_factor = 1.0f - ramp_func;
f = inv_freq_interpolation * (1.0f - extrap_factor) + freq * extrap_factor;
}

int dim_idx = laneId * numElemsPerThread + i;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
float s, c;
__sincosf(pos_id * f, &s, &c);
elements[i] = (e0 * c - e1 * s) * attention_factor;
elements[i + 1] = (e1 * c + e0 * s) * attention_factor;

freq *= freq_ratio;
}
} else {
// NeoX style: first and second halves of the rotary region are paired
float elements2[numElemsPerThread];
float cos_vals[numElemsPerThread];
float sin_vals[numElemsPerThread];

__syncwarp();
int const half_rotary_lanes = rotary_lanes / 2;
// Avoid UB from (1u << 32) when rotary_lanes == 32
Expand All @@ -183,15 +215,15 @@ __global__ void fusedQKNormRopeKernel(
// Remap so that both halves use the same set of frequencies
dim_idx = (dim_idx * 2) % rotary_dim;
int half_dim = dim_idx / 2;
float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high);
float freq = compute_freq<yarn>(base, rotary_dim, half_dim, factor, low, high);
float theta = pos_id * freq;
__sincosf(theta, &sin_vals[i], &cos_vals[i]);
}
__syncwarp();
}

for (int i = 0; i < numElemsPerThread; i++) {
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
for (int i = 0; i < numElemsPerThread; i++) {
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
}
}
}

Expand All @@ -209,14 +241,8 @@ __global__ void fusedQKNormRopeKernel(

// ---------------------------------------------------------------------------
// Host-side tvm-ffi entry point
//
// HEAD_DIM and INTERLEAVE are compile-time template parameters, passed as
// template arguments from Python via the cuda_wrappers specialisation in
// fused_qknorm_rope.py (e.g. fused_qk_norm_rope<128, false>). This avoids
// both runtime dispatch and macro-based specialisation.
// ---------------------------------------------------------------------------

template <int HEAD_DIM, bool INTERLEAVE>
void fused_qk_norm_rope(
tvm::ffi::TensorView qkv, // [num_tokens, (nq+nk+nv)*head_dim] bf16
tvm::ffi::TensorView q_weight, // [head_dim] bf16
Expand All @@ -225,17 +251,17 @@ void fused_qk_norm_rope(
int num_heads_q,
int num_heads_k,
int num_heads_v,
int head_dim,
float eps,
float base,
int is_neox, // 0 = interleave style, 1 = NeoX style
float factor,
float low,
float high,
float attention_factor,
int rotary_dim) {
using namespace host;

static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256, "HEAD_DIM must be 64, 128, or 256");

RuntimeCheck(qkv.device().device_type == kDLCUDA, "qkv must be a CUDA tensor");
RuntimeCheck(qkv.is_contiguous(), "qkv must be contiguous");
RuntimeCheck(qkv.dtype().code == kDLBfloat && qkv.dtype().bits == 16, "qkv must be bfloat16");
Expand All @@ -244,12 +270,12 @@ void fused_qk_norm_rope(
RuntimeCheck(q_weight.is_contiguous(), "q_weight must be contiguous");
RuntimeCheck(q_weight.dtype().code == kDLBfloat && q_weight.dtype().bits == 16, "q_weight must be bfloat16");
RuntimeCheck(
q_weight.ndim() == 1 && static_cast<int>(q_weight.size(0)) == HEAD_DIM, "q_weight must be 1D of size head_dim");
q_weight.ndim() == 1 && static_cast<int>(q_weight.size(0)) == head_dim, "q_weight must be 1D of size head_dim");

RuntimeCheck(k_weight.is_contiguous(), "k_weight must be contiguous");
RuntimeCheck(k_weight.dtype().code == kDLBfloat && k_weight.dtype().bits == 16, "k_weight must be bfloat16");
RuntimeCheck(
k_weight.ndim() == 1 && static_cast<int>(k_weight.size(0)) == HEAD_DIM, "k_weight must be 1D of size head_dim");
k_weight.ndim() == 1 && static_cast<int>(k_weight.size(0)) == head_dim, "k_weight must be 1D of size head_dim");

RuntimeCheck(position_ids.device().device_type == kDLCUDA, "position_ids must be a CUDA tensor");
RuntimeCheck(position_ids.is_contiguous(), "position_ids must be contiguous");
Expand All @@ -259,49 +285,62 @@ void fused_qk_norm_rope(
int num_tokens = static_cast<int>(qkv.size(0));
int total_heads = num_heads_q + num_heads_k + num_heads_v;
RuntimeCheck(
static_cast<int>(qkv.size(1)) == total_heads * HEAD_DIM, "qkv.size(1) must equal (nq + nk + nv) * head_dim");
static_cast<int>(qkv.size(1)) == total_heads * head_dim, "qkv.size(1) must equal (nq + nk + nv) * head_dim");
RuntimeCheck(static_cast<int>(position_ids.size(0)) == num_tokens, "position_ids must have num_tokens elements");

constexpr int numElemsPerThread = HEAD_DIM / 32;
static_assert(
JIT_HEAD_DIM == 64 || JIT_HEAD_DIM == 128 || JIT_HEAD_DIM == 256, "JIT_HEAD_DIM must be 64, 128, or 256");
static_assert(JIT_INTERLEAVE == 0 || JIT_INTERLEAVE == 1, "JIT_INTERLEAVE must be 0 or 1");
static_assert(JIT_YARN == 0 || JIT_YARN == 1, "JIT_YARN must be 0 or 1");
RuntimeCheck(head_dim == JIT_HEAD_DIM, "head_dim mismatch with JIT-compiled kernel");

int numElemsPerThread = head_dim / 32;
RuntimeCheck(rotary_dim % numElemsPerThread == 0, "rotary_dim must be divisible by (head_dim / 32)");

if constexpr (!INTERLEAVE) {
bool neox = static_cast<bool>(is_neox);
if (neox) {
// NeoX uses __shfl_xor_sync which requires half_rotary_lanes to be a power of 2
int rotary_lanes = rotary_dim / numElemsPerThread;
int half_rotary_lanes = rotary_lanes / 2;
bool is_pow2 = (half_rotary_lanes >= 1) && ((half_rotary_lanes & (half_rotary_lanes - 1)) == 0);
RuntimeCheck(is_pow2, "half_rotary_lanes must be a power of 2 for NeoX style RoPE");
}

bool interleave = !neox;
RuntimeCheck(interleave == static_cast<bool>(JIT_INTERLEAVE), "interleave mismatch with JIT-compiled kernel");
bool use_yarn = (factor != 1.0f);
RuntimeCheck(use_yarn == static_cast<bool>(JIT_YARN), "yarn mismatch with JIT-compiled kernel");

cudaStream_t stream = LaunchKernel::resolve_device(qkv.device());

constexpr int blockSize = 256;
int warpsPerBlock = blockSize / 32;
int totalQKHeads = num_heads_q + num_heads_k;
int totalWarps = num_tokens * totalQKHeads;
int gridSize = host::div_ceil(totalWarps, warpsPerBlock);
int gridSize = div_ceil(totalWarps, warpsPerBlock);

auto* qkv_ptr = reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr());
auto const* qw_ptr = reinterpret_cast<__nv_bfloat16 const*>(q_weight.data_ptr());
auto const* kw_ptr = reinterpret_cast<__nv_bfloat16 const*>(k_weight.data_ptr());
auto const* pos_ptr = reinterpret_cast<int const*>(position_ids.data_ptr());

fusedQKNormRopeKernel<HEAD_DIM, INTERLEAVE><<<gridSize, blockSize, 0, stream>>>(
qkv_ptr,
num_heads_q,
num_heads_k,
num_heads_v,
eps,
qw_ptr,
kw_ptr,
base,
pos_ptr,
num_tokens,
factor,
low,
high,
attention_factor,
rotary_dim);
fusedQKNormRopeKernel<JIT_HEAD_DIM, static_cast<bool>(JIT_INTERLEAVE), static_cast<bool>(JIT_YARN)>
<<<gridSize, blockSize, 0, stream>>>(
qkv_ptr,
num_heads_q,
num_heads_k,
num_heads_v,
eps,
qw_ptr,
kw_ptr,
base,
pos_ptr,
num_tokens,
factor,
low,
high,
attention_factor,
rotary_dim);
}

} // namespace
Loading
Loading