diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 2f2e7ecc1829..589e5f7bac04 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -87,6 +87,12 @@ constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 constexpr float kFp8Max = 448.0f; +#ifndef USE_ROCM +// When num_tokens is less than this threshold, +// run the reduced grid variant on cuda +constexpr float NUM_TOKEN_CUTOFF = 1024; +#endif + // Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM. constexpr int kNumLanes = 32; constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16 @@ -112,6 +118,197 @@ __device__ __forceinline__ float warpSum(float val) { return val; } +// ──────────────────────────────────────────────────────────────────────────── +// Per-slot inner pipeline +// ──────────────────────────────────────────────────────────────────────────── +// Shared by both kernel variants: 1 CTA per (token, head) pair vs. 1 CTA per +// token +template +__device__ __forceinline__ void processDeepseekV4Slot( + uint4 v0, uint4 v1, int const tokenIdx, int const slotIdx, + int const dim_base, int const laneId, int const num_heads_q, + float const eps, scalar_t_in* __restrict__ q_inout, + uint8_t* __restrict__ k_cache, int64_t const* __restrict__ slot_mapping, + int64_t const* __restrict__ position_ids, + float const* __restrict__ cos_sin_cache, int const cache_block_size, + int const kv_block_stride) { + using Converter = vllm::_typeConvert; + bool const isKV = (slotIdx == num_heads_q); + + // ── Decode the bf16 → 16 fp32 registers ───────────────────────────── + float elements[kElemsPerLane]; + { + typename Converter::packed_hip_type const* p0 = + reinterpret_cast(&v0); + typename Converter::packed_hip_type const* p1 = + reinterpret_cast(&v1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 f2 = Converter::convert(p0[i]); + elements[2 * i] = f2.x; + elements[2 * i + 1] = f2.y; + } +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 f2 = Converter::convert(p1[i]); + elements[8 + 2 * i] = f2.x; + elements[8 + 2 * i + 1] = f2.y; + } + } + + // ── Q branch: RMSNorm (no weight) ─────────────────────────────────── + if (!isKV) { + float sumOfSquares = 0.0f; +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + sumOfSquares += elements[i] * elements[i]; + } + sumOfSquares = warpSum(sumOfSquares); + float const rms_rcp = + rsqrtf(sumOfSquares / static_cast(kHeadDim) + eps); +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + elements[i] = elements[i] * rms_rcp; + } + } + + // ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────── + // All math in fp32. cos_sin_cache is loaded as fp32 (its native storage). + bool const is_rope_lane = dim_base >= kNopeDim; + if (is_rope_lane) { + int64_t const pos = position_ids[tokenIdx]; + constexpr int kHalfRope = kRopeDim / 2; + float const* cos_ptr = cos_sin_cache + pos * kRopeDim; + float const* sin_ptr = cos_ptr + kHalfRope; + + int const rope_local_base = dim_base - kNopeDim; + int const half_base = rope_local_base >> 1; + + // Load phase: 4 vectorized LDGs issue back-to-back. + float4 const c0 = *reinterpret_cast(cos_ptr + half_base); + float4 const c1 = *reinterpret_cast(cos_ptr + half_base + 4); + float4 const s0 = *reinterpret_cast(sin_ptr + half_base); + float4 const s1 = *reinterpret_cast(sin_ptr + half_base + 4); + float const cos_arr[8] = {c0.x, c0.y, c0.z, c0.w, c1.x, c1.y, c1.z, c1.w}; + float const sin_arr[8] = {s0.x, s0.y, s0.z, s0.w, s1.x, s1.y, s1.z, s1.w}; + +#pragma unroll + for (int p = 0; p < kElemsPerLane / 2; p++) { + float const x_even = elements[2 * p]; + float const x_odd = elements[2 * p + 1]; + elements[2 * p] = x_even * cos_arr[p] - x_odd * sin_arr[p]; + elements[2 * p + 1] = x_even * sin_arr[p] + x_odd * cos_arr[p]; + } + } + + // ═══════════════════════════════════════════════════════════════════ + // Q / KV branch dispatch. Restructured as if/else (no early `return`) + // so every code path lands at the same exit point — callers own PDL + // triggering and per-iteration buffer rotation. + // ═══════════════════════════════════════════════════════════════════ + if (!isKV) { + // ── Q: cast back to bf16 and store. ──────────────────────────── + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = + Converter::convert(make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + scalar_t_in* dst = + q_inout + + (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + + dim_base; + *reinterpret_cast(dst) = out0; + *reinterpret_cast(dst + 8) = out1; + } else { + // ── KV: FP8 quant on NoPE + bf16 store on RoPE + cache insert. + int64_t const slot_id = slot_mapping[tokenIdx]; + if (slot_id >= 0) { + int64_t const block_idx = slot_id / cache_block_size; + int64_t const pos_in_block = slot_id % cache_block_size; + uint8_t* block_base = + k_cache + block_idx * static_cast(kv_block_stride); + uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes; + uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim; + uint8_t* token_scale_ptr = + block_base + + static_cast(cache_block_size) * kTokenDataBytes + + pos_in_block * kScaleBytesPerToken; + +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + elements[i] = Converter::convert(Converter::convert(elements[i])); + } + + float local_absmax = 0.0f; +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + local_absmax = fmaxf(local_absmax, fabsf(elements[i])); + } + float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f); + float const exponent = ceilf(log2f(absmax / kFp8Max)); + float const inv_scale = exp2f(-exponent); + + if (!is_rope_lane) { + uint8_t out_bytes[kElemsPerLane]; +#pragma unroll + for (int i = 0; i < kElemsPerLane; i++) { + float scaled = elements[i] * inv_scale; + scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); +#ifndef USE_ROCM + __nv_fp8_storage_t s = + __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); + out_bytes[i] = static_cast(s); +#else + out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); +#endif + } + *reinterpret_cast(token_fp8_ptr + dim_base) = + *reinterpret_cast(out_bytes); + + if ((laneId & 3) == 0) { + int const q_block_idx = laneId >> 2; + float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f); + token_scale_ptr[q_block_idx] = static_cast(encoded); + } + if (laneId == 0) { + token_scale_ptr[kNumQuantBlocks] = 0; + } + } else { + uint4 out0, out1; + typename Converter::packed_hip_type* po0 = + reinterpret_cast(&out0); + typename Converter::packed_hip_type* po1 = + reinterpret_cast(&out1); +#pragma unroll + for (int i = 0; i < 4; i++) { + po0[i] = Converter::convert( + make_float2(elements[2 * i], elements[2 * i + 1])); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + po1[i] = Converter::convert( + make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); + } + int const rope_local_base = dim_base - kNopeDim; + scalar_t_in* bf16_dst = + reinterpret_cast(token_bf16_ptr) + rope_local_base; + *reinterpret_cast(bf16_dst) = out0; + *reinterpret_cast(bf16_dst + 8) = out1; + } + } + } +} + // ──────────────────────────────────────────────────────────────────────────── // Kernel // ──────────────────────────────────────────────────────────────────────────── @@ -149,8 +346,6 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( return; } else { #endif - using Converter = vllm::_typeConvert; - int const warpsPerBlock = blockDim.x / 32; int const warpId = threadIdx.x / 32; int const laneId = threadIdx.x % 32; @@ -176,10 +371,8 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( // Dim range this lane owns within the 512-wide head. int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16 - // ── Load 16 bf16 → 16 fp32 registers (one 16-byte + one 16-byte LDG) ──── - float elements[kElemsPerLane]; - float sumOfSquares = 0.0f; - + // Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector + // type; the shared per-slot helper bitcasts to scalar_t_in packed pairs. scalar_t_in const* src_ptr; if (isKV) { src_ptr = kv_in + static_cast(tokenIdx) * kHeadDim + dim_base; @@ -189,196 +382,103 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel( dim_base; src_ptr = q_inout + q_row_offset; } + uint4 const v0 = *reinterpret_cast(src_ptr); + uint4 const v1 = *reinterpret_cast(src_ptr + 8); - // Two 16-byte loads per thread (8 bf16 each). Use uint4 as the vector - // type and bitcast to scalar_t_in packed pairs for conversion. - uint4 v0 = *reinterpret_cast(src_ptr); - uint4 v1 = *reinterpret_cast(src_ptr + 8); - - { - typename Converter::packed_hip_type const* p0 = - reinterpret_cast(&v0); - typename Converter::packed_hip_type const* p1 = - reinterpret_cast(&v1); -// Each packed_hip_type holds 2 bf16 → 4 packed = 8 elems per uint4. -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 f2 = Converter::convert(p0[i]); - elements[2 * i] = f2.x; - elements[2 * i + 1] = f2.y; - } -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 f2 = Converter::convert(p1[i]); - elements[8 + 2 * i] = f2.x; - elements[8 + 2 * i + 1] = f2.y; - } - } + processDeepseekV4Slot( + v0, v1, tokenIdx, slotIdx, dim_base, laneId, num_heads_q, eps, q_inout, + k_cache, slot_mapping, position_ids, cos_sin_cache, cache_block_size, + kv_block_stride); - // ── Q branch: RMSNorm with no weight (has_weight=False) ───────────────── - // Variance + rsqrt + multiply all in fp32, no intermediate bf16 round. - // The downstream bf16 round only happens at the final store. - if (!isKV) { -#pragma unroll - for (int i = 0; i < kElemsPerLane; i++) { - sumOfSquares += elements[i] * elements[i]; - } - sumOfSquares = warpSum(sumOfSquares); - float const rms_rcp = - rsqrtf(sumOfSquares / static_cast(kHeadDim) + eps); -#pragma unroll - for (int i = 0; i < kElemsPerLane; i++) { - elements[i] = elements[i] * rms_rcp; - } - } - - // ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────── - // All math in fp32. cos_sin_cache is loaded as fp32 (its native storage). - bool const is_rope_lane = dim_base >= kNopeDim; - if (is_rope_lane) { - int64_t const pos = position_ids[tokenIdx]; - constexpr int kHalfRope = kRopeDim / 2; // 32 - float const* cos_ptr = cos_sin_cache + pos * kRopeDim; - float const* sin_ptr = cos_ptr + kHalfRope; - - int const rope_local_base = dim_base - kNopeDim; // in [0, 64) step 16 -#pragma unroll - for (int p = 0; p < kElemsPerLane / 2; p++) { - int const pair_dim = rope_local_base + 2 * p; - int const half_idx = pair_dim / 2; - float const cos_v = VLLM_LDG(cos_ptr + half_idx); - float const sin_v = VLLM_LDG(sin_ptr + half_idx); - float const x_even = elements[2 * p]; - float const x_odd = elements[2 * p + 1]; - elements[2 * p] = x_even * cos_v - x_odd * sin_v; - elements[2 * p + 1] = x_even * sin_v + x_odd * cos_v; - } - } - - // ═══════════════════════════════════════════════════════════════════════ - // Q branch: cast to bf16 and store back in place. - // ═══════════════════════════════════════════════════════════════════════ - if (!isKV) { - uint4 out0, out1; - typename Converter::packed_hip_type* po0 = - reinterpret_cast(&out0); - typename Converter::packed_hip_type* po1 = - reinterpret_cast(&out1); -#pragma unroll - for (int i = 0; i < 4; i++) { - po0[i] = Converter::convert( - make_float2(elements[2 * i], elements[2 * i + 1])); - } -#pragma unroll - for (int i = 0; i < 4; i++) { - po1[i] = Converter::convert( - make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); - } - scalar_t_in* dst = - q_inout + - (static_cast(tokenIdx) * num_heads_q + slotIdx) * kHeadDim + - dim_base; - *reinterpret_cast(dst) = out0; - *reinterpret_cast(dst + 8) = out1; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); + cudaTriggerProgrammaticLaunchCompletion(); #endif - return; - } - - // ═══════════════════════════════════════════════════════════════════════ - // KV branch. - // ═══════════════════════════════════════════════════════════════════════ - int64_t const slot_id = slot_mapping[tokenIdx]; - if (slot_id < 0) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + } #endif - return; - } +} - int64_t const block_idx = slot_id / cache_block_size; - int64_t const pos_in_block = slot_id % cache_block_size; - uint8_t* block_base = - k_cache + block_idx * static_cast(kv_block_stride); - uint8_t* token_fp8_ptr = block_base + pos_in_block * kTokenDataBytes; - uint8_t* token_bf16_ptr = token_fp8_ptr + kNopeDim; - uint8_t* token_scale_ptr = - block_base + static_cast(cache_block_size) * kTokenDataBytes + - pos_in_block * kScaleBytesPerToken; - - // Round K to bf16 first, matching the unfused reference path where K is - // materialized as bf16 before K quantization. absmax, clamp, and FP8 - // quant below all run on these bf16-rounded values. -#pragma unroll - for (int i = 0; i < kElemsPerLane; i++) { - elements[i] = Converter::convert(Converter::convert(elements[i])); - } +// ──────────────────────────────────────────────────────────────────────────── +// Kernel +// ──────────────────────────────────────────────────────────────────────────── +// +// Grid: 1D, gridDim.x = num_tokens_full +// Block: blockDim.x = 256 threads (8 warps per block) Each +// warp handles one token, iterating over each head. +// Q branch (RMSNorm + RoPE, in place) head_slot == num_heads_q +// KV branch (RoPE + UE8M0 quant + insert) +// +template +__global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernelReducedGrid( + scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place + scalar_t_in const* __restrict__ kv_in, uint8_t* __restrict__ k_cache, + int64_t const* __restrict__ slot_mapping, + int64_t const* __restrict__ position_ids, + float const* __restrict__ cos_sin_cache, float const eps, + int const num_tokens_full, int const num_tokens_insert, + int const num_heads_q, int const cache_block_size, + int const kv_block_stride) { +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + if constexpr (std::is_same_v) { + return; + } else { +#endif + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; - // Per-quant-block absmax must be computed by ALL 32 lanes (warp-collective - // shuffle requires full participation). RoPE lanes contribute garbage, - // but their values are gated out below via `!is_rope_lane`. - float local_absmax = 0.0f; -#pragma unroll - for (int i = 0; i < kElemsPerLane; i++) { - local_absmax = fmaxf(local_absmax, fabsf(elements[i])); - } - float const absmax = fmaxf(warp4MaxAbs(local_absmax), 1e-4f); - float const exponent = ceilf(log2f(absmax / kFp8Max)); - float const inv_scale = exp2f(-exponent); + int const tokenIdx = blockIdx.x; + if (tokenIdx >= num_tokens_full) return; - if (!is_rope_lane) { - // ── NoPE lane: UE8M0 FP8 quant ─────────────────────────────────────── - uint8_t out_bytes[kElemsPerLane]; -#pragma unroll - for (int i = 0; i < kElemsPerLane; i++) { - float scaled = elements[i] * inv_scale; - scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max); -#ifndef USE_ROCM - __nv_fp8_storage_t s = - __nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3); - out_bytes[i] = static_cast(s); -#else - out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaGridDependencySynchronize(); #endif + + int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16 + int const slot_end = + (tokenIdx >= num_tokens_insert) ? num_heads_q : (num_heads_q + 1); + + auto src_for_slot = [&](int s) -> scalar_t_in const* { + if (s == num_heads_q) { + return kv_in + static_cast(tokenIdx) * kHeadDim + dim_base; } - // One 16-byte STG per lane. - *reinterpret_cast(token_fp8_ptr + dim_base) = - *reinterpret_cast(out_bytes); - - // Lane (4k) of each 4-lane group writes the scale byte for block k<7. - if ((laneId & 3) == 0) { - int const q_block_idx = laneId >> 2; // 0..6 for NoPE lanes - float encoded = fmaxf(fminf(exponent + 127.0f, 255.0f), 0.0f); - token_scale_ptr[q_block_idx] = static_cast(encoded); - } - // Lane 0 also writes the padding byte at index 7. - if (laneId == 0) { - token_scale_ptr[kNumQuantBlocks] = 0; // pad - } - } else { - // ── RoPE lane: cast back to bf16 and store to cache bf16 tail ──────── - uint4 out0, out1; - typename Converter::packed_hip_type* po0 = - reinterpret_cast(&out0); - typename Converter::packed_hip_type* po1 = - reinterpret_cast(&out1); -#pragma unroll - for (int i = 0; i < 4; i++) { - po0[i] = Converter::convert( - make_float2(elements[2 * i], elements[2 * i + 1])); - } -#pragma unroll - for (int i = 0; i < 4; i++) { - po1[i] = Converter::convert( - make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1])); - } - int const rope_local_base = dim_base - kNopeDim; // in [0, 64) - scalar_t_in* bf16_dst = - reinterpret_cast(token_bf16_ptr) + rope_local_base; - *reinterpret_cast(bf16_dst) = out0; - *reinterpret_cast(bf16_dst + 8) = out1; - } + return q_inout + + (static_cast(tokenIdx) * num_heads_q + + static_cast(s)) * + kHeadDim + + dim_base; + }; + + if (warpId < slot_end) { + int curr_slot = warpId; + scalar_t_in const* src_curr = src_for_slot(curr_slot); + uint4 v0_curr = *reinterpret_cast(src_curr); + uint4 v1_curr = *reinterpret_cast(src_curr + 8); + + while (curr_slot < slot_end) { + int const next_slot = curr_slot + warpsPerBlock; + bool const has_next = (next_slot < slot_end); + + // Prefetch src for the next slot + uint4 v0_next, v1_next; + if (has_next) { + scalar_t_in const* src_next = src_for_slot(next_slot); + v0_next = *reinterpret_cast(src_next); + v1_next = *reinterpret_cast(src_next + 8); + } + + processDeepseekV4Slot( + v0_curr, v1_curr, tokenIdx, curr_slot, dim_base, laneId, + num_heads_q, eps, q_inout, k_cache, slot_mapping, position_ids, + cos_sin_cache, cache_block_size, kv_block_stride); + + // ── Buffer rotation: hand the prefetched LDGs to the next iter. + v0_curr = v0_next; + v1_curr = v1_next; + curr_slot = next_slot; + } // while + } // if (warpId < slot_end) + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) cudaTriggerProgrammaticLaunchCompletion(); #endif @@ -430,11 +530,22 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( config.attrs = attrs; config.numAttrs = (sm_version >= 90) ? 1 : 0; - cudaLaunchKernelEx( - &config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, - q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, - num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, - kv_block_stride); + if (num_tokens_full < NUM_TOKEN_CUTOFF) { + cudaLaunchKernelEx( + &config, fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel, + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, + num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride); + } else { + config.gridDim = dim3(num_tokens_full); + cudaLaunchKernelEx( + &config, + fusedDeepseekV4QNormRopeKVRopeQuantInsertKernelReducedGrid, + q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps, + num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size, + kv_block_stride); + } + #else // ROCm: use standard kernel launch syntax (no PDL/stream serialization) // clang-format off @@ -508,4 +619,4 @@ void fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size_i, kv_block_stride, stream); }); -} +} \ No newline at end of file diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 9706778ac86f..13010540d973 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -67,29 +67,26 @@ def apply_rope_gptj_last_k( head_dim = x.shape[-1] nope_dim = head_dim - rope_dim - # Gather cos/sin for each token position: [num_tokens, rope_dim] - cs = cos_sin_cache[positions].to(torch.float32) # [N, rope_dim] - cos = cs[..., :half] # [N, half] - sin = cs[..., half:] # [N, half] - - # Reshape leading dims so we can broadcast: x shape [..., head_dim]. - # Bring token dim to front; assume x is [num_tokens, ..., head_dim]. - # We rely on positions being per-token and all other dims sharing the same pos. - rope = x[..., nope_dim:].float() # [..., rope_dim] - # Make rope pairs: reshape last dim to [half, 2] + cs = cos_sin_cache[positions].to(torch.float32) + cos = cs[..., :half] + sin = cs[..., half:] + + rope = x[..., nope_dim:].float() shape = rope.shape rope = rope.reshape(*shape[:-1], half, 2) - even = rope[..., 0] # [..., half] + even = rope[..., 0] odd = rope[..., 1] - # Broadcast cos/sin over any heads dim in between. cos/sin are [N, half]. - # Add singleton dims for intermediate axes. for _ in range(rope.ndim - 3): cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) - new_even = even * cos - odd * sin - new_odd = even * sin + odd * cos + # Use addcmul (compiles to FMA on CUDA) for the 2x2 rotation. nvcc lowers + # the kernel's `e*c - o*s` to fma(e, c, -o*s); matching that here keeps + # near-cancellation pairs on the same bf16 grid as the kernel output and + # avoids spurious 1-ULP boundary flips at high num_tokens. + new_even = torch.addcmul(-odd * sin, even, cos) + new_odd = torch.addcmul(odd * cos, even, sin) rope_rotated = torch.stack((new_even, new_odd), dim=-1).reshape(shape) out = x.clone().float() @@ -99,11 +96,15 @@ def apply_rope_gptj_last_k( def rmsnorm_no_weight(x: torch.Tensor, eps: float) -> torch.Tensor: """RMSNorm with no learnable weight, matching - `RMSNorm(head_dim, has_weight=False)`.""" - orig_dtype = x.dtype + `RMSNorm(head_dim, has_weight=False)`. + + Returns fp32 so callers can chain RoPE without an intermediate bf16 round + (the kernel keeps the whole RMSNorm→RoPE pipeline in fp32 and rounds once + at the final store). + """ xf = x.float() variance = xf.pow(2).mean(dim=-1, keepdim=True) - return (xf * torch.rsqrt(variance + eps)).to(orig_dtype) + return xf * torch.rsqrt(variance + eps) # ── Dispatch to the CUDA op (skip test cleanly if it isn't built in) ───────── @@ -128,7 +129,7 @@ def _call_fused(q, kv, k_cache, slot_mapping, positions, cos_sin_cache, eps, bs) # ── Test 1: Q path numerical parity ────────────────────────────────────────── -@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64]) +@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64, 2048]) @pytest.mark.parametrize("n_heads", [8, 64]) def test_q_path_matches_reference(num_tokens: int, n_heads: int): torch.manual_seed(0) @@ -142,8 +143,10 @@ def test_q_path_matches_reference(num_tokens: int, n_heads: int): cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) # Reference: RMSNorm (no weight) per head, then GPT-J RoPE on last 64. + # Keep the chain in fp32 (rmsnorm_no_weight returns fp32) and round to + # bf16 once at the end, matching the kernel. q_ref = rmsnorm_no_weight(q, eps) - q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(dtype) # Fused call with dummy KV tensors (KV branch will write slot_mapping=-1 → noop). num_blocks = 2 @@ -173,7 +176,7 @@ def _ue8m0_per_block_scales(kv_roped_nope_f32: torch.Tensor, qblock: int): return torch.pow(2.0, exponent) # [n_tok, n_blocks] -@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64]) +@pytest.mark.parametrize("num_tokens", [1, 4, 17, 64, 2048]) @pytest.mark.parametrize("block_size", [16, 64]) def test_kv_path_matches_reference(num_tokens: int, block_size: int): torch.manual_seed(1) @@ -261,7 +264,7 @@ def _dequant(k_cache_2d): # ── Test 2b: DP padding (slot_mapping shorter than q/kv) ───────────────────── -@pytest.mark.parametrize("num_tokens", [4, 17]) +@pytest.mark.parametrize("num_tokens", [4, 17, 2048]) @pytest.mark.parametrize("pad", [1, 5]) @pytest.mark.parametrize("block_size", [16, 64]) def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int): @@ -312,7 +315,7 @@ def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int): # ── Test 3: combined single-call Q + KV parity ─────────────────────────────── -@pytest.mark.parametrize("num_tokens", [1, 4, 17]) +@pytest.mark.parametrize("num_tokens", [1, 4, 17, 2048]) @pytest.mark.parametrize("n_heads", [8, 64]) @pytest.mark.parametrize("block_size", [16, 64]) def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): @@ -332,7 +335,7 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): # Reference. q_ref = rmsnorm_no_weight(q, eps) - q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache) + q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(dtype) kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache) k_cache_ref = torch.zeros( num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device