diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py index 8c120cd0f962..905d41d5d6b0 100644 --- a/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py @@ -1,8 +1,8 @@ """ Benchmark: fused_qknorm_rope JIT vs AOT (sgl_kernel) -Measures throughput (us) for fused_qk_norm_rope across typical -LLM configurations (head_dim x num_heads x num_tokens). +Measures throughput (µs) for fused_qk_norm_rope across typical +LLM configurations (head_dim × num_heads × num_tokens). Run: python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py @@ -39,7 +39,7 @@ ci_range=[64, 512], ) -# (head_dim, num_heads_q, num_heads_k, num_heads_v) - typical MoE/dense configs +# (head_dim, num_heads_q, num_heads_k, num_heads_v) — typical MoE/dense configs MODEL_CONFIGS = get_benchmark_range( full_range=[ (64, 32, 8, 8), # small @@ -49,6 +49,16 @@ ci_range=[(128, 32, 8, 8)], ) +# Real production shapes (self-attention; num_heads_k == num_heads_v == num_heads_q). +# Format: (name, num_tokens, num_heads_q, num_heads_k, num_heads_v, head_dim, rotary_dim) +PRODUCTION_SHAPES = [ + ("flux_1024", 4096, 24, 24, 24, 128, 128), + ("qwen_image_1024", 4096, 32, 32, 32, 128, 128), + ("qwen_image_partial", 4096, 32, 32, 32, 128, 64), + ("zimage_1024", 4096, 30, 30, 30, 128, 128), + ("batch2_medium", 4096, 24, 24, 24, 128, 128), # B=2, T=2048 +] + LINE_VALS = ["jit", "aot"] if AOT_AVAILABLE else ["jit"] LINE_NAMES = ["JIT (new)", "AOT sgl_kernel"] if AOT_AVAILABLE else ["JIT (new)"] STYLES = [("blue", "--"), ("orange", "-")] if AOT_AVAILABLE else [("blue", "--")] @@ -123,6 +133,75 @@ def bench_fused_qknorm_rope( return run_benchmark(fn) +# --------------------------------------------------------------------------- +# Benchmark: fused_qk_norm_rope — real production shapes (with speedup column) +# --------------------------------------------------------------------------- + + +def bench_fused_qknorm_rope_production(): + device = "cuda" + header = f"{'name':<22} {'tokens':>6} {'nq':>4} {'nk':>4} {'nv':>4} {'hd':>4} {'rdim':>5} {'JIT(us)':>9} {'AOT(us)':>9} {'speedup':>8}" + sep = "-" * len(header) + print("\nfused-qknorm-rope-production-shapes:") + print(sep) + print(header) + print(sep) + + for ( + name, + num_tokens, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + rotary_dim, + ) in PRODUCTION_SHAPES: + total_heads = num_heads_q + num_heads_k + num_heads_v + qkv = torch.randn( + (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device + ) + q_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + k_weight = torch.ones(head_dim, dtype=torch.bfloat16, device=device) + position_ids = torch.arange(num_tokens, dtype=torch.int32, device=device) + + common_kwargs = dict( + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + eps=1e-5, + q_weight=q_weight, + k_weight=k_weight, + base=10000.0, + is_neox=False, + position_ids=position_ids, + factor=1.0, + low=1.0, + high=32.0, + attention_factor=1.0, + rotary_dim=rotary_dim, + ) + + jit_us, _, _ = run_benchmark( + lambda: fused_qk_norm_rope_jit(qkv.clone(), **common_kwargs) + ) + if AOT_AVAILABLE: + aot_us, _, _ = run_benchmark( + lambda: fused_qk_norm_rope_aot(qkv.clone(), **common_kwargs) + ) + speedup = f"{aot_us / jit_us:.2f}x" + aot_str = f"{aot_us:9.3f}" + else: + aot_str = f"{'N/A':>9}" + speedup = "N/A" + + print( + f"{name:<22} {num_tokens:>6} {num_heads_q:>4} {num_heads_k:>4} {num_heads_v:>4}" + f" {head_dim:>4} {rotary_dim:>5} {jit_us:9.3f} {aot_str} {speedup:>8}" + ) + print(sep) + + # --------------------------------------------------------------------------- # Quick correctness diff # --------------------------------------------------------------------------- @@ -130,7 +209,7 @@ def bench_fused_qknorm_rope( def calculate_diff(): if not AOT_AVAILABLE: - print("sgl_kernel not available - skipping AOT diff check") + print("sgl_kernel not available — skipping AOT diff check") return device = "cuda" @@ -184,3 +263,5 @@ def calculate_diff(): calculate_diff() print() bench_fused_qknorm_rope.run(print_data=True) + print() + bench_fused_qknorm_rope_production() diff --git a/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh index 40401572b3b8..1c1f41dccf47 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh @@ -39,11 +39,11 @@ namespace { // When factor != 1.0, blends interpolated and extrapolated frequencies. // --------------------------------------------------------------------------- -__device__ inline float -compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float low, float high) { +template +__device__ inline float compute_freq(float base, int rotary_dim, int half_dim, float factor, float low, float high) { float freq = powf(base, -2.0f * half_dim / static_cast(rotary_dim)); - if (factor != 1.0f) { + if constexpr (yarn) { float inv_freq_extrapolation = freq; float inv_freq_interpolation = freq / factor; @@ -68,11 +68,14 @@ compute_freq_yarn(float base, int rotary_dim, int half_dim, float factor, float // // Each warp processes one (token, head) pair. // head_dim: compile-time head dimension (64, 128, or 256) -// interleave: true -> interleave / GPT-J style RoPE (!is_neox) -// false -> NeoX style RoPE (is_neox) +// interleave: true → interleave / GPT-J style RoPE (!is_neox) +// false → NeoX style RoPE (is_neox) // --------------------------------------------------------------------------- -template +// interleave (GPT-J) pairs (2k,2k+1) share the same freq/theta, +// so sin/cos is computed once per pair and copied to the odd element, +// halving powf + __sincosf calls vs a naive per-element approach. +template __global__ void fusedQKNormRopeKernel( __nv_bfloat16* qkv, // [num_tokens, (nq+nk+nv)*head_dim], in-place int const num_heads_q, @@ -139,36 +142,65 @@ __global__ void fusedQKNormRopeKernel( // Apply RMSNorm // ------------------------------------------------------------------- float rms_rcp = rsqrtf(sumOfSquares / static_cast(head_dim) + eps); - for (int i = 0; i < numElemsPerThread; i++) { - int dim = laneId * numElemsPerThread + i; - float weight = isQ ? device::cast(q_weight[dim]) : device::cast(k_weight[dim]); - elements[i] *= rms_rcp * weight; + { + vec_T wvec; + wvec.load((isQ ? q_weight : k_weight) + offsetThread - offsetWarp); + for (int i = 0; i < numElemsPerThread; i++) { + elements[i] *= rms_rcp * device::cast(wvec[i]); + } } // ------------------------------------------------------------------- // Apply RoPE to the first rotary_dim elements // ------------------------------------------------------------------- - float elements2[numElemsPerThread]; - float cos_vals[numElemsPerThread]; - float sin_vals[numElemsPerThread]; float pos_id = static_cast(position_ids[tokenIdx]); int const rotary_lanes = rotary_dim / numElemsPerThread; bool const applyRotary = (laneId < rotary_lanes); if (applyRotary) { if constexpr (interleave) { - // Interleave (GPT-J) style: pairs of consecutive elements share a frequency - for (int i = 0; i < numElemsPerThread; i++) { - elements2[i] = (i % 2 == 0) ? -elements[i + 1] : elements[i - 1]; + // Pairs (2k, 2k+1) share the same half_dim → same freq/theta. + // numElemsPerThread is always even (head_dim/32, head_dim in {64,128,256}), + // so we step by 2 and handle both elements of each pair per iteration. + // + // freq follows a geometric series across pairs: freq[k] = freq[0] * ratio^k, + // where ratio = base^(-2/rotary_dim). Pre-compute both outside the loop to + // replace all but the first powf call with a single multiply per iteration. + // + // sin/cos are applied immediately to e0/e1, eliminating the elements2, + // cos_vals, sin_vals intermediate arrays and reducing register pressure. + int const half_dim_start = laneId * numElemsPerThread / 2; + float freq = powf(base, -2.0f * static_cast(half_dim_start) / static_cast(rotary_dim)); + float const freq_ratio = powf(base, -2.0f / static_cast(rotary_dim)); + + for (int i = 0; i < numElemsPerThread; i += 2) { + float e0 = elements[i]; + float e1 = elements[i + 1]; + + float f = freq; + if constexpr (yarn) { + int half_dim = half_dim_start + i / 2; + float inv_freq_interpolation = freq / factor; + float high_adj = (fabsf(low - high) <= 1e-6f) ? high + 0.001f : high; + float linear_func = (static_cast(half_dim) - low) / (high_adj - low); + float ramp_func = fminf(fmaxf(linear_func, 0.0f), 1.0f); + float extrap_factor = 1.0f - ramp_func; + f = inv_freq_interpolation * (1.0f - extrap_factor) + freq * extrap_factor; + } - int dim_idx = laneId * numElemsPerThread + i; - int half_dim = dim_idx / 2; - float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); - float theta = pos_id * freq; - __sincosf(theta, &sin_vals[i], &cos_vals[i]); + float s, c; + __sincosf(pos_id * f, &s, &c); + elements[i] = (e0 * c - e1 * s) * attention_factor; + elements[i + 1] = (e1 * c + e0 * s) * attention_factor; + + freq *= freq_ratio; } } else { // NeoX style: first and second halves of the rotary region are paired + float elements2[numElemsPerThread]; + float cos_vals[numElemsPerThread]; + float sin_vals[numElemsPerThread]; + __syncwarp(); int const half_rotary_lanes = rotary_lanes / 2; // Avoid UB from (1u << 32) when rotary_lanes == 32 @@ -183,15 +215,15 @@ __global__ void fusedQKNormRopeKernel( // Remap so that both halves use the same set of frequencies dim_idx = (dim_idx * 2) % rotary_dim; int half_dim = dim_idx / 2; - float freq = compute_freq_yarn(base, rotary_dim, half_dim, factor, low, high); + float freq = compute_freq(base, rotary_dim, half_dim, factor, low, high); float theta = pos_id * freq; __sincosf(theta, &sin_vals[i], &cos_vals[i]); } __syncwarp(); - } - for (int i = 0; i < numElemsPerThread; i++) { - elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + for (int i = 0; i < numElemsPerThread; i++) { + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + } } } @@ -209,14 +241,8 @@ __global__ void fusedQKNormRopeKernel( // --------------------------------------------------------------------------- // Host-side tvm-ffi entry point -// -// HEAD_DIM and INTERLEAVE are compile-time template parameters, passed as -// template arguments from Python via the cuda_wrappers specialisation in -// fused_qknorm_rope.py (e.g. fused_qk_norm_rope<128, false>). This avoids -// both runtime dispatch and macro-based specialisation. // --------------------------------------------------------------------------- -template void fused_qk_norm_rope( tvm::ffi::TensorView qkv, // [num_tokens, (nq+nk+nv)*head_dim] bf16 tvm::ffi::TensorView q_weight, // [head_dim] bf16 @@ -225,8 +251,10 @@ void fused_qk_norm_rope( int num_heads_q, int num_heads_k, int num_heads_v, + int head_dim, float eps, float base, + int is_neox, // 0 = interleave style, 1 = NeoX style float factor, float low, float high, @@ -234,8 +262,6 @@ void fused_qk_norm_rope( int rotary_dim) { using namespace host; - static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256, "HEAD_DIM must be 64, 128, or 256"); - RuntimeCheck(qkv.device().device_type == kDLCUDA, "qkv must be a CUDA tensor"); RuntimeCheck(qkv.is_contiguous(), "qkv must be contiguous"); RuntimeCheck(qkv.dtype().code == kDLBfloat && qkv.dtype().bits == 16, "qkv must be bfloat16"); @@ -244,12 +270,12 @@ void fused_qk_norm_rope( RuntimeCheck(q_weight.is_contiguous(), "q_weight must be contiguous"); RuntimeCheck(q_weight.dtype().code == kDLBfloat && q_weight.dtype().bits == 16, "q_weight must be bfloat16"); RuntimeCheck( - q_weight.ndim() == 1 && static_cast(q_weight.size(0)) == HEAD_DIM, "q_weight must be 1D of size head_dim"); + q_weight.ndim() == 1 && static_cast(q_weight.size(0)) == head_dim, "q_weight must be 1D of size head_dim"); RuntimeCheck(k_weight.is_contiguous(), "k_weight must be contiguous"); RuntimeCheck(k_weight.dtype().code == kDLBfloat && k_weight.dtype().bits == 16, "k_weight must be bfloat16"); RuntimeCheck( - k_weight.ndim() == 1 && static_cast(k_weight.size(0)) == HEAD_DIM, "k_weight must be 1D of size head_dim"); + k_weight.ndim() == 1 && static_cast(k_weight.size(0)) == head_dim, "k_weight must be 1D of size head_dim"); RuntimeCheck(position_ids.device().device_type == kDLCUDA, "position_ids must be a CUDA tensor"); RuntimeCheck(position_ids.is_contiguous(), "position_ids must be contiguous"); @@ -259,13 +285,20 @@ void fused_qk_norm_rope( int num_tokens = static_cast(qkv.size(0)); int total_heads = num_heads_q + num_heads_k + num_heads_v; RuntimeCheck( - static_cast(qkv.size(1)) == total_heads * HEAD_DIM, "qkv.size(1) must equal (nq + nk + nv) * head_dim"); + static_cast(qkv.size(1)) == total_heads * head_dim, "qkv.size(1) must equal (nq + nk + nv) * head_dim"); RuntimeCheck(static_cast(position_ids.size(0)) == num_tokens, "position_ids must have num_tokens elements"); - constexpr int numElemsPerThread = HEAD_DIM / 32; + static_assert( + JIT_HEAD_DIM == 64 || JIT_HEAD_DIM == 128 || JIT_HEAD_DIM == 256, "JIT_HEAD_DIM must be 64, 128, or 256"); + static_assert(JIT_INTERLEAVE == 0 || JIT_INTERLEAVE == 1, "JIT_INTERLEAVE must be 0 or 1"); + static_assert(JIT_YARN == 0 || JIT_YARN == 1, "JIT_YARN must be 0 or 1"); + RuntimeCheck(head_dim == JIT_HEAD_DIM, "head_dim mismatch with JIT-compiled kernel"); + + int numElemsPerThread = head_dim / 32; RuntimeCheck(rotary_dim % numElemsPerThread == 0, "rotary_dim must be divisible by (head_dim / 32)"); - if constexpr (!INTERLEAVE) { + bool neox = static_cast(is_neox); + if (neox) { // NeoX uses __shfl_xor_sync which requires half_rotary_lanes to be a power of 2 int rotary_lanes = rotary_dim / numElemsPerThread; int half_rotary_lanes = rotary_lanes / 2; @@ -273,35 +306,41 @@ void fused_qk_norm_rope( RuntimeCheck(is_pow2, "half_rotary_lanes must be a power of 2 for NeoX style RoPE"); } + bool interleave = !neox; + RuntimeCheck(interleave == static_cast(JIT_INTERLEAVE), "interleave mismatch with JIT-compiled kernel"); + bool use_yarn = (factor != 1.0f); + RuntimeCheck(use_yarn == static_cast(JIT_YARN), "yarn mismatch with JIT-compiled kernel"); + cudaStream_t stream = LaunchKernel::resolve_device(qkv.device()); constexpr int blockSize = 256; int warpsPerBlock = blockSize / 32; int totalQKHeads = num_heads_q + num_heads_k; int totalWarps = num_tokens * totalQKHeads; - int gridSize = host::div_ceil(totalWarps, warpsPerBlock); + int gridSize = div_ceil(totalWarps, warpsPerBlock); auto* qkv_ptr = reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()); auto const* qw_ptr = reinterpret_cast<__nv_bfloat16 const*>(q_weight.data_ptr()); auto const* kw_ptr = reinterpret_cast<__nv_bfloat16 const*>(k_weight.data_ptr()); auto const* pos_ptr = reinterpret_cast(position_ids.data_ptr()); - fusedQKNormRopeKernel<<>>( - qkv_ptr, - num_heads_q, - num_heads_k, - num_heads_v, - eps, - qw_ptr, - kw_ptr, - base, - pos_ptr, - num_tokens, - factor, - low, - high, - attention_factor, - rotary_dim); + fusedQKNormRopeKernel(JIT_INTERLEAVE), static_cast(JIT_YARN)> + <<>>( + qkv_ptr, + num_heads_q, + num_heads_k, + num_heads_v, + eps, + qw_ptr, + kw_ptr, + base, + pos_ptr, + num_tokens, + factor, + low, + high, + attention_factor, + rotary_dim); } } // namespace diff --git a/python/sglang/jit_kernel/fused_qknorm_rope.py b/python/sglang/jit_kernel/fused_qknorm_rope.py index 92ea1f4350ad..00e872020709 100644 --- a/python/sglang/jit_kernel/fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/fused_qknorm_rope.py @@ -13,17 +13,20 @@ @cache_once -def _jit_fused_qknorm_rope_module(head_dim: int, is_neox: bool) -> Module: - interleave = "false" if is_neox else "true" +def _jit_fused_qknorm_rope_module(head_dim: int, is_neox: bool, yarn: bool) -> Module: return load_jit( "fused_qknorm_rope", head_dim, int(is_neox), + int(yarn), cuda_files=["elementwise/fused_qknorm_rope.cuh"], - cuda_wrappers=[ - ("fused_qk_norm_rope", f"fused_qk_norm_rope<{head_dim}, {interleave}>") + cuda_wrappers=[("fused_qk_norm_rope", "fused_qk_norm_rope")], + extra_cuda_cflags=[ + "--use_fast_math", + f"-DJIT_HEAD_DIM={head_dim}", + f"-DJIT_INTERLEAVE={0 if is_neox else 1}", + f"-DJIT_YARN={1 if yarn else 0}", ], - extra_cuda_cflags=["--use_fast_math"], ) @@ -55,9 +58,9 @@ def fused_qk_norm_rope_out( Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. Args: - qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place - q_weight: [head_dim] bfloat16 -RMSNorm weights for Q - k_weight: [head_dim] bfloat16 -RMSNorm weights for K + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 — modified in-place + q_weight: [head_dim] bfloat16 — RMSNorm weights for Q + k_weight: [head_dim] bfloat16 — RMSNorm weights for K position_ids: [num_tokens] int32 num_heads_q: number of query heads num_heads_k: number of key heads @@ -65,14 +68,15 @@ def fused_qk_norm_rope_out( head_dim: head dimension; must be 64, 128, or 256 eps: epsilon for RMSNorm base: RoPE base frequency - is_neox: True ->NeoX style, False ->interleave (GPT-J) style + is_neox: True → NeoX style, False → interleave (GPT-J) style factor: YaRN scaling factor (1.0 = standard RoPE) low: YaRN low-frequency threshold high: YaRN high-frequency threshold attention_factor: scale applied to the rotary component rotary_dim: number of elements per head to apply RoPE to """ - module = _jit_fused_qknorm_rope_module(head_dim, is_neox) + yarn = factor != 1.0 + module = _jit_fused_qknorm_rope_module(head_dim, is_neox, yarn) module.fused_qk_norm_rope( qkv, q_weight, @@ -81,8 +85,10 @@ def fused_qk_norm_rope_out( num_heads_q, num_heads_k, num_heads_v, + head_dim, eps, base, + 1 if is_neox else 0, factor, low, high, @@ -93,13 +99,16 @@ def fused_qk_norm_rope_out( @cache_once def can_use_fused_qk_norm_rope( - head_dim: int, is_neox: bool, dtype: torch.dtype + head_dim: int, is_neox: bool, dtype: torch.dtype, yarn: bool = False ) -> bool: """Return True if the JIT fused QK-Norm + RoPE kernel can be used. Args: head_dim: head dimension; supported values are 64, 128, 256 dtype: tensor dtype; only bfloat16 is supported + yarn: whether YaRN scaling is active (factor != 1.0); prebuilds the + correct kernel variant so no extra JIT compile occurs on the + first real call. """ logger = logging.getLogger(__name__) if head_dim not in (64, 128, 256): @@ -111,7 +120,7 @@ def can_use_fused_qk_norm_rope( logger.warning(f"Unsupported dtype={dtype} for JIT fused_qk_norm_rope kernel") return False try: - _jit_fused_qknorm_rope_module(head_dim, is_neox) + _jit_fused_qknorm_rope_module(head_dim, is_neox, yarn) return True except Exception as e: logger.warning(f"Failed to load JIT fused_qk_norm_rope kernel: {e}") @@ -142,16 +151,16 @@ def fused_qk_norm_rope( Matches the call signature of ``sgl_kernel.fused_qk_norm_rope``. Args: - qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 -modified in-place + qkv: [num_tokens, (nq+nk+nv)*head_dim] bfloat16 — modified in-place num_heads_q: number of query heads num_heads_k: number of key heads num_heads_v: number of value heads head_dim: head dimension; must be 64, 128, or 256 eps: epsilon for RMSNorm - q_weight: [head_dim] bfloat16 -RMSNorm weights for Q - k_weight: [head_dim] bfloat16 -RMSNorm weights for K + q_weight: [head_dim] bfloat16 — RMSNorm weights for Q + k_weight: [head_dim] bfloat16 — RMSNorm weights for K base: RoPE base frequency - is_neox: True ->NeoX style, False ->interleave (GPT-J) style + is_neox: True → NeoX style, False → interleave (GPT-J) style position_ids: [num_tokens] int32 factor: YaRN scaling factor (1.0 = standard RoPE) low: YaRN low-frequency threshold diff --git a/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py index 898683d15c3c..0843db13f217 100644 --- a/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py +++ b/python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py @@ -122,7 +122,7 @@ def apply_interleave(x): q = apply_interleave(q) k = apply_interleave(k) else: - # NeoX style: first half * cos - second half * sin (and vice versa) + # NeoX style: first half × cos − second half × sin (and vice versa) def apply_neox(x): # x: [num_tokens, n_heads, head_dim] x1 = x[:, :, : rotary_dim // 2] @@ -231,7 +231,7 @@ def test_fused_qknorm_rope_partial_rotary(head_dim, is_neox): # NeoX requires half_rotary_lanes to be power of 2. # half_rotary_lanes = rotary_dim / (head_dim / 32) / 2 = (head_dim//2) / (head_dim/32) / 2 - # = 16 / 2 = 8 -> power of 2, OK for all supported head_dims. + # = 16 / 2 = 8 → power of 2, OK for all supported head_dims. qkv = torch.randn( (num_tokens, total_heads * head_dim), dtype=torch.bfloat16, device=device diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 010a73074759..912891b6a7eb 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -513,6 +513,7 @@ def __init__( self.compatible_with_fused_qk_norm_rope = not isinstance( self.rotary_emb, MRotaryEmbedding ) and self.head_dim in (64, 128, 256) + _yarn_factor, _, _, _ = compute_yarn_parameters(config) self.use_fused_qk_norm_rope = ( get_global_server_args().enable_fused_qk_norm_rope and self.compatible_with_fused_qk_norm_rope @@ -521,6 +522,7 @@ def __init__( self.head_dim, self.rotary_emb.is_neox_style, torch.bfloat16, + _yarn_factor != 1.0, ) ) self._used_fused_qk_norm_rope_last_call = False