Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ class KVCacheBlockPool
SizeType32 numKvHeads;
SizeType32 sizePerHead;
SizeType32 tokensPerBlock;
SizeType32 quantSize;
SizeType32 blockSize;

// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
Expand All @@ -488,15 +487,14 @@ class KVCacheBlockPool
bool containsBlockScales;

KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr,
SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr,
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false)
: numLayers(numLayers)
, kvFactor(kvFactor)
, numKvHeads(numKvHeads)
, sizePerHead(sizePerHead)
, tokensPerBlock(tokensPerBlock)
, quantSize(quantSize)
, blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize)
, blockSize(numKvHeads * sizePerHead * tokensPerBlock)
, primaryPtr(std::move(primaryPtr))
, secondaryPtr(std::move(secondaryPtr))
, containsBlockScales(containsBlockScales)
Expand Down Expand Up @@ -644,6 +642,15 @@ class WindowBlockManager
return mPools.at(poolIdx).blockSize;
}

[[nodiscard]] SizeType32 getNumEltsPerContainer() const
{
#ifdef ENABLE_FP4
return mDataType == nvinfer1::DataType::kFP4 ? 2 : 1;
#else
return 1;
#endif
}

[[nodiscard]] SizeType32 getNumPools(bool includeBlockScalePools = true) const noexcept
{
if (includeBlockScalePools)
Expand Down Expand Up @@ -1236,6 +1243,8 @@ class BaseKVCacheManager

[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers() const = 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers() const = 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping() const = 0;

virtual void getBlockOffsetsOfBatch(
Expand Down Expand Up @@ -1540,7 +1549,7 @@ class KVCacheManager : public BaseKVCacheManager
return mLayerToPoolMapping;
}

[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const
[[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const override
{
// TODO: add a new optional model input so the attention plugin can access these
return mBlockScalePoolPointers;
Expand Down
25 changes: 15 additions & 10 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,14 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
}

auto numEltsPerContainer = getNumEltsPerContainer();
#ifdef ENABLE_FP4
if (numEltsPerContainer == 2)
{
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
}
#endif

size_t poolIndex = 0;
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
{
Expand All @@ -615,7 +623,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
mLayerToPoolIndex[layerIdx] = poolIndex;
}
}
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1);
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock);
++poolIndex;
}

Expand Down Expand Up @@ -700,15 +708,16 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co

void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
{
SizeType32 const numEltsPerContainer = getNumEltsPerContainer();
auto num_pools = mPools.size();
for (size_t i = 0; i < num_pools; ++i)
{
auto& kv_pool = mPools[i];
TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0,
"Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize.");

mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead,
kv_pool.tokensPerBlock, quantBlockSize,
TLLM_CHECK_WITH_INFO((kv_pool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
"Cannot use FP4 quantization since kv_pool.sizePerHead is not divisible by FP4 quantBlockSize.");
auto blockScaleSizePerHead = kv_pool.sizePerHead * numEltsPerContainer / quantBlockSize;
mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, blockScaleSizePerHead,
kv_pool.tokensPerBlock,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
/*containsBlockScales=*/true);
Expand Down Expand Up @@ -742,10 +751,6 @@ void WindowBlockManager::allocatePools(bool useUvm)

if (poolIsFP4)
{
TLLM_CHECK_WITH_INFO(blockSize % 2 == 0, "Block size must be divisible by 2 for FP4 KV cache.");
// Divide by 2. We can't create FP4 buffers directly, so we'll have to create a uint8 buffer with
// half the expected number of elements.
blockSize /= 2;
poolDtype = nvinfer1::DataType::kINT8;
}

Expand Down
83 changes: 70 additions & 13 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
}
xqaParams.kv_cache_data_type = DATA_TYPE_E4M3;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
xqaParams.kv_cache_data_type = DATA_TYPE_E2M1;
}
else
{
xqaParams.kv_cache_data_type = xqaParams.data_type;
Expand Down Expand Up @@ -959,6 +963,9 @@ int AttentionOp::mlaGeneration(
generation_params.can_use_one_more_block, generation_params.host_primary_pool_pointer,
generation_params.host_secondary_pool_pointer, generation_params.block_offsets);

// Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided.
auto kv_scale_cache_buffer = KVBlockArray();

// Workspace pointer shift
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
size_t offset = 0;
Expand Down Expand Up @@ -1234,7 +1241,7 @@ int AttentionOp::mlaGeneration(
{
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
xqaParams.stream = stream;
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
return 0;
}
else if (mIsSpecDecodingEnabled && mUseSpecDecoding)
Expand Down Expand Up @@ -1308,8 +1315,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
float const q_scaling = mQScaling;

KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto sizePerToken = mNumAttnKVHeads * headSize * elemSize;
KVCacheBuffer kv_scale_cache_buffer;

auto sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /*bits*/;

if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
Expand All @@ -1318,6 +1327,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer,
params.host_secondary_pool_pointer, params.block_offsets);
if (mKVCacheQuantMode.hasFp4KvCache())
{
kv_scale_cache_buffer = KVBlockArray(params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock,
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block,
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
params.block_offsets);
}
}
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
{
Expand All @@ -1326,6 +1343,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
isCrossAttention() ? params.cross_kv_length : params.max_attention_window_size, sizePerToken,
params.cyclic_attention_window_size, params.sink_token_length, false,
reinterpret_cast<BufferDataType*>(params.key_value_cache));
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
}
}

