diff --git a/benchmark/kernels/bench_flashmla_fused_kv.py b/benchmark/kernels/bench_flashmla_fused_kv.py new file mode 100644 index 00000000000..202e896c466 --- /dev/null +++ b/benchmark/kernels/bench_flashmla_fused_kv.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +""" +Microbenchmark: fused vs baseline (emulated) for MLA RoPE + FP8 + KV write. +Uses the sgl_kernel.mla_rope_quantize_fp8_fused extension. +""" +import time + +import torch +from sgl_kernel import mla_rope_quantize_fp8_fused + + +def run_one(nnz=1024, Dn=512, Dr=64, iters=200, warmup=20, device="cuda"): + torch.manual_seed(0) + + q_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16) + q_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16) + k_nope = torch.randn(nnz, Dn, device=device, dtype=torch.float16) + k_rope = torch.randn(nnz, Dr, device=device, dtype=torch.float16) + + max_seq = max(2048, nnz) + t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None] + idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :] + freqs = 0.1 * (idx + 1.0) + cos = torch.cos(t * freqs) + sin = torch.sin(t * freqs) + cos_sin = torch.cat([cos, sin], dim=1) + pos_ids = torch.randint( + low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long + ) + + slots = nnz + 8 + loc = torch.arange(nnz, device=device, dtype=torch.long) + + q_out = torch.empty(nnz, Dn + Dr, device=device, dtype=torch.uint8) + k_nope_out = torch.empty(nnz, Dn, device=device, dtype=torch.uint8) + k_rope_out = torch.empty(nnz, Dr, device=device, dtype=torch.uint8) + kv_base = torch.zeros(slots, Dn + Dr, device=device, dtype=torch.uint8) + + # baselines + def baseline(): + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, + q_out, + k_nope_out, + k_rope_out, + None, + None, + ) + kv_base.zero_() + kv_base[loc, :Dn] = k_nope_out + kv_base[loc, Dn:] = k_rope_out + + kv_fused = torch.zeros_like(kv_base) + + def fused(): + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, + q_out, + None, + None, + kv_fused, + loc, + ) + + # warmup + for _ in range(warmup): + baseline() + torch.cuda.synchronize() + t0 = time.time() + for _ in range(iters): + baseline() + torch.cuda.synchronize() + t1 = time.time() + baseline_ms = (t1 - t0) * 1000.0 / iters + + for _ in range(warmup): + fused() + torch.cuda.synchronize() + t0 = time.time() + for _ in range(iters): + fused() + torch.cuda.synchronize() + t1 = time.time() + fused_ms = (t1 - t0) * 1000.0 / iters + + return baseline_ms, fused_ms, baseline_ms / fused_ms + + +if __name__ == "__main__": + print("MLA RoPE + FP8 Quantization + KV Cache Write Fusion Benchmark") + print("=" * 70) + print("Config: Dn=512, Dr=64, iters=1000, warmup=100") + print("=" * 70) + + # Test larger batch sizes and more iterations for stable measurements + for nnz in [1024, 4096, 8192, 16384, 32768]: + b, f, s = run_one(nnz=nnz, iters=1000, warmup=100) + if b > 0: + speedup_pct = (s - 1.0) * 100 + print( + f"nnz={nnz:5d} | baseline={b:7.3f} ms | fused={f:7.3f} ms | speedup x{s:4.2f} ({speedup_pct:+5.1f}%)" + ) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 65ae9c4dc6e..657d7cfe7e9 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -39,6 +39,11 @@ if _is_cuda: from sgl_kernel import concat_mla_absorb_q + try: + from sgl_kernel import mla_rope_quantize_fp8_fused + except ImportError: + mla_rope_quantize_fp8_fused = None # Will use non-fused path + # Constants DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB @@ -634,13 +639,19 @@ def quantize_and_rope_for_fp8( forward_batch: ForwardBatch, cos_sin_cache: torch.Tensor, is_neox: bool, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + *, + layer: Optional["RadixAttention"] = None, + save_kv_cache: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """Quantize and apply RoPE for FP8 attention path. This function handles the FP8 quantization and RoPE application for MLA attention. It takes separate query/key nope and rope components, applies RoPE to the rope parts, quantizes all components to FP8, and merges the query components into a single tensor. + When layer and save_kv_cache are provided, it uses the fused kernel to directly write + K into the KV cache, eliminating intermediate global memory operations. + Args: q_nope: Query no-position-encoding component [seq_len, num_heads, kv_lora_rank] - expected dtype: torch.bfloat16 @@ -654,12 +665,14 @@ def quantize_and_rope_for_fp8( cos_sin_cache: Precomputed cosine/sine cache for RoPE - expected dtype: matches q_/k_ input dtype (torch.bfloat16) is_neox: Whether to use NeoX-style RoPE (interleaved) or GPT-style (half rotation) + layer: Optional RadixAttention layer for fused KV cache write + save_kv_cache: If True and layer is provided, use fused kernel to directly write KV cache Returns: tuple: (merged_q_out, k_nope_out, k_rope_out) quantized to FP8 - merged_q_out: [seq_len, num_heads, kv_lora_rank + qk_rope_head_dim], dtype=torch.float8_e4m3fn - - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn - - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn + - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn or None (if fused) + - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn or None (if fused) """ attn_dtype = torch.float8_e4m3fn q_len, num_heads = q_rope.shape[0], q_rope.shape[1] @@ -673,35 +686,82 @@ def quantize_and_rope_for_fp8( dtype=attn_dtype, ) - # Key outputs maintain original shapes but with FP8 dtype - k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype) - k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype) - - # Apply RoPE and quantize all components in a single fused kernel call - # This kernel handles: - # 1. RoPE application to q_rope and k_rope using cos_sin_cache and positions - # 2. Quantization of all components to FP8 format - # 3. Output placement into pre-allocated tensors - flashinfer.rope.mla_rope_quantize_fp8( - q_rope=q_rope, - k_rope=k_rope, - q_nope=q_nope, - k_nope=k_nope, - cos_sin_cache=cos_sin_cache, - pos_ids=forward_batch.positions, - is_neox=is_neox, - quantize_dtype=attn_dtype, - # Output tensor slicing: q_out contains [nope_part, rope_part] - q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end - k_rope_out=k_rope_out, - q_nope_out=q_out[..., : self.kv_lora_rank], # Nope part goes to beginning - k_nope_out=k_nope_out, - # Quantization scales (set to 1.0 for no additional scaling) - quant_scale_q=1.0, - quant_scale_kv=1.0, + # Determine if we can use the fused kernel (RoPE + quantization + direct KV write) + # Conditions: save_kv_cache, layer provided, not NSA packed format, kernel available + compute_cap = torch.cuda.get_device_capability() + kernel_available = mla_rope_quantize_fp8_fused is not None + has_layer = layer is not None + nsa_flag = getattr( + forward_batch.token_to_kv_pool, "nsa_kv_cache_store_fp8", False + ) + + # Only enable fusion on SM90+ (H100/B200) due to FP8 hardware support + use_fused_kv_write = ( + kernel_available + and (compute_cap[0] >= 9) + and save_kv_cache + and has_layer + and not nsa_flag ) - return q_out, k_nope_out, k_rope_out + if use_fused_kv_write: + # Fused path: RoPE + quantization + direct KV cache write + # This eliminates the intermediate write/read of k_nope_out/k_rope_out + + # Get KV buffer (supports both 2D and 3D formats) + kv_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + + # Prepare cache locations and positions as int64 contiguous + kv_loc = forward_batch.out_cache_loc.to(torch.int64).contiguous() + positions = forward_batch.positions.to(torch.int64).contiguous() + + # Call fused kernel: RoPE + quantize + write KV cache + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin_cache, + positions, + is_neox, + q_out, + None, # k_nope_out - skip intermediate output + None, # k_rope_out - skip intermediate output + kv_buffer, + kv_loc, + ) + + # Return Q output and None for K outputs (already written to cache) + return q_out, None, None + else: + # Standard path: RoPE + quantization only (backward compatible) + # Key outputs maintain original shapes but with FP8 dtype + k_rope_out = k_rope.new_empty(k_rope.shape, dtype=attn_dtype) + k_nope_out = k_nope.new_empty(k_nope.shape, dtype=attn_dtype) + + # Apply RoPE and quantize all components in a single kernel call + flashinfer.rope.mla_rope_quantize_fp8( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + cos_sin_cache=cos_sin_cache, + pos_ids=forward_batch.positions, + is_neox=is_neox, + quantize_dtype=attn_dtype, + # Output tensor slicing: q_out contains [nope_part, rope_part] + q_rope_out=q_out[..., self.kv_lora_rank :], # RoPE part goes to end + k_rope_out=k_rope_out, + q_nope_out=q_out[ + ..., : self.kv_lora_rank + ], # Nope part goes to beginning + k_nope_out=k_nope_out, + # Quantization scales (set to 1.0 for no additional scaling) + quant_scale_q=1.0, + quant_scale_kv=1.0, + ) + + return q_out, k_nope_out, k_rope_out def pad_draft_extend_query( self, @@ -798,12 +858,17 @@ def forward_decode( ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None + fused_kv = False # Track if using fused KV write path + if self.data_type == torch.float8_e4m3fn: # For FP8 path, we quantize the query and rope parts and merge them into a single tensor # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend assert all( x is not None for x in [q_rope, k_rope, cos_sin_cache] ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None." + + # Call quantize_and_rope_for_fp8 with layer and save_kv_cache + # This enables the fused kernel path when conditions are met q, k, k_rope = self.quantize_and_rope_for_fp8( q, q_rope, @@ -812,11 +877,15 @@ def forward_decode( forward_batch, cos_sin_cache, is_neox, + layer=layer, + save_kv_cache=save_kv_cache, ) merge_query = False + # Check if fused path was used (K outputs are None when directly written to KV cache) + fused_kv = k is None and k_rope is None - # Save KV cache if requested - if save_kv_cache: + # Save KV cache if requested (only if not using fused path) + if save_kv_cache and not fused_kv: assert ( k is not None and k_rope is not None ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None." @@ -899,15 +968,22 @@ def forward_extend( ) -> torch.Tensor: # TODO refactor to avoid code duplication merge_query = q_rope is not None - if ( + fused_kv = False # Track if we used fused KV write path + + # FP8 quantize path: only enable when all required inputs are present + # If any parameter is None, gracefully fall back to standard path + # TODO: Verify removing explicit assert doesn't mask bugs where params should be present but are None + use_fp8_quantize = ( self.data_type == torch.float8_e4m3fn - ) and forward_batch.forward_mode.is_target_verify(): - # For FP8 path, we quantize the query and rope parts and merge them into a single tensor - # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend - assert all( - x is not None for x in [q_rope, k_rope, cos_sin_cache] - ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None." - q, k, k_rope = self.quantize_and_rope_for_fp8( + and ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ) + and all(x is not None for x in [q_rope, k_rope, cos_sin_cache]) + ) + + if use_fp8_quantize: + q, k_fp8_nope, k_fp8_rope = self.quantize_and_rope_for_fp8( q, q_rope, k.squeeze(1), @@ -915,11 +991,17 @@ def forward_extend( forward_batch, cos_sin_cache, is_neox, + layer=layer, + save_kv_cache=save_kv_cache, ) merge_query = False + # Fused path: K already written to KV cache, function returns None for K outputs + fused_kv = k_fp8_nope is None and k_fp8_rope is None + k = k_fp8_nope + k_rope = k_fp8_rope - # Save KV cache if requested - if save_kv_cache: + # Save KV cache if requested (only if not using fused path) + if save_kv_cache and not fused_kv: assert ( k is not None and k_rope is not None ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None." @@ -989,7 +1071,6 @@ def forward_extend( q, padded_q, metadata.seq_lens_q, metadata.cu_seqlens_q ) - # TODO may use `mla_rope_quantize_fp8` fusion q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) assert kv_cache.dtype == self.data_type @@ -1018,11 +1099,13 @@ def forward_extend( output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output - if k_rope is not None: - k = torch.cat([k, k_rope], dim=-1) - k = k.view(-1, layer.tp_k_head_num, layer.head_dim) - - v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) + # Only prepare k/v tensors if NOT using fused path (for MHA prefill) + # In fused path, K is already in KV cache and we use decode kernel + if not fused_kv: + if k_rope is not None: + k = torch.cat([k, k_rope], dim=-1) + k = k.view(-1, layer.tp_k_head_num, layer.head_dim) + v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) if forward_batch.attn_attend_prefix_cache: # MHA for chunked prefix kv cache when running model with MLA diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 0c8bb02a486..553b849e9f2 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -287,6 +287,7 @@ set(SOURCES "csrc/elementwise/concat_mla.cu" "csrc/elementwise/copy.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" + "csrc/elementwise/mla_rope_fp8_kv_fused.cu" "csrc/elementwise/rope.cu" "csrc/elementwise/topk.cu" "csrc/expert_specialization/es_fp8_blockwise.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 03a7ec0151f..bca10cfe91a 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -94,6 +94,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); + m.def( + "mla_rope_quantize_fp8_fused(Tensor q_nope, Tensor q_rope, Tensor k_nope, Tensor k_rope, " + "Tensor cos_sin_cache, Tensor pos_ids, bool is_neox, Tensor! q_out, " + "Tensor!? k_nope_out, Tensor!? k_rope_out, Tensor? kv_buffer, Tensor? kv_cache_loc) -> ()"); + m.impl("mla_rope_quantize_fp8_fused", torch::kCUDA, &mla_rope_quantize_fp8_fused); + m.def( "downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, " "int mult, int offset) -> ()"); diff --git a/sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu b/sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu new file mode 100644 index 00000000000..845397bb741 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu @@ -0,0 +1,684 @@ +/* Copyright 2024 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MLA RoPE + FP8 Quantization + KV Cache Write Fusion Kernel +// Fuses RoPE application, FP8 quantization, and direct KV cache write + +#ifdef TORCH_EXTENSION_NAME +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +// TODO: Use pytorch_extension_utils.h when it's available in sgl-kernel/include +#define CHECK_INPUT(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be " #d "D") +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), #a " must equal " #b) +#define CHECK_LAST_DIM_CONTIGUOUS(x) TORCH_CHECK(x.stride(x.dim() - 1) == 1, #x " last dim must be contiguous") + +namespace { + +template +struct Vec2Traits; + +template <> +struct Vec2Traits<__half> { + using v2 = __half2; + __device__ static inline float2 to_float2(v2 h2) { + return __half22float2(h2); + } + __device__ static inline float to_float(const __half& h) { + return __half2float(h); + } +}; + +template <> +struct Vec2Traits { + using v2 = nv_bfloat162; + __device__ static inline float2 to_float2(v2 h2) { + return __bfloat1622float2(h2); + } + __device__ static inline float to_float(const nv_bfloat16& h) { + return __bfloat162float(h); + } +}; + +__device__ inline uint8_t float_to_e4m3fn_byte(float x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + __nv_fp8_storage_t byte = __nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3); + return static_cast(byte); +#else + x = fmaxf(-448.0f, fminf(448.0f, x)); + union { + float f; + uint32_t u; + } conv; + conv.f = x; + uint32_t sign = (conv.u >> 31) & 0x1; + if (x == 0.0f) return 0; + int exp = ((conv.u >> 23) & 0xFF) - 127; + exp = max(-6, min(8, exp)); + uint32_t mant = (conv.u >> 20) & 0x7; + uint8_t result = (sign << 7) | ((exp + 7) << 3) | mant; + return result; +#endif +} + +__device__ inline uint32_t pack4(uint8_t a0, uint8_t a1, uint8_t a2, uint8_t a3) { + return (uint32_t)a0 | ((uint32_t)a1 << 8) | ((uint32_t)a2 << 16) | ((uint32_t)a3 << 24); +} + +__device__ inline void rope_rotate(float& xr, float& xi, float c, float s) { + float xr_new = xr * c - xi * s; + float xi_new = xr * s + xi * c; + xr = xr_new; + xi = xi_new; +} + +template +__global__ void FusedRopeQuantizeKernelVec( + const T* __restrict__ q_nope, + const T* __restrict__ q_rope, + int64_t qn_stride_tok, + int64_t qn_stride_head, + int64_t qr_stride_tok, + int64_t qr_stride_head, + const T* __restrict__ k_nope, + const T* __restrict__ k_rope, + int64_t kn_stride_tok, + int64_t kr_stride_tok, + const float* __restrict__ cos_sin, + const int64_t* __restrict__ pos_ids, + int nnz, + int num_heads, + int Dn, + int Dr, + bool is_neox, + uint8_t* __restrict__ q_out_fp8, + int64_t qout_stride_tok_bytes, + int64_t qout_stride_head_bytes, + uint8_t* __restrict__ k_nope_out_fp8, + uint8_t* __restrict__ k_rope_out_fp8, + uint8_t* __restrict__ kv_buffer_bytes, + int64_t kv_stride_row_bytes, + const int64_t* __restrict__ kv_cache_loc) { + constexpr int WARP_SIZE = 32; + int warp_in_block = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x & (WARP_SIZE - 1); + + int global_row = blockIdx.x * WARPS_PER_CTA + warp_in_block; + if (global_row >= nnz * num_heads) return; + + int token_id = global_row / num_heads; + int head_id = global_row % num_heads; + + const T* qn = q_nope + size_t(token_id) * qn_stride_tok + size_t(head_id) * qn_stride_head; + const T* qr = q_rope + size_t(token_id) * qr_stride_tok + size_t(head_id) * qr_stride_head; + const T* kn = k_nope + size_t(token_id) * kn_stride_tok; + const T* kr = k_rope + size_t(token_id) * kr_stride_tok; + + uint8_t* qdst = q_out_fp8 + size_t(token_id) * qout_stride_tok_bytes + size_t(head_id) * qout_stride_head_bytes; + uint8_t* kndst = k_nope_out_fp8 ? (k_nope_out_fp8 + size_t(token_id) * Dn) : nullptr; + uint8_t* krdst = k_rope_out_fp8 ? (k_rope_out_fp8 + size_t(token_id) * Dr) : nullptr; + + int pos = static_cast(pos_ids[token_id]); + + uint8_t* kvdst = nullptr; + if (kv_buffer_bytes && kv_cache_loc) { + int64_t flat_row = kv_cache_loc[token_id]; + kvdst = kv_buffer_bytes + flat_row * kv_stride_row_bytes; + } + + const float* cos_ptr = cos_sin + size_t(pos) * Dr; + const float* sin_ptr = cos_ptr + (Dr / 2); + + using V2 = typename Vec2Traits::v2; + + // Process Q_nope: vectorized quantize + write + for (int c = lane * 4; c < Dn; c += WARP_SIZE * 4) { + V2 h0 = *reinterpret_cast(qn + c + 0); + V2 h1 = *reinterpret_cast(qn + c + 2); + float2 f0 = Vec2Traits::to_float2(h0); + float2 f1 = Vec2Traits::to_float2(h1); + + uint32_t packed = pack4( + float_to_e4m3fn_byte(f0.x), float_to_e4m3fn_byte(f0.y), float_to_e4m3fn_byte(f1.x), float_to_e4m3fn_byte(f1.y)); + *reinterpret_cast(qdst + c) = packed; + } + + for (int c = lane * 4; c < Dr; c += WARP_SIZE * 4) { + V2 h0 = *reinterpret_cast(qr + c + 0); + V2 h1 = *reinterpret_cast(qr + c + 2); + float2 f0 = Vec2Traits::to_float2(h0); + float2 f1 = Vec2Traits::to_float2(h1); + + int base0 = (c + 0) >> 1; + int base1 = (c + 2) >> 1; + float c0 = cos_ptr[base0], s0 = sin_ptr[base0]; + float c1 = cos_ptr[base1], s1 = sin_ptr[base1]; + + rope_rotate(f0.x, f0.y, c0, s0); + rope_rotate(f1.x, f1.y, c1, s1); + + uint32_t packed = pack4( + float_to_e4m3fn_byte(f0.x), float_to_e4m3fn_byte(f0.y), float_to_e4m3fn_byte(f1.x), float_to_e4m3fn_byte(f1.y)); + *reinterpret_cast(qdst + Dn + c) = packed; + } + + if (head_id == 0) { + for (int c = lane * 4; c < Dn; c += WARP_SIZE * 4) { + V2 h0 = *reinterpret_cast(kn + c + 0); + V2 h1 = *reinterpret_cast(kn + c + 2); + float2 f0 = Vec2Traits::to_float2(h0); + float2 f1 = Vec2Traits::to_float2(h1); + + uint32_t packed = pack4( + float_to_e4m3fn_byte(f0.x), + float_to_e4m3fn_byte(f0.y), + float_to_e4m3fn_byte(f1.x), + float_to_e4m3fn_byte(f1.y)); + + if (kndst) *reinterpret_cast(kndst + c) = packed; + if (kvdst) *reinterpret_cast(kvdst + c) = packed; + } + + for (int c = lane * 4; c < Dr; c += WARP_SIZE * 4) { + V2 h0 = *reinterpret_cast(kr + c + 0); + V2 h1 = *reinterpret_cast(kr + c + 2); + float2 f0 = Vec2Traits::to_float2(h0); + float2 f1 = Vec2Traits::to_float2(h1); + + int base0 = (c + 0) >> 1; + int base1 = (c + 2) >> 1; + float c0 = cos_ptr[base0], s0 = sin_ptr[base0]; + float c1 = cos_ptr[base1], s1 = sin_ptr[base1]; + + rope_rotate(f0.x, f0.y, c0, s0); + rope_rotate(f1.x, f1.y, c1, s1); + + uint32_t packed = pack4( + float_to_e4m3fn_byte(f0.x), + float_to_e4m3fn_byte(f0.y), + float_to_e4m3fn_byte(f1.x), + float_to_e4m3fn_byte(f1.y)); + + if (krdst) *reinterpret_cast(krdst + c) = packed; + if (kvdst) *reinterpret_cast(kvdst + Dn + c) = packed; + } + } +} + +// ============================================================================ +// Scalar fallback kernel: for dimensions not divisible by 4 +// Template supports both FP16 (__half) and BF16 (nv_bfloat16) +// ============================================================================ +template +__global__ void FusedRopeQuantizeKernelScalar( + const T* __restrict__ q_nope, + const T* __restrict__ q_rope, + int64_t qn_stride_tok, + int64_t qn_stride_head, + int64_t qr_stride_tok, + int64_t qr_stride_head, + const T* __restrict__ k_nope, + const T* __restrict__ k_rope, + int64_t kn_stride_tok, // NEW: K_nope stride(0) in elements + int64_t kr_stride_tok, // NEW: K_rope stride(0) in elements + const float* __restrict__ cos_sin, + const int64_t* __restrict__ pos_ids, + int nnz, + int num_heads, + int Dn, + int Dr, + bool is_neox, + uint8_t* __restrict__ q_out_fp8, + int64_t qout_stride_tok_bytes, + int64_t qout_stride_head_bytes, + uint8_t* __restrict__ k_nope_out_fp8, + uint8_t* __restrict__ k_rope_out_fp8, + uint8_t* __restrict__ kv_buffer_bytes, + int64_t kv_stride_row_bytes, // 2D: row stride in bytes + const int64_t* __restrict__ kv_cache_loc) { + for (int global_row = blockIdx.x * BLOCK_THREADS + threadIdx.x; global_row < nnz * num_heads; + global_row += gridDim.x * BLOCK_THREADS) { + int token_id = global_row / num_heads; + int head_id = global_row % num_heads; + + int pos = static_cast(pos_ids[token_id]); + const float* cos_ptr = cos_sin + size_t(pos) * Dr; + const float* sin_ptr = cos_ptr + (Dr / 2); + + { + const T* qn = q_nope + size_t(token_id) * qn_stride_tok + size_t(head_id) * qn_stride_head; + const T* qr = q_rope + size_t(token_id) * qr_stride_tok + size_t(head_id) * qr_stride_head; + uint8_t* qdst = q_out_fp8 + size_t(token_id) * qout_stride_tok_bytes + size_t(head_id) * qout_stride_head_bytes; + + for (int i = 0; i < Dn; ++i) { + qdst[i] = float_to_e4m3fn_byte(Vec2Traits::to_float(qn[i])); + } + + if (!is_neox) { + for (int i = 0; i < Dr; i += 2) { + int base = i >> 1; + float xr = Vec2Traits::to_float(qr[i + 0]); + float xi = (i + 1 < Dr) ? Vec2Traits::to_float(qr[i + 1]) : 0.0f; + rope_rotate(xr, xi, cos_ptr[base], sin_ptr[base]); + qdst[Dn + i] = float_to_e4m3fn_byte(xr); + if (i + 1 < Dr) qdst[Dn + i + 1] = float_to_e4m3fn_byte(xi); + } + } else { + int half = Dr / 2; + for (int i = 0; i < half; ++i) { + float xr = Vec2Traits::to_float(qr[i]); + float xi = Vec2Traits::to_float(qr[i + half]); + rope_rotate(xr, xi, cos_ptr[i], sin_ptr[i]); + qdst[Dn + i] = float_to_e4m3fn_byte(xr); + qdst[Dn + i + half] = float_to_e4m3fn_byte(xi); + } + } + } + + const T* kn = k_nope + size_t(token_id) * kn_stride_tok; + const T* kr = k_rope + size_t(token_id) * kr_stride_tok; + + if (head_id == 0) { + if (k_nope_out_fp8) { + uint8_t* knd = k_nope_out_fp8 + size_t(token_id) * Dn; + for (int i = 0; i < Dn; ++i) { + knd[i] = float_to_e4m3fn_byte(Vec2Traits::to_float(kn[i])); + } + } + if (k_rope_out_fp8) { + uint8_t* krd = k_rope_out_fp8 + size_t(token_id) * Dr; + if (!is_neox) { + for (int i = 0; i < Dr; i += 2) { + int base = i >> 1; + float xr = Vec2Traits::to_float(kr[i]); + float xi = (i + 1 < Dr) ? Vec2Traits::to_float(kr[i + 1]) : 0.0f; + rope_rotate(xr, xi, cos_ptr[base], sin_ptr[base]); + krd[i] = float_to_e4m3fn_byte(xr); + if (i + 1 < Dr) krd[i + 1] = float_to_e4m3fn_byte(xi); + } + } else { + int half = Dr / 2; + for (int i = 0; i < half; ++i) { + float xr = Vec2Traits::to_float(kr[i]); + float xi = Vec2Traits::to_float(kr[i + half]); + rope_rotate(xr, xi, cos_ptr[i], sin_ptr[i]); + krd[i] = float_to_e4m3fn_byte(xr); + krd[i + half] = float_to_e4m3fn_byte(xi); + } + } + } + + if (kv_buffer_bytes && kv_cache_loc) { + int64_t flat_row = kv_cache_loc[token_id]; + uint8_t* dst = kv_buffer_bytes + flat_row * kv_stride_row_bytes; + for (int i = 0; i < Dn; ++i) { + dst[i] = float_to_e4m3fn_byte(Vec2Traits::to_float(kn[i])); + } + if (!is_neox) { + for (int i = 0; i < Dr; i += 2) { + int base = i >> 1; + float xr = Vec2Traits::to_float(kr[i]); + float xi = (i + 1 < Dr) ? Vec2Traits::to_float(kr[i + 1]) : 0.0f; + rope_rotate(xr, xi, cos_ptr[base], sin_ptr[base]); + dst[Dn + i] = float_to_e4m3fn_byte(xr); + if (i + 1 < Dr) dst[Dn + i + 1] = float_to_e4m3fn_byte(xi); + } + } else { + int half = Dr / 2; + for (int i = 0; i < half; ++i) { + float xr = Vec2Traits::to_float(kr[i]); + float xi = Vec2Traits::to_float(kr[i + half]); + rope_rotate(xr, xi, cos_ptr[i], sin_ptr[i]); + dst[Dn + i] = float_to_e4m3fn_byte(xr); + dst[Dn + i + half] = float_to_e4m3fn_byte(xi); + } + } + } + } + } +} + +} // namespace + +void mla_rope_quantize_fp8_fused( + at::Tensor q_nope, + at::Tensor q_rope, + at::Tensor k_nope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool is_neox, + at::Tensor q_out, + c10::optional k_nope_out, + c10::optional k_rope_out, + c10::optional kv_buffer, + c10::optional kv_cache_loc) { + CHECK_INPUT(q_nope); + CHECK_INPUT(q_rope); + CHECK_INPUT(k_nope); + CHECK_INPUT(k_rope); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + CHECK_INPUT(q_out); + + auto device = q_nope.device(); + CHECK_EQ(q_rope.device(), device); + CHECK_EQ(k_nope.device(), device); + CHECK_EQ(k_rope.device(), device); + CHECK_EQ(cos_sin_cache.device(), device); + CHECK_EQ(pos_ids.device(), device); + CHECK_EQ(q_out.device(), device); + + TORCH_CHECK(q_nope.dim() == 2 || q_nope.dim() == 3, "q_nope must be 2D or 3D"); + TORCH_CHECK(q_rope.dim() == 2 || q_rope.dim() == 3, "q_rope must be 2D or 3D"); + CHECK_DIM(2, k_nope); + CHECK_DIM(2, k_rope); + CHECK_DIM(1, pos_ids); + CHECK_DIM(2, cos_sin_cache); + + // Determine dimensions and strides based on Q shape + int nnz_tokens, num_heads, Dn, Dr; + int64_t qn_stride_tok, qn_stride_head, qr_stride_tok, qr_stride_head; + int64_t qout_stride_tok_bytes, qout_stride_head_bytes; + + if (q_nope.dim() == 3) { + nnz_tokens = q_nope.size(0); + num_heads = q_nope.size(1); + Dn = q_nope.size(2); + Dr = q_rope.size(2); + + CHECK_EQ(q_rope.size(0), nnz_tokens); + CHECK_EQ(q_rope.size(1), num_heads); + CHECK_EQ(q_out.dim(), 3); + CHECK_EQ(q_out.size(0), nnz_tokens); + CHECK_EQ(q_out.size(1), num_heads); + CHECK_EQ(q_out.size(2), Dn + Dr); + + qn_stride_tok = q_nope.stride(0); + qn_stride_head = q_nope.stride(1); + qr_stride_tok = q_rope.stride(0); + qr_stride_head = q_rope.stride(1); + qout_stride_tok_bytes = q_out.stride(0); + qout_stride_head_bytes = q_out.stride(1); + } else { + nnz_tokens = q_nope.size(0); + Dn = q_nope.size(1); + Dr = q_rope.size(1); + num_heads = 1; + + CHECK_EQ(q_rope.size(0), nnz_tokens); + CHECK_EQ(q_out.dim(), 2); + CHECK_EQ(q_out.size(0), nnz_tokens); + CHECK_EQ(q_out.size(1), Dn + Dr); + + qn_stride_tok = q_nope.stride(0); + qn_stride_head = 0; + qr_stride_tok = q_rope.stride(0); + qr_stride_head = 0; + qout_stride_tok_bytes = q_out.stride(0); + qout_stride_head_bytes = 0; + } + + int nnz_k = k_rope.size(0); + CHECK_EQ(k_nope.size(0), nnz_k); + CHECK_EQ(k_nope.size(1), Dn); + CHECK_EQ(k_rope.size(0), nnz_k); + CHECK_EQ(k_rope.size(1), Dr); + CHECK_EQ(nnz_k, nnz_tokens); + + int64_t kn_stride_tok = k_nope.stride(0); + int64_t kr_stride_tok = k_rope.stride(0); + + CHECK_LAST_DIM_CONTIGUOUS(k_nope); + CHECK_LAST_DIM_CONTIGUOUS(k_rope); + CHECK_LAST_DIM_CONTIGUOUS(q_nope); + CHECK_LAST_DIM_CONTIGUOUS(q_rope); + CHECK_LAST_DIM_CONTIGUOUS(q_out); + + uint8_t* k_nope_out_ptr = nullptr; + uint8_t* k_rope_out_ptr = nullptr; + if (k_nope_out.has_value()) { + auto t = k_nope_out.value(); + CHECK_INPUT(t); + CHECK_DIM(2, t); + CHECK_EQ(t.size(0), nnz_k); + CHECK_EQ(t.size(1), Dn); + k_nope_out_ptr = reinterpret_cast(t.data_ptr()); + } + if (k_rope_out.has_value()) { + auto t = k_rope_out.value(); + CHECK_INPUT(t); + CHECK_DIM(2, t); + CHECK_EQ(t.size(0), nnz_k); + CHECK_EQ(t.size(1), Dr); + k_rope_out_ptr = reinterpret_cast(t.data_ptr()); + } + + uint8_t* kv_buf_ptr = nullptr; + int64_t kv_stride_row_bytes = 0; + const int64_t* kv_loc_ptr = nullptr; + if (kv_buffer.has_value() || kv_cache_loc.has_value()) { + TORCH_CHECK(kv_buffer.has_value() && kv_cache_loc.has_value(), "kv_buffer and kv_cache_loc must be both provided"); + auto kv = kv_buffer.value(); + auto loc = kv_cache_loc.value(); + CHECK_INPUT(kv); + CHECK_INPUT(loc); + CHECK_DIM(1, loc); + + TORCH_CHECK(kv.dim() == 2 || (kv.dim() == 3 && kv.size(1) == 1), "kv_buffer must be 2D or 3D with middle dim=1"); + + int kv_dim_actual = (kv.dim() == 3) ? kv.size(2) : kv.size(1); + CHECK_EQ(kv_dim_actual, Dn + Dr); + CHECK_EQ(loc.size(0), nnz_k); + CHECK_LAST_DIM_CONTIGUOUS(kv); + + kv_buf_ptr = reinterpret_cast(kv.data_ptr()); + kv_stride_row_bytes = kv.stride(0) * kv.element_size(); + kv_loc_ptr = loc.data_ptr(); + } + + const float* cs_ptr = cos_sin_cache.data_ptr(); + const int64_t* pos_ptr = pos_ids.data_ptr(); + uint8_t* q_out_ptr = reinterpret_cast(q_out.data_ptr()); + + cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); + int total_rows = nnz_tokens * num_heads; + + bool can_vectorize = ((Dn & 3) == 0) && ((Dr & 3) == 0) && !is_neox; + if (can_vectorize) { + bool strides_aligned = + (qout_stride_tok_bytes % 4 == 0) && (num_heads > 1 ? (qout_stride_head_bytes % 4 == 0) : true); + if (!strides_aligned) { + can_vectorize = false; + } + } + + auto dtype = q_nope.scalar_type(); + + if (dtype == at::kHalf) { + const __half* qn_ptr = reinterpret_cast(q_nope.data_ptr()); + const __half* qr_ptr = reinterpret_cast(q_rope.data_ptr()); + const __half* kn_ptr = reinterpret_cast(k_nope.data_ptr()); + const __half* kr_ptr = reinterpret_cast(k_rope.data_ptr()); + + if (can_vectorize) { + constexpr int WARPS_PER_CTA = 4; + dim3 vecBlock(WARPS_PER_CTA * 32); + dim3 vecGrid((total_rows + WARPS_PER_CTA - 1) / WARPS_PER_CTA); + + FusedRopeQuantizeKernelVec<<>>( + qn_ptr, + qr_ptr, + qn_stride_tok, + qn_stride_head, + qr_stride_tok, + qr_stride_head, + kn_ptr, + kr_ptr, + kn_stride_tok, + kr_stride_tok, + cs_ptr, + pos_ptr, + nnz_tokens, + num_heads, + Dn, + Dr, + is_neox, + q_out_ptr, + qout_stride_tok_bytes, + qout_stride_head_bytes, + k_nope_out_ptr, + k_rope_out_ptr, + kv_buf_ptr, + kv_stride_row_bytes, + kv_loc_ptr); + } else { + constexpr int BLOCK_THREADS = 256; + dim3 grid((total_rows + BLOCK_THREADS - 1) / BLOCK_THREADS); + + FusedRopeQuantizeKernelScalar<<>>( + qn_ptr, + qr_ptr, + qn_stride_tok, + qn_stride_head, + qr_stride_tok, + qr_stride_head, + kn_ptr, + kr_ptr, + kn_stride_tok, + kr_stride_tok, + cs_ptr, + pos_ptr, + nnz_tokens, + num_heads, + Dn, + Dr, + is_neox, + q_out_ptr, + qout_stride_tok_bytes, + qout_stride_head_bytes, + k_nope_out_ptr, + k_rope_out_ptr, + kv_buf_ptr, + kv_stride_row_bytes, + kv_loc_ptr); + } + } else if (dtype == at::kBFloat16) { + const nv_bfloat16* qn_ptr = reinterpret_cast(q_nope.data_ptr()); + const nv_bfloat16* qr_ptr = reinterpret_cast(q_rope.data_ptr()); + const nv_bfloat16* kn_ptr = reinterpret_cast(k_nope.data_ptr()); + const nv_bfloat16* kr_ptr = reinterpret_cast(k_rope.data_ptr()); + + if (can_vectorize) { + constexpr int WARPS_PER_CTA = 4; + dim3 vecBlock(WARPS_PER_CTA * 32); + dim3 vecGrid((total_rows + WARPS_PER_CTA - 1) / WARPS_PER_CTA); + + FusedRopeQuantizeKernelVec<<>>( + qn_ptr, + qr_ptr, + qn_stride_tok, + qn_stride_head, + qr_stride_tok, + qr_stride_head, + kn_ptr, + kr_ptr, + kn_stride_tok, + kr_stride_tok, + cs_ptr, + pos_ptr, + nnz_tokens, + num_heads, + Dn, + Dr, + is_neox, + q_out_ptr, + qout_stride_tok_bytes, + qout_stride_head_bytes, + k_nope_out_ptr, + k_rope_out_ptr, + kv_buf_ptr, + kv_stride_row_bytes, + kv_loc_ptr); + } else { + constexpr int BLOCK_THREADS = 256; + dim3 grid((total_rows + BLOCK_THREADS - 1) / BLOCK_THREADS); + + FusedRopeQuantizeKernelScalar<<>>( + qn_ptr, + qr_ptr, + qn_stride_tok, + qn_stride_head, + qr_stride_tok, + qr_stride_head, + kn_ptr, + kr_ptr, + kn_stride_tok, + kr_stride_tok, + cs_ptr, + pos_ptr, + nnz_tokens, + num_heads, + Dn, + Dr, + is_neox, + q_out_ptr, + qout_stride_tok_bytes, + qout_stride_head_bytes, + k_nope_out_ptr, + k_rope_out_ptr, + kv_buf_ptr, + kv_stride_row_bytes, + kv_loc_ptr); + } + } else { + TORCH_CHECK(false, "Unsupported dtype for fused kernel. Only FP16 and BF16 are supported."); + } + + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "Kernel launch failed"); +} + +#ifdef TORCH_EXTENSION_NAME +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "mla_rope_quantize_fp8_fused", + &mla_rope_quantize_fp8_fused, + "Fused MLA RoPE + FP8 quantization with optional KV cache write", + py::arg("q_nope"), + py::arg("q_rope"), + py::arg("k_nope"), + py::arg("k_rope"), + py::arg("cos_sin_cache"), + py::arg("pos_ids"), + py::arg("is_neox"), + py::arg("q_out"), + py::arg("k_nope_out") = py::none(), + py::arg("k_rope_out") = py::none(), + py::arg("kv_buffer") = py::none(), + py::arg("kv_cache_loc") = py::none()); +} +#endif diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index e95bcc2ff8c..e3edfe2374b 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -157,6 +157,20 @@ void apply_rope_pos_ids_cos_sin_cache( const std::optional& v_buffer, const std::optional& kv_cache_loc); +void mla_rope_quantize_fp8_fused( + at::Tensor q_nope, + at::Tensor q_rope, + at::Tensor k_nope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool is_neox, + at::Tensor q_out, + c10::optional k_nope_out, + c10::optional k_rope_out, + c10::optional kv_buffer, + c10::optional kv_cache_loc); + void downcast_fp8( at::Tensor& k, at::Tensor& v, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 1c36681e189..19add301dea 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -30,6 +30,7 @@ gelu_tanh_and_mul, gemma_fused_add_rmsnorm, gemma_rmsnorm, + mla_rope_quantize_fp8_fused, rmsnorm, silu_and_mul, ) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 5684800bee4..bb645426a75 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -391,3 +391,50 @@ def concat_mla_absorb_q( ) torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out) return out + + +def mla_rope_quantize_fp8_fused( + q_nope: torch.Tensor, + q_rope: torch.Tensor, + k_nope: torch.Tensor, + k_rope: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + is_neox: bool, + q_out: torch.Tensor, + k_nope_out: Optional[torch.Tensor], + k_rope_out: Optional[torch.Tensor], + kv_buffer: Optional[torch.Tensor], + kv_cache_loc: Optional[torch.Tensor], +): + """ + Fused MLA RoPE + FP8 quantization + KV cache write kernel. + + Args: + q_nope: Query nope part + q_rope: Query rope part + k_nope: Key nope part + k_rope: Key rope part + cos_sin_cache: Precomputed cos/sin cache + pos_ids: Position IDs + is_neox: Whether to use NeoX-style RoPE + q_out: Output buffer for quantized Q + k_nope_out: Optional output buffer for quantized K nope + k_rope_out: Optional output buffer for quantized K rope + kv_buffer: Optional KV cache buffer for direct write + kv_cache_loc: Optional cache locations for KV buffer write + """ + torch.ops.sgl_kernel.mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin_cache, + pos_ids, + is_neox, + q_out, + k_nope_out, + k_rope_out, + kv_buffer, + kv_cache_loc, + ) diff --git a/sgl-kernel/tests/test_mla_rope_fp8_fused.py b/sgl-kernel/tests/test_mla_rope_fp8_fused.py new file mode 100644 index 00000000000..a54f3920eb1 --- /dev/null +++ b/sgl-kernel/tests/test_mla_rope_fp8_fused.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +""" +PyTest: correctness of fused MLA RoPE + FP8 quantization + KV write. +Tests the mla_rope_quantize_fp8_fused kernel from sgl_kernel extension. + +Tests both: +1. Baseline path (kernel writes to k_nope_out/k_rope_out) +2. Fused path (kernel directly writes to KV cache buffer) +""" +import itertools + +import pytest +import torch +from sgl_kernel import mla_rope_quantize_fp8_fused + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "nnz,num_heads,Dn,Dr,dtype", + list( + itertools.product( + [64, 256], # nnz: number of tokens + [1, 8], # num_heads: 1 for 2D Q, 8 for 3D Q + [512], # Dn: nope dimension + [64], # Dr: rope dimension + [torch.float16, torch.bfloat16], # dtypes + ) + ), +) +def test_fused_matches_baseline(nnz, num_heads, Dn, Dr, dtype): + """Test that fused KV write produces same results as baseline.""" + device = "cuda" + torch.manual_seed(42) + + # Create inputs based on whether we're testing 2D or 3D Q + if num_heads == 1: + # 2D case: [nnz, dim] + q_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + q_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + q_out_shape = (nnz, Dn + Dr) + else: + # 3D case: [nnz, num_heads, dim] + q_nope = torch.randn(nnz, num_heads, Dn, device=device, dtype=dtype) + q_rope = torch.randn(nnz, num_heads, Dr, device=device, dtype=dtype) + q_out_shape = (nnz, num_heads, Dn + Dr) + + # K is always 2D regardless of Q shape + k_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + k_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + + # Create cos/sin cache + max_seq = max(2048, nnz) + t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None] + idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :] + freqs = 0.1 * (idx + 1.0) # Small frequencies to avoid overflow + cos = torch.cos(t * freqs) + sin = torch.sin(t * freqs) + cos_sin = torch.cat([cos, sin], dim=1) # [max_seq, 2*Dr] + + # Random position IDs + pos_ids = torch.randint( + low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long + ) + + # ======================================================================== + # BASELINE PATH: Write to k_nope_out/k_rope_out, then concat manually + # ======================================================================== + q_out_base = torch.empty(q_out_shape, device=device, dtype=torch.uint8) + k_nope_out = torch.empty(nnz, Dn, device=device, dtype=torch.uint8) + k_rope_out = torch.empty(nnz, Dr, device=device, dtype=torch.uint8) + + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, # is_neox + q_out_base, + k_nope_out, + k_rope_out, + None, # kv_buffer + None, # kv_cache_loc + ) + + # Manually concat K parts into KV buffer (simulating set_mla_kv_buffer) + slots = nnz + 8 # Add some extra slots + kv_base = torch.zeros(slots, Dn + Dr, device=device, dtype=torch.uint8) # 2D format + loc = torch.arange(nnz, device=device, dtype=torch.long) + kv_base[loc, :Dn] = k_nope_out + kv_base[loc, Dn:] = k_rope_out + + # ======================================================================== + # FUSED PATH: Direct KV write, skip separate K outputs + # ======================================================================== + q_out_fused = torch.empty_like(q_out_base) + kv_fused = torch.zeros_like(kv_base) + + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, # is_neox + q_out_fused, + None, # k_nope_out + None, # k_rope_out + kv_fused, # Direct write to KV buffer + loc, # kv_cache_loc + ) + + # ======================================================================== + # ASSERTIONS + # ======================================================================== + # For FP8 quantized outputs, we can't expect perfect match due to + # rounding, but they should be very close when converted back to float + torch.testing.assert_close( + q_out_base.float(), + q_out_fused.float(), + rtol=1e-3, + atol=1e-3, + msg=f"q_out mismatch for {dtype=}, {num_heads=}", + ) + + torch.testing.assert_close( + kv_base.float(), + kv_fused.float(), + rtol=1e-3, + atol=1e-3, + msg=f"KV buffer mismatch for {dtype=}, {num_heads=}", + ) + + # For stricter check on used slots (should be exact) + assert torch.equal( + kv_base[loc], kv_fused[loc] + ), f"Used KV slots must match exactly for {dtype=}, {num_heads=}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("nnz,Dn,Dr", [(128, 512, 64), (1024, 512, 64)]) +def test_baseline_only_path(nnz, Dn, Dr): + """Test that baseline path (without KV buffer) works correctly.""" + device = "cuda" + dtype = torch.float16 + torch.manual_seed(42) + + # 2D inputs + q_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + q_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + k_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + k_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + + # Create cos/sin cache + max_seq = 2048 + t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None] + idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :] + freqs = 0.1 * (idx + 1.0) + cos = torch.cos(t * freqs) + sin = torch.sin(t * freqs) + cos_sin = torch.cat([cos, sin], dim=1) + + pos_ids = torch.randint( + low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long + ) + + # Allocate outputs + q_out = torch.empty(nnz, Dn + Dr, device=device, dtype=torch.uint8) + k_nope_out = torch.empty(nnz, Dn, device=device, dtype=torch.uint8) + k_rope_out = torch.empty(nnz, Dr, device=device, dtype=torch.uint8) + + # Call kernel (baseline path only) + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, + q_out, + k_nope_out, + k_rope_out, + None, # No KV buffer + None, # No kv_cache_loc + ) + + # Basic sanity checks + assert q_out.shape == (nnz, Dn + Dr) + assert k_nope_out.shape == (nnz, Dn) + assert k_rope_out.shape == (nnz, Dr) + + # Check that outputs are not all zeros (actual quantization happened) + assert q_out.abs().sum() > 0, "q_out should not be all zeros" + assert k_nope_out.abs().sum() > 0, "k_nope_out should not be all zeros" + assert k_rope_out.abs().sum() > 0, "k_rope_out should not be all zeros" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_fused_only_path(): + """Test that fused path (only KV buffer, no separate K outputs) works.""" + device = "cuda" + dtype = torch.float16 + nnz, Dn, Dr = 128, 512, 64 + torch.manual_seed(42) + + # 2D inputs + q_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + q_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + k_nope = torch.randn(nnz, Dn, device=device, dtype=dtype) + k_rope = torch.randn(nnz, Dr, device=device, dtype=dtype) + + # Create cos/sin cache + max_seq = 2048 + t = torch.linspace(0, 1, steps=max_seq, device=device, dtype=torch.float32)[:, None] + idx = torch.arange(Dr, device=device, dtype=torch.float32)[None, :] + freqs = 0.1 * (idx + 1.0) + cos = torch.cos(t * freqs) + sin = torch.sin(t * freqs) + cos_sin = torch.cat([cos, sin], dim=1) + + pos_ids = torch.randint( + low=0, high=max_seq, size=(nnz,), device=device, dtype=torch.long + ) + + # Allocate only Q output and KV buffer + q_out = torch.empty(nnz, Dn + Dr, device=device, dtype=torch.uint8) + slots = nnz + 16 + kv_buffer = torch.zeros( + slots, Dn + Dr, device=device, dtype=torch.uint8 + ) # 2D format + loc = torch.arange(nnz, device=device, dtype=torch.long) + + # Call kernel (fused path only) + mla_rope_quantize_fp8_fused( + q_nope, + q_rope, + k_nope, + k_rope, + cos_sin, + pos_ids, + False, + q_out, + None, # No k_nope_out + None, # No k_rope_out + kv_buffer, # Direct KV write + loc, + ) + + # Check that KV buffer was written + assert kv_buffer[loc].abs().sum() > 0, "KV buffer should have been written" + # Check unused slots are still zero + unused = torch.ones(slots, dtype=torch.bool, device=device) + unused[loc] = False + assert kv_buffer[unused].abs().sum() == 0, "Unused KV slots should remain zero" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])