diff --git a/cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h b/cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h new file mode 100644 index 00000000000..b0ac689d38b --- /dev/null +++ b/cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm::kernels +{ + +void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache, + int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim, + int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3, + int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, + cudaStream_t stream = 0); + +} diff --git a/cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu b/cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu new file mode 100644 index 00000000000..3cb35273a94 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "IndexerKCacheScatter.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm::kernels +{ + +namespace +{ +/** + * Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3], + * find the actual memory offset within the given k cache pool using the strides. + */ +__device__ __forceinline__ int64_t flatIndexToMemoryOffset( + int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3) +{ + // Unravel from innermost to outermost dimension + int32_t i3 = flat_idx % d3; + flat_idx /= d3; + + int32_t i2 = flat_idx % d2; + flat_idx /= d2; + + int32_t i1 = flat_idx % d1; + flat_idx /= d1; + + int32_t i0 = flat_idx; + + // Compute memory offset using strides + return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3; +} + +} // anonymous namespace + +/** + * CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool + * + * @param k_fp8_bytes Quantized FP8 data [num_tokens, 128] + * @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4] + * @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be + * non-contiguous) + * @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens] + * @param slot_mapping_scale Flat element index for scale data start position [num_tokens] + * @param num_tokens Number of tokens + * @param head_dim Head dimension (must be 128) + * @param scale_size Scale size in bytes (must be 4) + * @param cache_stride_0 Stride for k_cache dimension 0 (in bytes) + * @param cache_stride_1 Stride for k_cache dimension 1 (in bytes) + * @param cache_stride_2 Stride for k_cache dimension 2 (in bytes) + * @param cache_stride_3 Stride for k_cache dimension 3 (in bytes) + * @param cache_dim_0 Size of k_cache dimension 0 + * @param cache_dim_1 Size of k_cache dimension 1 + * @param cache_dim_2 Size of k_cache dimension 2 + * @param cache_dim_3 Size of k_cache dimension 3 + */ +__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes, + uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache, + int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens, + int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, + int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3) +{ + // For head_dim=128, each thread handles 4 bytes/elements per read/write instruction + constexpr int VEC_SIZE = 4; + + // Token index from block.x + int32_t token_idx = blockIdx.x; + + if (token_idx >= num_tokens) + { + return; + } + + int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx]; + int64_t flat_idx_scale_base = slot_mapping_scale[token_idx]; + + if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0) + { + return; + } + + int32_t head_dim_idx = threadIdx.x * VEC_SIZE; + int64_t flat_idx = flat_idx_fp8_base + head_dim_idx; + + // Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous) + int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, + cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3); + int64_t src_offset = token_idx * head_dim + head_dim_idx; + + // 4 bytes write + *reinterpret_cast(&k_cache[dst_offset]) = *reinterpret_cast(&k_fp8_bytes[src_offset]); + + // Only thread 0 writes the single 4 bytes scale value + if (threadIdx.x == 0) + { + int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2, + cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3); + int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4 + + // 4 bytes write for scale + *reinterpret_cast(&k_cache[dst_offset_scale]) + = *reinterpret_cast(&k_scale_bytes[src_offset_scale]); + } +} + +void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache, + int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim, + int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3, + int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream) +{ + if (num_tokens == 0) + { + return; + } + + // Assertions for DeepSeek-V3.2 configuration + constexpr int32_t QUANT_BLOCK_SIZE = 128; + TLLM_CHECK_WITH_INFO( + head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim); + TLLM_CHECK_WITH_INFO( + scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size); + + // For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale + constexpr int32_t THREADS_PER_BLOCK = 32; + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(num_tokens); + + indexerKCacheScatterUnifiedKernel<<>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8, + slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2, + cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3); + + // Check for kernel launch errors + TLLM_CUDA_CHECK(cudaGetLastError()); +} + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index e472e3e133e..9354669468b 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -83,6 +83,7 @@ add_library( fp8PerTensorScaleMoe.cpp fp4BlockScaleMoe.cpp noAuxTcOp.cpp + IndexerKCacheScatterOp.cpp ncclCommunicatorOp.cpp parallelDecodeKVCacheUpdateOp.cpp redrafterCurandOp.cpp diff --git a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp new file mode 100644 index 00000000000..b94674f1ca7 --- /dev/null +++ b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/runtime/torchUtils.h" + +#include "tensorrt_llm/kernels/IndexerKCacheScatter.h" + +namespace th = torch; +namespace tl = tensorrt_llm; +namespace tk = tensorrt_llm::kernels; + +namespace torch_ext +{ + +void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache, + th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale) +{ + // Validate all tensors are CUDA tensors + TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() + && slot_mapping_scale.is_cuda(), + "All tensors must be CUDA tensors"); + + // Validate tensor dimensions + TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]"); + TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]"); + TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]"); + TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]"); + + // Enforce k_cache is 4D tensor + TORCH_CHECK(k_cache.dim() == 4, + "k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions", + static_cast(k_cache.dim())); + + // Validate tensor dtypes + TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8"); + TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8"); + TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64"); + TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64"); + + // Validate tensor shapes are consistent + auto num_tokens = static_cast(k_fp8_bytes.size(0)); + TORCH_CHECK( + k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension"); + TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens"); + TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens"); + + // Validate tensors are contiguous (except k_cache which may be non-contiguous) + TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous"); + TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous"); + // k_cache can be non-contiguous - we handle this via strides + TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous"); + TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous"); + + int32_t head_dim = static_cast(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128 + int32_t scale_size = static_cast(k_scale_bytes.size(1)); // scale_size = 4 bytes + + int32_t cache_dim_0 = static_cast(k_cache.size(0)); // num_blocks + int32_t cache_dim_1 = static_cast(k_cache.size(1)); // block_size + int32_t cache_dim_2 = static_cast(k_cache.size(2)); // num_kv_heads + int32_t cache_dim_3 = static_cast(k_cache.size(3)); // per_token_size + + // Validation for indexer k cache pool for DeepSeek-V3.2 constraints + TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2); + TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim); + TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size); + + int64_t cache_stride_0 = static_cast(k_cache.stride(0)); + int64_t cache_stride_1 = static_cast(k_cache.stride(1)); + int64_t cache_stride_2 = static_cast(k_cache.stride(2)); + int64_t cache_stride_3 = static_cast(k_cache.stride(3)); + + auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device()); + + tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr(), k_scale_bytes.data_ptr(), + k_cache.data_ptr(), slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), + num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, + cache_stride_1, cache_stride_2, cache_stride_3, stream); +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, " + "Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("indexer_k_cache_scatter_op", &torch_ext::indexer_k_cache_scatter_op); +} diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 8fa57e20135..3a1e1169c18 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -872,24 +872,12 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, k_scale_bytes = k_scale_flat.view(torch.uint8).view( num_tokens, scale_size) - # Scatter FP8 data + # Use CUDA kernel to scatter FP8 and scale bytes into cache flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens] - byte_offsets = torch.arange(head_dim, device=k_cache.device).unsqueeze( - 0) # [1, head_dim] - scatter_indices_fp8 = flat_indices_fp8.unsqueeze( - 1) + byte_offsets # [num_tokens, head_dim] - scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8, - k_cache.shape) - k_cache[scatter_indices_fp8] = k_fp8_bytes - flat_indices_scale = metadata.slot_mapping_scale[:num_tokens] - byte_offsets = torch.arange( - scale_size, device=k_cache.device).unsqueeze(0) # [1, scale_size] - scatter_indices_scale = flat_indices_scale.unsqueeze( - 1) + byte_offsets # [num_tokens, scale_size] - scatter_indices_scale = _unravel_indices(scatter_indices_scale, - k_cache.shape) - k_cache[scatter_indices_scale] = k_scale_bytes + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, + k_cache, flat_indices_fp8, + flat_indices_scale) def _gather_k_cache_for_chunk( self, diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index e921fd891b0..8414fde36d0 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -12,7 +12,7 @@ import pytest import torch -from utils.util import check_accuracy, getSMVersion +from utils.util import check_accuracy, skip_pre_hopper from tensorrt_llm import deep_gemm from tensorrt_llm._torch.attention_backend.interface import ( @@ -70,12 +70,9 @@ def __init__(self, index_head_dim, index_n_heads, index_topk): index_topk=2048) # Create KV cache config - # Note: max_attention_window expects list[int] (one per layer) kv_cache_config = KvCacheConfig( enable_block_reuse=False, max_tokens=max_seq_len * batch_size, - max_attention_window=[max_seq_len] * - num_layers, # List of max window per layer ) # Create mapping (single GPU, no parallelism) @@ -303,8 +300,7 @@ def _ref_fp8_mqa_logits( @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif(getSMVersion() < 90, - reason="fp8_mqa_logits is only supported in SM90 and SM100") +@skip_pre_hopper def test_deepgemm_fp8_mqa_logits_basic(): """ Basic test for deepgemm.fp8_mqa_logits kernel. @@ -477,7 +473,179 @@ def __init__(self): @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+") +@skip_pre_hopper +def test_indexer_k_cache_scatter_custom_op(): + """ + Direct comparison: CUDA kernel vs Python reference for k_cache scatter. + + This test ensures the new CUDA kernel indexer_k_cache_scatter_op produces + exactly the same results as the Python scatter implementation. + """ + torch.manual_seed(123) + + # Test parameters + head_dim = 128 + block_size = 64 + batch_size = 3 + num_tokens = 96 # 3 requests × 32 tokens each + max_seq_len = 512 + + # Use different layers for CUDA vs Python to test non-contiguous handling + layer_idx_cuda = 1 # CUDA kernel writes to layer 0 + layer_idx_python = 2 # Python reference writes to layer 1 + + # Create cache manager with multiple layers + cache_manager, sparse_attn_config = create_dsa_cache_manager( + batch_size=batch_size, + head_dim=head_dim, + tokens_per_block=block_size, + max_seq_len=max_seq_len, + num_layers=3) # Multi-layer pool for non-contiguous test + + # Allocate blocks + request_ids = list(range(batch_size)) + tokens_per_req = [32, 32, 32] + cache_manager.add_dummy_requests(request_ids, + tokens_per_req, + is_gen=False, + prepare_resource=True) + + # Create metadata + metadata = _create_mock_metadata( + request_ids, + batch_size, + num_contexts=batch_size, + num_generations=0, + seq_lens=torch.tensor(tokens_per_req, dtype=torch.int32), + kv_lens=torch.tensor(tokens_per_req, dtype=torch.int32), + num_cached_tokens=[0] * batch_size, + cache_manager=cache_manager, + num_ctx_tokens=num_tokens, + num_tokens=num_tokens, + ) + + from tensorrt_llm._torch.attention_backend.sparse.dsa import Indexer + Indexer.prepare(metadata) + + # Generate test data + k_original = torch.randn((num_tokens, head_dim), + device="cuda", + dtype=torch.bfloat16) + k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original) + + # Prepare byte-level data + scale_size = k_scale.shape[1] * 4 + k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim) + k_scale_flat = k_scale.view(-1) + if k_scale_flat.stride(-1) != 1: + k_scale_flat = torch.as_strided(k_scale_flat.contiguous(), + size=(k_scale_flat.numel(), ), + stride=(1, )) + k_scale_bytes = k_scale_flat.view(torch.uint8).view(num_tokens, scale_size) + + flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens] + flat_indices_scale = metadata.slot_mapping_scale[:num_tokens] + + # ========== Use Different Layers for CUDA vs Python ========== + # Simple approach: use layer 0 for CUDA, layer 1 for Python + # Both get the same input data, but write to different layers + # Then we extract and compare the outputs from each layer + + # Get k_cache for CUDA path (layer 0) + k_cache_cuda = cache_manager.get_indexer_k_cache_buffers(layer_idx_cuda) + k_cache_cuda.zero_() + + # Get k_cache for Python path (layer 1) + k_cache_python = cache_manager.get_indexer_k_cache_buffers(layer_idx_python) + k_cache_python.zero_() + + # Print cache properties + print(f"\n=== Cache Properties ===") + print(f" CUDA (layer {layer_idx_cuda}):") + print(f" Shape: {k_cache_cuda.shape}") + print(f" Stride: {k_cache_cuda.stride()}") + print(f" is_contiguous: {k_cache_cuda.is_contiguous()}") + print(f" Python (layer {layer_idx_python}):") + print(f" Shape: {k_cache_python.shape}") + print(f" Stride: {k_cache_python.stride()}") + print(f" is_contiguous: {k_cache_python.is_contiguous()}") + + # ========== Path 1: CUDA Kernel ========== + print(f"\n=== Path 1: CUDA Kernel ===") + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, + k_cache_cuda, flat_indices_fp8, + flat_indices_scale) + torch.cuda.synchronize() + print(f"✓ CUDA kernel completed") + + # ========== Path 2: Python Reference ========== + print(f"\n=== Path 2: Python Reference ===") + + def _unravel_indices(flat_indices, shape): + d3 = shape[3] + i3 = flat_indices % d3 + flat_indices = flat_indices // d3 + d2 = shape[2] + i2 = flat_indices % d2 + flat_indices = flat_indices // d2 + d1 = shape[1] + i1 = flat_indices % d1 + flat_indices = flat_indices // d1 + i0 = flat_indices + return i0, i1, i2, i3 + + # Scatter FP8 data + byte_offsets = torch.arange(head_dim, + device=k_cache_python.device).unsqueeze(0) + scatter_indices_fp8 = flat_indices_fp8.unsqueeze(1) + byte_offsets + scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8, + k_cache_python.shape) + k_cache_python[scatter_indices_fp8] = k_fp8_bytes + + # Scatter scale data + byte_offsets = torch.arange(scale_size, + device=k_cache_python.device).unsqueeze(0) + scatter_indices_scale = flat_indices_scale.unsqueeze(1) + byte_offsets + scatter_indices_scale = _unravel_indices(scatter_indices_scale, + k_cache_python.shape) + k_cache_python[scatter_indices_scale] = k_scale_bytes + + # ========== Validation: Byte-for-Byte Comparison ========== + print(f"\n=== Validation ===") + + total_bytes = k_cache_cuda.numel() + + # Compare entire cache tensors + if torch.equal(k_cache_cuda, k_cache_python): + print(f"✅ PERFECT MATCH! CUDA and Python produce identical cache") + print(f" Total bytes compared: {total_bytes}") + print( + f" Tokens: {num_tokens}, head_dim: {head_dim}, block_size: {block_size}" + ) + else: + # Find differences + diff_mask = k_cache_cuda != k_cache_python + num_diffs = diff_mask.sum().item() + + print( + f"⚠️ Found {num_diffs}/{total_bytes} byte differences ({100*num_diffs/total_bytes:.4f}%)" + ) + + # Show first few differences + diff_indices = torch.nonzero(diff_mask.view(-1))[:5] + for idx in diff_indices: + flat_idx = idx.item() + print( + f" Byte {flat_idx}: CUDA={k_cache_cuda.view(-1)[flat_idx].item()}, " + f"Python={k_cache_python.view(-1)[flat_idx].item()}") + + # Fail the test + raise AssertionError( + "CUDA kernel produced different results than Python reference") + + +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +@skip_pre_hopper def test_fp8_k_cache_roundtrip(): """Verify FP8 quantization scales survive write/read cycle for multiple requests.""" torch.manual_seed(42) @@ -562,9 +730,7 @@ def test_fp8_k_cache_roundtrip(): @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif( - getSMVersion() < 90, - reason="fp8_paged_mqa_logits is only supported in SM90 and SM100") +@skip_pre_hopper @pytest.mark.parametrize("batch_size,next_n", [(4, 1), (2, 2)]) def test_indexer_decode_with_paged_kv_cache(batch_size, next_n): """ @@ -859,7 +1025,7 @@ def test_split_prefill_chunks(max_chunk_size, seq_lens, start_idx, @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif(getSMVersion() < 90, reason="FP8 operations require SM90+") +@skip_pre_hopper @pytest.mark.parametrize( "chunk_size,seq_lens_list,chunking_type", [