Expand Down Expand Up @@ -1490,8 +1508,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
decoder_params.blockSparseParams = mBlockSparseParams;
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
decoder_params.quantScaleO = params.attention_output_orig_quant;
decoder_params.dequantScaleQ = params.kv_scale_quant_orig;
decoder_params.dequantScaleKv = params.kv_scale_quant_orig;
decoder_params.dequantScaleQkv = params.kv_scale_quant_orig;
decoder_params.separateQkvScales = mKVCacheQuantMode.hasFp4KvCache();
decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling);
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
Expand Down Expand Up @@ -1549,9 +1567,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
sync_check_cuda_error(stream);
}

KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache()
? KvCacheDataType::INT8
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
KvCacheDataType cache_type{KvCacheDataType::BASE};
if (mKVCacheQuantMode.hasInt8KvCache())
{
cache_type = KvCacheDataType::INT8;
}
else if (mKVCacheQuantMode.hasFp8KvCache())
{
cache_type = KvCacheDataType::FP8;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
cache_type = KvCacheDataType::NVFP4;
}

cudaDataType_t const gemm_data_type = tc::CudaDataType<T>::value;
int const attention_seq_len_1 = params.input_seq_length; // q length
Expand Down Expand Up @@ -1600,6 +1628,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.quantized_qkv_output = fp8_qkv_buffer;
preprocessingParams.q_output = q_buf_2_;
preprocessingParams.kv_cache_buffer = kv_cache_buffer;
preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer;
preprocessingParams.qkv_bias = params.qkv_bias;
preprocessingParams.tokens_info = decoder_params.tokensInfo;
preprocessingParams.seq_lens = params.context_lengths;
Expand All @@ -1612,7 +1641,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant;
preprocessingParams.spec_decoding_position_offsets = nullptr;
preprocessingParams.logn_scaling = params.logn_scaling_ptr;

Expand Down Expand Up @@ -1781,6 +1810,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
fmhaParams.pagedKvCache = kv_cache_buffer;
fmhaParams.pagedKvSfCache = kv_scale_cache_buffer;
}
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths;
Expand Down Expand Up @@ -2126,8 +2156,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
int32_t const batch_beam = params.beam_width * params.num_requests;

KVCacheBuffer kv_cache_buffer;
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize;
KVCacheBuffer kv_scale_cache_buffer;

auto const sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits<T>() / 8 /*bits*/;

