Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions cpp/tensorrt_llm/kernels/IndexerKCacheScatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/common/cudaUtils.h"

namespace tensorrt_llm::kernels
{

void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3,
cudaStream_t stream = 0);

}
152 changes: 152 additions & 0 deletions cpp/tensorrt_llm/kernels/indexerKCacheScatter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "IndexerKCacheScatter.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"

namespace tensorrt_llm::kernels
{

namespace
{
/**
* Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3],
* find the actual memory offset within the given k cache pool using the strides.
*/
__device__ __forceinline__ int64_t flatIndexToMemoryOffset(
int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3)
{
// Unravel from innermost to outermost dimension
int32_t i3 = flat_idx % d3;
flat_idx /= d3;

int32_t i2 = flat_idx % d2;
flat_idx /= d2;

int32_t i1 = flat_idx % d1;
flat_idx /= d1;

int32_t i0 = flat_idx;

// Compute memory offset using strides
return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3;
}

} // anonymous namespace

/**
* CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool
*
* @param k_fp8_bytes Quantized FP8 data [num_tokens, 128]
* @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4]
* @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be
* non-contiguous)
* @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens]
* @param slot_mapping_scale Flat element index for scale data start position [num_tokens]
* @param num_tokens Number of tokens
* @param head_dim Head dimension (must be 128)
* @param scale_size Scale size in bytes (must be 4)
* @param cache_stride_0 Stride for k_cache dimension 0 (in bytes)
* @param cache_stride_1 Stride for k_cache dimension 1 (in bytes)
* @param cache_stride_2 Stride for k_cache dimension 2 (in bytes)
* @param cache_stride_3 Stride for k_cache dimension 3 (in bytes)
* @param cache_dim_0 Size of k_cache dimension 0
* @param cache_dim_1 Size of k_cache dimension 1
* @param cache_dim_2 Size of k_cache dimension 2
* @param cache_dim_3 Size of k_cache dimension 3
*/
__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes,
uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache,
int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens,
int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2,
int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3)
{
// For head_dim=128, each thread handles 4 bytes/elements per read/write instruction
constexpr int VEC_SIZE = 4;

// Token index from block.x
int32_t token_idx = blockIdx.x;

if (token_idx >= num_tokens)
{
return;
}

int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx];
int64_t flat_idx_scale_base = slot_mapping_scale[token_idx];

if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0)
{
return;
}

int32_t head_dim_idx = threadIdx.x * VEC_SIZE;
int64_t flat_idx = flat_idx_fp8_base + head_dim_idx;

// Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous)
int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3,
cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
int64_t src_offset = token_idx * head_dim + head_dim_idx;

// 4 bytes write
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset]) = *reinterpret_cast<uint32_t const*>(&k_fp8_bytes[src_offset]);

// Only thread 0 writes the single 4 bytes scale value
if (threadIdx.x == 0)
{
int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2,
cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4

// 4 bytes write for scale
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset_scale])
= *reinterpret_cast<uint32_t const*>(&k_scale_bytes[src_offset_scale]);
}
}

void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream)
{
if (num_tokens == 0)
{
return;
}

// Assertions for DeepSeek-V3.2 configuration
constexpr int32_t QUANT_BLOCK_SIZE = 128;
TLLM_CHECK_WITH_INFO(
head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim);
TLLM_CHECK_WITH_INFO(
scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size);

// For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale
constexpr int32_t THREADS_PER_BLOCK = 32;

dim3 block(THREADS_PER_BLOCK);
dim3 grid(num_tokens);

indexerKCacheScatterUnifiedKernel<<<grid, block, 0, stream>>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8,
slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2,
cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3);

// Check for kernel launch errors
TLLM_CUDA_CHECK(cudaGetLastError());
}

} // namespace tensorrt_llm::kernels
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ add_library(
fp8PerTensorScaleMoe.cpp
fp4BlockScaleMoe.cpp
noAuxTcOp.cpp
IndexerKCacheScatterOp.cpp
ncclCommunicatorOp.cpp
parallelDecodeKVCacheUpdateOp.cpp
redrafterCurandOp.cpp
Expand Down
106 changes: 106 additions & 0 deletions cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"

#include "tensorrt_llm/kernels/IndexerKCacheScatter.h"

namespace th = torch;
namespace tl = tensorrt_llm;
namespace tk = tensorrt_llm::kernels;

namespace torch_ext
{

void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
{
// Validate all tensors are CUDA tensors
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
&& slot_mapping_scale.is_cuda(),
"All tensors must be CUDA tensors");

// Validate tensor dimensions
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");

// Enforce k_cache is 4D tensor
TORCH_CHECK(k_cache.dim() == 4,
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
static_cast<int>(k_cache.dim()));

// Validate tensor dtypes
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");

// Validate tensor shapes are consistent
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
TORCH_CHECK(
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");

// Validate tensors are contiguous (except k_cache which may be non-contiguous)
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
// k_cache can be non-contiguous - we handle this via strides
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");

int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes

int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size

// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);

int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));

auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());

tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
cache_stride_1, cache_stride_2, cache_stride_3, stream);
}

} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("indexer_k_cache_scatter_op", &torch_ext::indexer_k_cache_scatter_op);
}
20 changes: 4 additions & 16 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,24 +872,12 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
num_tokens, scale_size)

# Scatter FP8 data
# Use CUDA kernel to scatter FP8 and scale bytes into cache
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
byte_offsets = torch.arange(head_dim, device=k_cache.device).unsqueeze(
0) # [1, head_dim]
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(
1) + byte_offsets # [num_tokens, head_dim]
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
k_cache.shape)
k_cache[scatter_indices_fp8] = k_fp8_bytes

flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
byte_offsets = torch.arange(
scale_size, device=k_cache.device).unsqueeze(0) # [1, scale_size]
scatter_indices_scale = flat_indices_scale.unsqueeze(
1) + byte_offsets # [num_tokens, scale_size]
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
k_cache.shape)
k_cache[scatter_indices_scale] = k_scale_bytes
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
k_cache, flat_indices_fp8,
flat_indices_scale)

def _gather_k_cache_for_chunk(
self,
Expand Down
Loading