Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class KVCacheBlockPool
, sizePerHead(sizePerHead)
, tokensPerBlock(tokensPerBlock)
, quantSize(quantSize)
, blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize)
, blockSize(numKvHeads * sizePerHead * tokensPerBlock)
, primaryPtr(std::move(primaryPtr))
, secondaryPtr(std::move(secondaryPtr))
, containsBlockScales(containsBlockScales)
Expand Down Expand Up @@ -1230,6 +1230,8 @@ class BaseKVCacheManager

[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers() const = 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers() const = 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping() const = 0;

virtual void getBlockOffsetsOfBatch(
Expand Down Expand Up @@ -1524,7 +1526,7 @@ class KVCacheManager : public BaseKVCacheManager
return mLayerToPoolMapping;
}

[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const override
{
// TODO: add a new optional model input so the attention plugin can access these
return mBlockScalePoolPointers;
Expand Down
30 changes: 21 additions & 9 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,16 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
}

#ifdef ENABLE_FP4
SizeType32 const numEltsPerContainer = mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
if (numEltsPerContainer == 2)
{
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
}
#else
SizeType32 const numEltsPerContainer = 1;
#endif

size_t poolIndex = 0;
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
{
Expand All @@ -519,7 +529,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
mLayerToPoolIndex[layerIdx] = poolIndex;
}
}
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1);
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock, 1);
++poolIndex;
}

Expand Down Expand Up @@ -604,14 +614,20 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co

void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
{

#ifdef ENABLE_FP4
SizeType32 const numEltsPerContainer = mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
#else
SizeType32 const numEltsPerContainer = 1;
#endif
auto num_pools = mPools.size();
for (size_t i = 0; i < num_pools; ++i)
{
auto& kv_pool = mPools[i];
TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0,
"Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize.");

mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead,
TLLM_CHECK_WITH_INFO((kv_pool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
"Cannot use FP4 quantization since kv_pool.sizePerHead is not divisible by FP4 quantBlockSize.");
auto blockScaleSizePerHead = kv_pool.sizePerHead * numEltsPerContainer / quantBlockSize;
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, blockScaleSizePerHead,
kv_pool.tokensPerBlock, quantBlockSize,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
Expand Down Expand Up @@ -646,10 +662,6 @@ void WindowBlockManager::allocatePools(bool useUvm)

if (poolIsFP4)
{
TLLM_CHECK_WITH_INFO(blockSize % 2 == 0, "Block size must be divisible by 2 for FP4 KV cache.");
// Divide by 2. We can't create FP4 buffers directly, so we'll have to create a uint8 buffer with
// half the expected number of elements.
blockSize /= 2;
poolDtype = nvinfer1::DataType::kINT8;
}

Expand Down
109 changes: 98 additions & 11 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
}
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
}
else
{
xqaParams.kv_cache_data_type = xqaParams.data_type;
Expand Down Expand Up @@ -913,6 +917,9 @@ int AttentionOp::mlaGeneration(
generation_params.can_use_one_more_block, generation_params.host_primary_pool_pointer,
generation_params.host_secondary_pool_pointer, generation_params.block_offsets);

// Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
auto kv_scale_cache_buffer = KVBlockArray();

// Workspace pointer shift
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
size_t offset = 0;
Expand Down Expand Up @@ -1189,7 +1196,7 @@ int AttentionOp::mlaGeneration(
{
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
xqaParams.stream = stream;
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
return 0;
}
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
Expand Down Expand Up @@ -1263,8 +1270,23 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
float const q_scaling = mQScaling;

KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
KVCacheBuffer kv_scale_cache_buffer;

int elemBits;
if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache())
{
elemBits = 8;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
elemBits = 4;
}
else
{
elemBits = sizeof(T) * 8;
}
auto sizePerToken = mNumKVHeads * headSize * elemBits / 8 /*bits*/;

if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
Expand All @@ -1273,6 +1295,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer,
params.host_secondary_pool_pointer, params.block_offsets);
if (mKVCacheQuantMode.hasFp4KvCache())
{
kv_scale_cache_buffer = KVBlockArray(params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock,
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block,
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
params.block_offsets);
}
}
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
{
Expand All @@ -1281,6 +1311,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
isCrossAttention() ? params.cross_kv_length : params.max_attention_window_size, sizePerToken,
params.cyclic_attention_window_size, params.sink_token_length, false,
reinterpret_cast<BufferDataType*>(params.key_value_cache));
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
}
}

Expand Down Expand Up @@ -1478,9 +1509,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
sync_check_cuda_error(stream);
}

KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache()
? KvCacheDataType::INT8
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
KvCacheDataType cache_type{KvCacheDataType::BASE};
if (mKVCacheQuantMode.hasInt8KvCache())
{
cache_type = KvCacheDataType::INT8;
}
else if (mKVCacheQuantMode.hasFp8KvCache())
{
cache_type = KvCacheDataType::FP8;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
cache_type = KvCacheDataType::NVFP4;
}

cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
int const attention_seq_len_1 = params.input_seq_length; // q length
Expand Down Expand Up @@ -1529,6 +1570,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
preprocessingParams.q_output = q_buf_2_;
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
preprocessingParams.qkv_bias = params.qkv_bias;
preprocessingParams.tokens_info = decoder_params.tokensInfo;
preprocessingParams.seq_lens = params.context_lengths;
Expand All @@ -1541,7 +1583,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
preprocessingParams.kv_scale_orig_quant = params.kv_scale_orig_quant;
// The second level scale for NVFP4 KV cache. It points to an array of 2 float, separate scales for K and V.
// If set to nullptr, kv_scale_orig_quant will be used instead for both K and V.
preprocessingParams.kv_cache_scale_factors = nullptr;
preprocessingParams.spec_decoding_position_offsets = nullptr;
preprocessingParams.logn_scaling = params.logn_scaling_ptr;

Expand Down Expand Up @@ -1691,6 +1736,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
else
{
fmhaParams.pagedKvCache = kv_cache_buffer;
fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
}
}
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
Expand Down Expand Up @@ -2040,8 +2086,24 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
int32_t const batch_beam = params.beam_width * params.num_requests;

KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
KVCacheBuffer kv_scale_cache_buffer;

int elemBits;
if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache())
{
elemBits = 8;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
elemBits = 4;
}
else
{
elemBits = sizeof(T) * 8;
}

auto const sizePerToken = mNumKVHeads * headSize * elemBits / 8 /*bits*/;

if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
Expand All @@ -2051,13 +2113,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length,
params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer,
reinterpret_cast<BufferDataType*>(params.block_offsets));
if (mKVCacheQuantMode.hasFp4KvCache())
{
kv_scale_cache_buffer = KVBlockArray(batch_beam, params.max_blocks_per_sequence, mTokensPerBlock,
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block,
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
reinterpret_cast<BufferDataType*>(params.block_offsets));
}
}
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
{
using BufferDataType = typename KVCacheBuffer::DataType;
kv_cache_buffer = KVLinearBuffer(batch_beam, params.max_attention_window_size, sizePerToken,
params.cyclic_attention_window_size, params.sink_token_length, false,
reinterpret_cast<BufferDataType*>(params.key_value_cache));
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
}
}
sync_check_cuda_error(stream);
Expand Down Expand Up @@ -2133,7 +2204,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
xqaParams.output = mhaOutput;
xqaParams.qkv = attention_input;
}
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
Expand All @@ -2150,6 +2221,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
{
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 KV cache.");
}
}

// This is the number of kv tokens that q needs to visit, but excluding one as it will be processed before the kv
Expand Down Expand Up @@ -2414,6 +2489,10 @@ int AttentionOp::initialize() noexcept
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100, "fuse_fp4_quant only supports SM100 devices.");

// Check requirements for FP4 KV cache.
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,
"mFP8ContextFMHA must enable if FP4 KV cache is enabled");

TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
Expand Down Expand Up @@ -2490,7 +2569,10 @@ int AttentionOp::initialize() noexcept
{
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
}
// TODO: add FP4 KV cache support.
else if (mKVCacheQuantMode.hasFp4KvCache())
{
fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
}
}
// The output dtype.
fmhaParams.dataTypeOut = data_type;
Expand Down Expand Up @@ -2692,6 +2774,11 @@ int AttentionOp::initialize() noexcept
fixedParams.kvDataType = DATA_TYPE_E4M3;
fixedParams.mathDataType = DATA_TYPE_E4M3;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
fixedParams.kvDataType = DATA_TYPE_E2M1;
fixedParams.mathDataType = DATA_TYPE_E4M3;
}
else
{
fixedParams.kvDataType = fixedParams.inputDataType;
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class AttentionOp
kernels::KVBlockArray::DataType* block_offsets = nullptr;
void* host_primary_pool_pointer = nullptr;
void* host_secondary_pool_pointer = nullptr;
void* host_primary_block_scale_pool_pointer = nullptr;
void* host_secondary_block_scale_pool_pointer = nullptr;
int32_t num_tokens = 0;
int32_t max_blocks_per_sequence = 0;
int32_t const* sequence_lengths = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ struct MHARunnerParams
void const* kvPtr;
// The paged kv cache array.
KVBlockArray pagedKvCache;
// The paged kv cache array for scaling factor.
KVBlockArray pagedKvSfCache;
// The output buffer ptr.
void* outputPtr;
// The output scaling factor buffer ptr. (only used for FP4 output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_cache_scale_factors = nullptr;
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class XQAKernelList
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_cache_scale_factors = nullptr;
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
Expand Down
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized.");
// Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams
void const* kvPoolPtr = nullptr;
void const* kvSfPoolPtr = nullptr;
void const* kvPageIdxPtr = nullptr;
auto qkvLayout = kernels::QkvLayout::PackedQkv;
int32_t maxBlocksPerSeq = 0;
Expand All @@ -143,6 +144,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
qkvLayout = kernels::QkvLayout::PagedKv;
auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
kvPoolPtr = pagedKvCache.mPrimaryPoolPtr;
kvSfPoolPtr = runnerParams.pagedKvSfCache.mPrimaryPoolPtr;
kvPageIdxPtr = reinterpret_cast<int const*>(pagedKvCache.data);
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
Expand All @@ -163,6 +165,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
tllmRunnerParams.kPtr = nullptr;
tllmRunnerParams.vPtr = nullptr;
tllmRunnerParams.kvPtr = kvPoolPtr;
tllmRunnerParams.kvSfPtr = kvSfPoolPtr;
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
Expand Down
Git LFS file not shown
Git LFS file not shown
Loading