if (useKVCache())
{
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
Expand All @@ -2137,13 +2169,22 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length,
params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer,
reinterpret_cast<BufferDataType*>(params.block_offsets));
if (mKVCacheQuantMode.hasFp4KvCache())
{
kv_scale_cache_buffer = KVBlockArray(batch_beam, params.max_blocks_per_sequence, mTokensPerBlock,
sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block,
params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer,
reinterpret_cast<BufferDataType*>(params.block_offsets));
}
}
else if constexpr (std::is_same_v<KVCacheBuffer, KVLinearBuffer>)
{
using BufferDataType = typename KVCacheBuffer::DataType;
kv_cache_buffer = KVLinearBuffer(batch_beam, params.max_attention_window_size, sizePerToken,
params.cyclic_attention_window_size, params.sink_token_length, false,
reinterpret_cast<BufferDataType*>(params.key_value_cache));
TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV.");
}
}
sync_check_cuda_error(stream);
Expand Down Expand Up @@ -2215,7 +2256,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
xqaParams.output = mhaOutput;
xqaParams.qkv = attention_input;
}
mXqaDispatcher->run(xqaParams, kv_cache_buffer);
mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer);
if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1)
{
this->template ulyssesGenerationPostprocess<T>(
Expand All @@ -2232,6 +2273,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
{
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output.");
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 KV cache.");
}
else
{
TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase.");
Expand Down Expand Up @@ -2503,6 +2548,10 @@ int AttentionOp::initialize() noexcept
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");

// Check requirements for FP4 KV cache.
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,
"mFP8ContextFMHA must enable if FP4 KV cache is enabled");

TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0));
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
Expand Down Expand Up @@ -2579,7 +2628,10 @@ int AttentionOp::initialize() noexcept
{
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
}
// TODO: add FP4 KV cache support.
else if (mKVCacheQuantMode.hasFp4KvCache())
{
fmhaParams.dataTypeKv = DATA_TYPE_E2M1;
}
}
// The output dtype.
fmhaParams.dataTypeOut = data_type;
Expand Down Expand Up @@ -2789,6 +2841,11 @@ int AttentionOp::initialize() noexcept
fixedParams.kvDataType = DATA_TYPE_E4M3;
fixedParams.mathDataType = DATA_TYPE_E4M3;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
fixedParams.kvDataType = DATA_TYPE_E2M1;
fixedParams.mathDataType = DATA_TYPE_E4M3;
}
else
{
fixedParams.kvDataType = fixedParams.inputDataType;
Expand Down
16 changes: 16 additions & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class AttentionOp
kernels::KVBlockArray::DataType* block_offsets = nullptr;
void* host_primary_pool_pointer = nullptr;
void* host_secondary_pool_pointer = nullptr;
void* host_primary_block_scale_pool_pointer = nullptr;
void* host_secondary_block_scale_pool_pointer = nullptr;
int32_t num_tokens = 0;
int32_t total_kv_len = 0;
int32_t max_blocks_per_sequence = 0;
Expand Down Expand Up @@ -233,6 +235,20 @@ class AttentionOp
return num_sm_parts;
}

template <typename T>
int getKvCacheElemSizeInBits() const
{
if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache())
{
return 8;
}
else if (mKVCacheQuantMode.hasFp4KvCache())
{
return 4;
}
return sizeof(T) * 8;
}

// Called in configurePlugin().
template <typename T, typename KVCacheBuffer>
void prepareEnqueueGeneration(EnqueueGenerationParams<T> const& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ struct MHARunnerParams
void const* vPtr;
// The paged kv cache array.
KVBlockArray pagedKvCache;
// The paged kv cache array for scaling factor.
KVBlockArray pagedKvSfCache;
// The output buffer ptr.
void* outputPtr;
// The output scaling factor buffer ptr. (only used for FP4 output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_cache_scale_factors = nullptr;
preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
// Scalar parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ class XQAKernelList
preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr;
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin;
preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant;
preprocessingParams.kv_cache_scale_factors = nullptr;
preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant;
preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets;
preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas;
// Scalar parameters.
Expand Down
Loading