diff --git a/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py b/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py new file mode 100644 index 000000000000..e10b2f9bc684 --- /dev/null +++ b/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py @@ -0,0 +1,467 @@ +""" +Fused FP8 quantization + paged KV cache write kernel for TRTLLM MHA backend. + +This kernel fuses the following operations: +1. FP8 quantization of K and V tensors (from BF16/FP16 to FP8) +2. Per-token or per-page scale computation +3. Writing quantized K/V to paged KV cache layout + +Performance benefits: +- Eliminates intermediate FP8 tensors in memory +- Reduces kernel launch overhead +- Better memory bandwidth utilization +""" + +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def _process_kv_tensor( + token_id, + head_block_id, + page_id, + page_offset, + input_ptr, + cache_ptr, + inv_scale, + use_provided_scale: tl.constexpr, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + input_stride_token: tl.constexpr, + input_stride_head: tl.constexpr, + input_stride_dim: tl.constexpr, + cache_stride_page: tl.constexpr, + cache_stride_offset: tl.constexpr, + cache_stride_head: tl.constexpr, + cache_stride_dim: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Process a block of heads for a single K or V tensor.""" + head_idx = head_block_id * BLOCK_HEAD + num_heads_in_block = min(BLOCK_HEAD, num_kv_heads - head_idx) + + for dim_idx in range(0, head_dim, BLOCK_DIM): + num_dims_in_block = min(BLOCK_DIM, head_dim - dim_idx) + + head_offsets = head_idx + tl.arange(0, BLOCK_HEAD) + dim_offsets = dim_idx + tl.arange(0, BLOCK_DIM) + + head_mask = head_offsets < (head_idx + num_heads_in_block) + dim_mask = dim_offsets < (dim_idx + num_dims_in_block) + + # Load from input using 3D strides + input_offsets = ( + token_id * input_stride_token + + head_offsets[:, None] * input_stride_head + + dim_offsets[None, :] * input_stride_dim + ) + mask = head_mask[:, None] & dim_mask[None, :] + + block = tl.load(input_ptr + input_offsets, mask=mask, other=0.0) + + # Quantize to FP8 + if use_provided_scale: + block_fp8 = (block * inv_scale).to(tl.float8e4nv) + else: + block_fp8 = block.to(tl.float8e4nv) + + # Write to cache at [page_id, page_offset, head, dim] + cache_offsets = ( + page_id * cache_stride_page + + page_offset * cache_stride_offset + + head_offsets[:, None] * cache_stride_head + + dim_offsets[None, :] * cache_stride_dim + ) + + tl.store(cache_ptr + cache_offsets, block_fp8, mask=mask) + + +@triton.jit +def _fused_fp8_set_kv_buffer_kernel( + # Input tensors (post-RoPE K and V in FP16/BF16) + k_ptr, # [num_tokens, num_kv_heads, head_dim] + v_ptr, # [num_tokens, num_kv_heads, head_dim] + # Output KV cache buffers (FP8 paged layout) + k_cache_ptr, # [total_slots, num_kv_heads, head_dim] + v_cache_ptr, # [total_slots, num_kv_heads, head_dim] + # Cache location indices + cache_loc_ptr, # [num_tokens] -> token to cache location mapping + # Scalar scale (if provided, will be used; otherwise computed per-token) + k_scale, # scalar float + v_scale, # scalar float + use_provided_scale: tl.constexpr, # whether to use provided scale + # Tensor dimensions + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + page_size: tl.constexpr, + # Strides for K input [num_tokens, num_kv_heads, head_dim] + k_stride_token: tl.constexpr, + k_stride_head: tl.constexpr, + k_stride_dim: tl.constexpr, + # Strides for K cache [total_slots, num_kv_heads, head_dim] (logically paged) + k_cache_stride_page: tl.constexpr, + k_cache_stride_offset: tl.constexpr, + k_cache_stride_head: tl.constexpr, + k_cache_stride_dim: tl.constexpr, + # Strides for V input [num_tokens, num_kv_heads, head_dim] + v_stride_token: tl.constexpr, + v_stride_head: tl.constexpr, + v_stride_dim: tl.constexpr, + # Strides for V cache [total_slots, num_kv_heads, head_dim] (logically paged) + v_cache_stride_page: tl.constexpr, + v_cache_stride_offset: tl.constexpr, + v_cache_stride_head: tl.constexpr, + v_cache_stride_dim: tl.constexpr, + # Block sizes + BLOCK_HEAD: tl.constexpr, # Number of heads per block + BLOCK_DIM: tl.constexpr, # Head dimension block size +): + """ + Fused FP8 quantization + paged KV cache write kernel. + + Each program processes one token-head_block-kv combination, quantizing and writing + to the appropriate page in the KV cache. + + Grid: (num_tokens, num_head_blocks, 2) where dim2: 0=K, 1=V + """ + # Get program IDs + token_id = tl.program_id(0) + head_block_id = tl.program_id(1) + kv_idx = tl.program_id(2) # 0 for K, 1 for V + + # Get cache location for this token + cache_loc = tl.load(cache_loc_ptr + token_id) + + # Compute page_id and offset within page + page_id = cache_loc // page_size + page_offset = cache_loc % page_size + + # Select K or V based on kv_idx + if kv_idx == 0: + # Process K tensor + inv_scale = 1.0 / k_scale if use_provided_scale else 1.0 + _process_kv_tensor( + token_id, + head_block_id, + page_id, + page_offset, + k_ptr, + k_cache_ptr, + inv_scale, + use_provided_scale, + num_kv_heads, + head_dim, + k_stride_token, + k_stride_head, + k_stride_dim, + k_cache_stride_page, + k_cache_stride_offset, + k_cache_stride_head, + k_cache_stride_dim, + BLOCK_HEAD, + BLOCK_DIM, + ) + else: + # Process V tensor + inv_scale = 1.0 / v_scale if use_provided_scale else 1.0 + _process_kv_tensor( + token_id, + head_block_id, + page_id, + page_offset, + v_ptr, + v_cache_ptr, + inv_scale, + use_provided_scale, + num_kv_heads, + head_dim, + v_stride_token, + v_stride_head, + v_stride_dim, + v_cache_stride_page, + v_cache_stride_offset, + v_cache_stride_head, + v_cache_stride_dim, + BLOCK_HEAD, + BLOCK_DIM, + ) + + +def fused_fp8_set_kv_buffer( + k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] or [num_tokens, num_kv_heads * head_dim] + v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] or [num_tokens, num_kv_heads * head_dim] + k_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim] + v_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim] + cache_loc: torch.Tensor, # [num_tokens], dtype=int32 + k_scale: Optional[ + float + ] = None, # Scalar scale (matching original set_kv_buffer signature) + v_scale: Optional[float] = None, + page_size: int = 16, + use_triton: bool = True, # Whether to use Triton kernel (set to False to force naive fallback) +) -> None: + """ + Python wrapper for the fused FP8 quantization + paged KV cache write kernel. + + This function replicates the exact behavior of the original set_kv_buffer but with + a fused kernel that combines FP8 quantization and cache write. + + Args: + k: Key tensor after RoPE, can be 2D or 3D + v: Value tensor, can be 2D or 3D + k_cache: Paged K cache buffer in FP8 + v_cache: Paged V cache buffer in FP8 + cache_loc: Cache location for each token, shape [num_tokens] + k_scale: Optional scalar scale for K (matching original set_kv_buffer) + v_scale: Optional scalar scale for V (matching original set_kv_buffer) + page_size: Number of tokens per page + use_triton: Whether to use optimized Triton kernel + """ + num_tokens = k.shape[0] + + # Step 1: Infer num_kv_heads and head_dim from cache shape + if k_cache.ndim == 3: + # 3D cache layout: [total_slots, num_kv_heads, head_dim] + total_slots, num_kv_heads, head_dim = k_cache.shape + assert ( + total_slots % page_size == 0 + ), f"total_slots ({total_slots}) must be divisible by page_size ({page_size})" + num_pages = total_slots // page_size + elif k_cache.ndim == 4: + # 4D cache layout: [num_pages, page_size, num_kv_heads, head_dim] + num_pages, ps, num_kv_heads, head_dim = k_cache.shape + assert ( + ps == page_size + ), f"page_size mismatch: cache has {ps}, expected {page_size}" + total_slots = num_pages * page_size + else: + raise ValueError(f"Unsupported k_cache.ndim={k_cache.ndim}, expected 3 or 4") + + # Step 2: Validate k, v shapes and normalize + # Store original 3D shape for Triton path + k_3d = None + v_3d = None + + if k.ndim == 3: + # Input is [num_tokens, num_kv_heads, head_dim] + assert ( + k.shape[1] == num_kv_heads + ), f"num_kv_heads mismatch: k.shape[1]={k.shape[1]} vs cache={num_kv_heads}" + assert ( + k.shape[2] == head_dim + ), f"head_dim mismatch: k.shape[2]={k.shape[2]} vs cache={head_dim}" + assert v.shape[1] == num_kv_heads and v.shape[2] == head_dim, "v shape mismatch" + + # Keep 3D for Triton kernel + k_3d = k + v_3d = v + # Create 2D view for naive fallback (will be used only if use_triton=False) + k_2d = k.reshape(num_tokens, num_kv_heads * head_dim) + v_2d = v.reshape(num_tokens, num_kv_heads * head_dim) + elif k.ndim == 2: + # Input is already [num_tokens, num_kv_heads * head_dim] + assert ( + k.shape[1] == num_kv_heads * head_dim + ), f"k.shape[1]={k.shape[1]} != {num_kv_heads * head_dim}" + assert ( + v.shape[1] == num_kv_heads * head_dim + ), f"v.shape[1]={v.shape[1]} != {num_kv_heads * head_dim}" + + # Create 3D view for Triton kernel + k_3d = k.view(num_tokens, num_kv_heads, head_dim) + v_3d = v.view(num_tokens, num_kv_heads, head_dim) + # Keep 2D for naive + k_2d = k + v_2d = v + else: + raise ValueError(f"Unsupported k.ndim={k.ndim}, expected 2 or 3") + + # Step 3: Compute cache strides based on layout + if k_cache.ndim == 3: + # 3D cache: [total_slots, num_kv_heads, head_dim] + stride_slot = k_cache.stride(0) + stride_head = k_cache.stride(1) + stride_dim = k_cache.stride(2) + + k_cache_stride_page = stride_slot * page_size + k_cache_stride_offset = stride_slot + k_cache_stride_head = stride_head + k_cache_stride_dim = stride_dim + + v_stride_slot = v_cache.stride(0) + v_stride_head = v_cache.stride(1) + v_stride_dim = v_cache.stride(2) + + v_cache_stride_page = v_stride_slot * page_size + v_cache_stride_offset = v_stride_slot + v_cache_stride_head = v_stride_head + v_cache_stride_dim = v_stride_dim + else: + # 4D cache: [num_pages, page_size, num_kv_heads, head_dim] + k_cache_stride_page = k_cache.stride(0) + k_cache_stride_offset = k_cache.stride(1) + k_cache_stride_head = k_cache.stride(2) + k_cache_stride_dim = k_cache.stride(3) + + v_cache_stride_page = v_cache.stride(0) + v_cache_stride_offset = v_cache.stride(1) + v_cache_stride_head = v_cache.stride(2) + v_cache_stride_dim = v_cache.stride(3) + + # Decide whether to use provided scale + use_provided_scale = k_scale is not None and v_scale is not None + + if use_triton and num_tokens > 0: + # Use optimized Triton kernel + # Compute input strides for 3D k, v: [num_tokens, num_kv_heads, head_dim] + k_stride_token = k_3d.stride(0) + k_stride_head = k_3d.stride(1) + k_stride_dim = k_3d.stride(2) + + v_stride_token = v_3d.stride(0) + v_stride_head = v_3d.stride(1) + v_stride_dim = v_3d.stride(2) + + # Block sizes for tiling (tunable) + BLOCK_HEAD = min(num_kv_heads, 8) # Process up to 8 heads at once + BLOCK_DIM = min(head_dim, 128) # Process up to 128 dims at once + + # Compute number of head blocks + num_head_blocks = (num_kv_heads + BLOCK_HEAD - 1) // BLOCK_HEAD + + # Grid: (num_tokens, num_head_blocks, 2) + # - dim 0: tokens + # - dim 1: head blocks + # - dim 2: K/V (0=K, 1=V) + grid = (num_tokens, num_head_blocks, 2) + + # Launch Triton kernel + _fused_fp8_set_kv_buffer_kernel[grid]( + k_3d, + v_3d, + k_cache, + v_cache, + cache_loc, + k_scale if k_scale is not None else 1.0, + v_scale if v_scale is not None else 1.0, + use_provided_scale, + num_kv_heads, + head_dim, + page_size, + k_stride_token, + k_stride_head, + k_stride_dim, + k_cache_stride_page, + k_cache_stride_offset, + k_cache_stride_head, + k_cache_stride_dim, + v_stride_token, + v_stride_head, + v_stride_dim, + v_cache_stride_page, + v_cache_stride_offset, + v_cache_stride_head, + v_cache_stride_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_DIM=BLOCK_DIM, + ) + else: + # Fallback to naive implementation + _naive_fp8_set_kv_buffer( + k_2d, v_2d, k_cache, v_cache, cache_loc, k_scale, v_scale, page_size + ) + + +def _naive_fp8_set_kv_buffer( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_loc: torch.Tensor, + k_scale: Optional[float], + v_scale: Optional[float], + page_size: int, +) -> None: + """ + Naive fallback implementation that mimics the original set_kv_buffer logic. + + This directly replicates the behavior of MHATokenToKVPool.set_kv_buffer: + 1. Apply scale (if k.dtype != cache.dtype and scale is provided) + 2. Convert to FP8 + 3. Write to cache at cache_loc + + Args: + k: [num_tokens, num_kv_heads * head_dim], already reshaped to 2D + v: [num_tokens, num_kv_heads * head_dim], already reshaped to 2D + k_cache: [total_slots, num_kv_heads, head_dim] or [num_pages, page_size, num_kv_heads, head_dim] + v_cache: Same shape as k_cache + cache_loc: [num_tokens] + k_scale: Optional scale for K + v_scale: Optional scale for V + page_size: Tokens per page + """ + num_tokens = k.shape[0] + + # Infer dimensions from cache + if k_cache.ndim == 3: + num_kv_heads = k_cache.shape[1] + head_dim = k_cache.shape[2] + elif k_cache.ndim == 4: + num_kv_heads = k_cache.shape[2] + head_dim = k_cache.shape[3] + else: + raise ValueError(f"Unsupported k_cache.ndim={k_cache.ndim}") + + # Determine target dtype and storage dtype + # See: python/sglang/srt/mem_cache/memory_pool.py:445-449 + store_dtype = k_cache.dtype + if store_dtype == torch.uint8: + # Cache is stored as uint8 for FP8 (due to index_put limitation) + dtype = torch.float8_e4m3fn # Logical dtype + else: + dtype = store_dtype # Cache dtype is the logical dtype + + # Replicate the original set_kv_buffer behavior + # See: python/sglang/srt/mem_cache/memory_pool.py:777-799 + if k.dtype != dtype: + # Need quantization - clone first to avoid modifying input + k = k.clone() + v = v.clone() + + if k_scale is not None: + k.div_(k_scale) # In-place division + if v_scale is not None: + v.div_(v_scale) # In-place division + + k = k.to(dtype) + v = v.to(dtype) + + # View FP8 as uint8 if needed (for index_put compatibility) + if store_dtype == torch.uint8 and dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + k = k.view(torch.uint8) + v = v.view(torch.uint8) + + # Reshape from [T, H*D] to [T, H, D] + k = k.view(num_tokens, num_kv_heads, head_dim) + v = v.view(num_tokens, num_kv_heads, head_dim) + + # Write to cache using advanced indexing (same as original) + if k_cache.ndim == 3: + # 3D cache: [total_slots, H, D] + k_cache[cache_loc] = k + v_cache[cache_loc] = v + else: + # 4D cache: [num_pages, page_size, H, D] + # Decompose loc into page_id and page_offset (vectorized) + page_ids = cache_loc // page_size + page_offsets = cache_loc % page_size + k_cache[page_ids, page_offsets] = k + v_cache[page_ids, page_offsets] = v diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 9c85e8113583..5dcbfb1b205f 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -5,6 +5,7 @@ The kernel supports sm100 only, with sliding window and attention sink features. """ +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -14,9 +15,12 @@ FlashInferAttnBackend, FlashInferMultiStepDraftBackend, ) +from sglang.srt.layers.attention.trtllm_fp8_kv_kernel import fused_fp8_set_kv_buffer from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available +logger = logging.getLogger(__name__) + if is_flashinfer_available(): import flashinfer @@ -411,6 +415,36 @@ def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 + def _should_use_fused_fp8_path(self, save_kv_cache: bool, k: torch.Tensor) -> bool: + """Check if we should use the fused FP8 KV cache write path.""" + return save_kv_cache and k is not None and self.data_type == torch.float8_e4m3fn + + def _fused_fp8_set_kv_buffer( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + **kwargs, + ): + """Fused FP8 quantization and KV cache write.""" + cache_loc = forward_batch.out_cache_loc + + # Get K/V cache buffers from token_to_kv_pool + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + fused_fp8_set_kv_buffer( + k=k, + v=v, + k_cache=k_cache, + v_cache=v_cache, + cache_loc=cache_loc, + k_scale=layer.k_scale, # May be None + v_scale=layer.v_scale, # May be None + page_size=self.page_size, + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" @@ -524,10 +558,26 @@ def forward_decode( ) -> torch.Tensor: """Run forward for decode using TRTLLM MHA kernel.""" cache_loc = forward_batch.out_cache_loc - if save_kv_cache and k is not None: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + + use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) + + if use_fused_fp8_path: + # Use fused FP8 quantization + KV cache write path + self._fused_fp8_set_kv_buffer( + q=q, + k=k, + v=v, + layer=layer, + forward_batch=forward_batch, ) + k = None + v = None + else: + # Use original set_kv_buffer path + if save_kv_cache and k is not None: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) if self.data_type == torch.float8_e4m3fn: q = q.to(torch.float8_e4m3fn) @@ -585,10 +635,26 @@ def forward_extend( **kwargs, ): cache_loc = forward_batch.out_cache_loc - if save_kv_cache and k is not None: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + + use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) + + if use_fused_fp8_path: + # Use fused FP8 quantization + KV cache write path + self._fused_fp8_set_kv_buffer( + q=q, + k=k, + v=v, + layer=layer, + forward_batch=forward_batch, ) + k = None + v = None + else: + # Use original set_kv_buffer path + if save_kv_cache and k is not None: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) if self.data_type == torch.float8_e4m3fn: q = q.to(torch.float8_e4m3fn) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 3fbe81257290..7caec156e700 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -422,6 +422,7 @@ def forward_prepare_npu( q_bias=getattr(self.q_norm, "bias", None), k_bias=getattr(self.k_norm, "bias", None), ) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -449,6 +450,7 @@ def forward_prepare_native( else None ), ) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -477,8 +479,14 @@ def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states + + q, k, v, fb = inner_state + attn_output = self.attn( - *inner_state, + q, + k, + v, + fb, save_kv_cache=not ( enable_fused_set_kv_buffer(forward_batch) and self.compatible_with_fused_kv_buffer diff --git a/test/manual/test_trtllm_fp8_kv_kernel.py b/test/manual/test_trtllm_fp8_kv_kernel.py new file mode 100644 index 000000000000..e980ac221110 --- /dev/null +++ b/test/manual/test_trtllm_fp8_kv_kernel.py @@ -0,0 +1,306 @@ +""" +Unit tests for TRTLLM FP8 KV cache fusion kernel. +""" + +import unittest + +import torch + +from sglang.srt.layers.attention.trtllm_fp8_kv_kernel import fused_fp8_set_kv_buffer +from sglang.test.test_utils import CustomTestCase + + +class TestTRTLLMFP8KVKernel(CustomTestCase): + """Test fused FP8 KV cache write kernel correctness.""" + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available") + + if torch.cuda.get_device_capability()[0] < 9: + raise unittest.SkipTest("FP8 requires compute capability >= 9.0") + + def _test_kernel_correctness( + self, + num_tokens, + num_kv_heads, + head_dim, + page_size, + use_scale, + input_ndim, + cache_ndim, + ): + """Compare Triton kernel output against naive implementation.""" + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Create input tensors + if input_ndim == 3: + k = torch.randn( + num_tokens, num_kv_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, device=device, dtype=dtype + ) + else: + k = torch.randn( + num_tokens, num_kv_heads * head_dim, device=device, dtype=dtype + ) + v = torch.randn( + num_tokens, num_kv_heads * head_dim, device=device, dtype=dtype + ) + + # Create cache tensors (use FP8 to match real runtime behavior) + num_pages = 128 + total_slots = num_pages * page_size + cache_dtype = torch.float8_e4m3fn + if cache_ndim == 3: + k_cache_triton = torch.zeros( + total_slots, num_kv_heads, head_dim, device=device, dtype=cache_dtype + ) + v_cache_triton = torch.zeros( + total_slots, num_kv_heads, head_dim, device=device, dtype=cache_dtype + ) + k_cache_naive = torch.zeros( + total_slots, num_kv_heads, head_dim, device=device, dtype=cache_dtype + ) + v_cache_naive = torch.zeros( + total_slots, num_kv_heads, head_dim, device=device, dtype=cache_dtype + ) + else: + k_cache_triton = torch.zeros( + num_pages, + page_size, + num_kv_heads, + head_dim, + device=device, + dtype=cache_dtype, + ) + v_cache_triton = torch.zeros( + num_pages, + page_size, + num_kv_heads, + head_dim, + device=device, + dtype=cache_dtype, + ) + k_cache_naive = torch.zeros( + num_pages, + page_size, + num_kv_heads, + head_dim, + device=device, + dtype=cache_dtype, + ) + v_cache_naive = torch.zeros( + num_pages, + page_size, + num_kv_heads, + head_dim, + device=device, + dtype=cache_dtype, + ) + + # Create cache locations (ensure unique indices to avoid race conditions) + cache_loc = torch.randperm(total_slots, device=device, dtype=torch.int32)[ + :num_tokens + ] + + # Optional scales + k_scale = 0.5 if use_scale else None + v_scale = 0.75 if use_scale else None + + # Run Triton kernel + fused_fp8_set_kv_buffer( + k.clone(), + v.clone(), + k_cache_triton, + v_cache_triton, + cache_loc, + k_scale, + v_scale, + page_size, + use_triton=True, + ) + + # Run naive fallback + fused_fp8_set_kv_buffer( + k.clone(), + v.clone(), + k_cache_naive, + v_cache_naive, + cache_loc, + k_scale, + v_scale, + page_size, + use_triton=False, + ) + + # Compare results (bit-exact match expected) + self.assertTrue( + torch.equal(k_cache_triton, k_cache_naive), + "K cache mismatch between Triton and naive", + ) + self.assertTrue( + torch.equal(v_cache_triton, v_cache_naive), + "V cache mismatch between Triton and naive", + ) + + def test_basic_3d_input_3d_cache(self): + """Test basic case: 3D input, 3D cache, no scale.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=False, + input_ndim=3, + cache_ndim=3, + ) + + def test_basic_3d_input_4d_cache(self): + """Test basic case: 3D input, 4D cache, no scale.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=False, + input_ndim=3, + cache_ndim=4, + ) + + def test_with_scale_3d_cache(self): + """Test with scale: 3D input, 3D cache.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=True, + input_ndim=3, + cache_ndim=3, + ) + + def test_with_scale_4d_cache(self): + """Test with scale: 3D input, 4D cache.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=True, + input_ndim=3, + cache_ndim=4, + ) + + def test_2d_input_3d_cache(self): + """Test 2D input (flattened): 2D input, 3D cache.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=False, + input_ndim=2, + cache_ndim=3, + ) + + def test_2d_input_4d_cache(self): + """Test 2D input (flattened): 2D input, 4D cache.""" + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=False, + input_ndim=2, + cache_ndim=4, + ) + + def test_single_token(self): + """Test edge case: single token.""" + self._test_kernel_correctness( + num_tokens=1, + num_kv_heads=8, + head_dim=128, + page_size=16, + use_scale=True, + input_ndim=3, + cache_ndim=3, + ) + + def test_large_batch(self): + """Test larger batch size.""" + self._test_kernel_correctness( + num_tokens=128, + num_kv_heads=16, + head_dim=64, + page_size=16, + use_scale=True, + input_ndim=3, + cache_ndim=4, + ) + + def test_different_head_dims(self): + """Test different head dimensions.""" + for head_dim in [64, 128]: + self._test_kernel_correctness( + num_tokens=16, + num_kv_heads=8, + head_dim=head_dim, + page_size=16, + use_scale=False, + input_ndim=3, + cache_ndim=3, + ) + + def test_empty_input(self): + """Test edge case: empty input (0 tokens).""" + device = torch.device("cuda") + dtype = torch.bfloat16 + num_kv_heads = 8 + head_dim = 128 + page_size = 16 + num_tokens = 0 + + # Empty inputs + k = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=dtype) + + # Cache (use FP8 to match real runtime behavior) + total_slots = 128 + k_cache = torch.zeros( + total_slots, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float8_e4m3fn, + ) + v_cache = torch.zeros( + total_slots, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float8_e4m3fn, + ) + + # Empty cache locations + cache_loc = torch.empty(num_tokens, device=device, dtype=torch.int32) + + # Should not crash + fused_fp8_set_kv_buffer( + k, + v, + k_cache, + v_cache, + cache_loc, + k_scale=None, + v_scale=None, + page_size=page_size, + ) + + +if __name__ == "__main__": + unittest.main()