From 2d1bf100873399e76f647be2fe88865777de98c7 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Sat, 7 Jun 2025 09:40:21 +0000 Subject: [PATCH 1/6] Add NVFP4 KV cache support Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 19 +++-- .../batch_manager/kvCacheManager.cpp | 25 +++--- cpp/tensorrt_llm/common/attentionOp.cpp | 83 ++++++++++++++++--- cpp/tensorrt_llm/common/attentionOp.h | 16 ++++ .../fused_multihead_attention_common.h | 2 + .../decoderXQAImplJIT/decoderXQAImplJIT.cpp | 3 +- .../decoderXQAImplPrecompiled.cpp | 3 +- cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp | 3 + cpp/tensorrt_llm/kernels/gptKernels.cu | 13 ++- cpp/tensorrt_llm/kernels/gptKernels.h | 6 +- .../trtllmGenKernels/fmha/fmhaRunner.cpp | 3 +- .../trtllmGenKernels/fmha/fmhaRunnerParams.h | 9 +- .../trtllmGenKernels/fmha/kernelParams.h | 67 +++++++++------ .../kernels/unfusedAttentionKernels.h | 12 +-- .../unfusedAttentionKernels_2_template.h | 58 +++++++++---- cpp/tensorrt_llm/kernels/xqaDispatcher.cpp | 60 +++++++++----- cpp/tensorrt_llm/kernels/xqaDispatcher.h | 9 +- cpp/tensorrt_llm/nanobind/bindings.cpp | 1 + .../pybind/batch_manager/kvCacheManager.cpp | 12 +++ cpp/tensorrt_llm/pybind/bindings.cpp | 2 + cpp/tensorrt_llm/thop/attentionOp.cpp | 62 +++++++++++--- .../batch_manager/kvCacheManagerTest.cpp | 8 +- cpp/tests/unit_tests/kernels/ropeTest.cu | 15 +--- .../_torch/attention_backend/trtllm.py | 29 ++++--- tensorrt_llm/_torch/modules/attention.py | 9 ++ tensorrt_llm/_torch/modules/linear.py | 67 ++++++++++++++- tensorrt_llm/_torch/pyexecutor/_util.py | 3 + .../_torch/pyexecutor/model_engine.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 34 ++++++-- tensorrt_llm/_utils.py | 27 +++--- tensorrt_llm/llmapi/llm_utils.py | 4 +- .../defs/accuracy/references/gsm8k.yaml | 3 + .../defs/accuracy/references/mmlu.yaml | 3 + .../defs/accuracy/test_llm_api_pytorch.py | 23 +++++ .../test_lists/test-db/l0_b200.yml | 1 + 35 files changed, 516 insertions(+), 180 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index df526a5dfbe..310f8eef913 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -477,7 +477,6 @@ class KVCacheBlockPool SizeType32 numKvHeads; SizeType32 sizePerHead; SizeType32 tokensPerBlock; - SizeType32 quantSize; SizeType32 blockSize; // Memory pools. Primary is fast memory, secondary is slower memory used for offloading. @@ -488,15 +487,14 @@ class KVCacheBlockPool bool containsBlockScales; KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead, - SizeType32 tokensPerBlock, SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr, + SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr, runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false) : numLayers(numLayers) , kvFactor(kvFactor) , numKvHeads(numKvHeads) , sizePerHead(sizePerHead) , tokensPerBlock(tokensPerBlock) - , quantSize(quantSize) - , blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize) + , blockSize(numKvHeads * sizePerHead * tokensPerBlock) , primaryPtr(std::move(primaryPtr)) , secondaryPtr(std::move(secondaryPtr)) , containsBlockScales(containsBlockScales) @@ -644,6 +642,15 @@ class WindowBlockManager return mPools.at(poolIdx).blockSize; } + [[nodiscard]] SizeType32 getNumEltsPerContainer() const + { +#ifdef ENABLE_FP4 + return mDataType == nvinfer1::DataType::kFP4 ? 2 : 1; +#else + return 1; +#endif + } + [[nodiscard]] SizeType32 getNumPools(bool includeBlockScalePools = true) const noexcept { if (includeBlockScalePools) @@ -1236,6 +1243,8 @@ class BaseKVCacheManager [[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers() const = 0; + [[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers() const = 0; + [[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping() const = 0; virtual void getBlockOffsetsOfBatch( @@ -1540,7 +1549,7 @@ class KVCacheManager : public BaseKVCacheManager return mLayerToPoolMapping; } - [[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const + [[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers() const override { // TODO: add a new optional model input so the attention plugin can access these return mBlockScalePoolPointers; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 0b793a041aa..02a9bcb7d0f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -605,6 +605,14 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool; } + auto numEltsPerContainer = getNumEltsPerContainer(); +#ifdef ENABLE_FP4 + if (numEltsPerContainer == 2) + { + TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache."); + } +#endif + size_t poolIndex = 0; for (auto const [numKvHeads, numLayers] : numLayersPerPool) { @@ -615,7 +623,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind mLayerToPoolIndex[layerIdx] = poolIndex; } } - mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead, tokensPerBlock, 1); + mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock); ++poolIndex; } @@ -700,15 +708,16 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize) { + SizeType32 const numEltsPerContainer = getNumEltsPerContainer(); auto num_pools = mPools.size(); for (size_t i = 0; i < num_pools; ++i) { auto& kv_pool = mPools[i]; - TLLM_CHECK_WITH_INFO(kv_pool.blockSize % quantBlockSize == 0, - "Cannot use FP4 quantization since kv_pool.blockSize is not divisible by FP4 quantBlockSize."); - - mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, kv_pool.sizePerHead, - kv_pool.tokensPerBlock, quantBlockSize, + TLLM_CHECK_WITH_INFO((kv_pool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0, + "Cannot use FP4 quantization since kv_pool.sizePerHead is not divisible by FP4 quantBlockSize."); + auto blockScaleSizePerHead = kv_pool.sizePerHead * numEltsPerContainer / quantBlockSize; + mPools.emplace_back(kv_pool.numLayers, kv_pool.kvFactor, kv_pool.numKvHeads, blockScaleSizePerHead, + kv_pool.tokensPerBlock, /*primaryPool=*/nullptr, /*secondaryPool=*/nullptr, /*containsBlockScales=*/true); @@ -742,10 +751,6 @@ void WindowBlockManager::allocatePools(bool useUvm) if (poolIsFP4) { - TLLM_CHECK_WITH_INFO(blockSize % 2 == 0, "Block size must be divisible by 2 for FP4 KV cache."); - // Divide by 2. We can't create FP4 buffers directly, so we'll have to create a uint8 buffer with - // half the expected number of elements. - blockSize /= 2; poolDtype = nvinfer1::DataType::kINT8; } diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 866a0d71c8f..f00c25273bf 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -214,6 +214,10 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& } xqaParams.kv_cache_data_type = DATA_TYPE_E4M3; } + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + xqaParams.kv_cache_data_type = DATA_TYPE_E2M1; + } else { xqaParams.kv_cache_data_type = xqaParams.data_type; @@ -959,6 +963,9 @@ int AttentionOp::mlaGeneration( generation_params.can_use_one_more_block, generation_params.host_primary_pool_pointer, generation_params.host_secondary_pool_pointer, generation_params.block_offsets); + // Currently NVFP4 KV cache is not supported for MLA. An empty placeholder is provided. + auto kv_scale_cache_buffer = KVBlockArray(); + // Workspace pointer shift int8_t* workspace_byte_ptr = reinterpret_cast(params.workspace); size_t offset = 0; @@ -1234,7 +1241,7 @@ int AttentionOp::mlaGeneration( { TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase."); xqaParams.stream = stream; - mXqaDispatcher->run(xqaParams, kv_cache_buffer); + mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer); return 0; } else if (mIsSpecDecodingEnabled && mUseSpecDecoding) @@ -1308,8 +1315,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea float const q_scaling = mQScaling; KVCacheBuffer kv_cache_buffer; - auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); - auto sizePerToken = mNumAttnKVHeads * headSize * elemSize; + KVCacheBuffer kv_scale_cache_buffer; + + auto sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits() / 8 /*bits*/; + if (useKVCache()) { if constexpr (std::is_same_v) @@ -1318,6 +1327,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer, params.block_offsets); + if (mKVCacheQuantMode.hasFp4KvCache()) + { + kv_scale_cache_buffer = KVBlockArray(params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock, + sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, + params.sink_token_length, params.can_use_one_more_block, + params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer, + params.block_offsets); + } } else if constexpr (std::is_same_v) { @@ -1326,6 +1343,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea isCrossAttention() ? params.cross_kv_length : params.max_attention_window_size, sizePerToken, params.cyclic_attention_window_size, params.sink_token_length, false, reinterpret_cast(params.key_value_cache)); + TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV."); } } @@ -1490,8 +1508,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea decoder_params.blockSparseParams = mBlockSparseParams; decoder_params.fmhaTileCounter = fmha_tile_counter_ptr; decoder_params.quantScaleO = params.attention_output_orig_quant; - decoder_params.dequantScaleQ = params.kv_scale_quant_orig; - decoder_params.dequantScaleKv = params.kv_scale_quant_orig; + decoder_params.dequantScaleQkv = params.kv_scale_quant_orig; + decoder_params.separateQkvScales = mKVCacheQuantMode.hasFp4KvCache(); decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling); decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr; decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr; @@ -1549,9 +1567,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea sync_check_cuda_error(stream); } - KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache() - ? KvCacheDataType::INT8 - : (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE); + KvCacheDataType cache_type{KvCacheDataType::BASE}; + if (mKVCacheQuantMode.hasInt8KvCache()) + { + cache_type = KvCacheDataType::INT8; + } + else if (mKVCacheQuantMode.hasFp8KvCache()) + { + cache_type = KvCacheDataType::FP8; + } + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + cache_type = KvCacheDataType::NVFP4; + } cudaDataType_t const gemm_data_type = tc::CudaDataType::value; int const attention_seq_len_1 = params.input_seq_length; // q length @@ -1600,6 +1628,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea preprocessingParams.quantized_qkv_output = fp8_qkv_buffer; preprocessingParams.q_output = q_buf_2_; preprocessingParams.kv_cache_buffer = kv_cache_buffer; + preprocessingParams.kv_cache_block_scales_buffer = kv_scale_cache_buffer; preprocessingParams.qkv_bias = params.qkv_bias; preprocessingParams.tokens_info = decoder_params.tokensInfo; preprocessingParams.seq_lens = params.context_lengths; @@ -1612,7 +1641,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf; preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin; preprocessingParams.mrope_rotary_cos_sin = params.mrope_rotary_cos_sin; - preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant; + preprocessingParams.qkv_scale_orig_quant = params.kv_scale_orig_quant; preprocessingParams.spec_decoding_position_offsets = nullptr; preprocessingParams.logn_scaling = params.logn_scaling_ptr; @@ -1781,6 +1810,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea if constexpr (std::is_same_v) { fmhaParams.pagedKvCache = kv_cache_buffer; + fmhaParams.pagedKvSfCache = kv_scale_cache_buffer; } fmhaParams.cuQSeqLenPtr = cu_q_seqlens; fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths; @@ -2126,8 +2156,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud int32_t const batch_beam = params.beam_width * params.num_requests; KVCacheBuffer kv_cache_buffer; - auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); - auto const sizePerToken = mNumAttnKVHeads * headSize * elemSize; + KVCacheBuffer kv_scale_cache_buffer; + + auto const sizePerToken = mNumAttnKVHeads * headSize * getKvCacheElemSizeInBits() / 8 /*bits*/; + if (useKVCache()) { if constexpr (std::is_same_v) @@ -2137,6 +2169,14 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, params.sink_token_length, params.can_use_one_more_block, params.host_primary_pool_pointer, params.host_secondary_pool_pointer, reinterpret_cast(params.block_offsets)); + if (mKVCacheQuantMode.hasFp4KvCache()) + { + kv_scale_cache_buffer = KVBlockArray(batch_beam, params.max_blocks_per_sequence, mTokensPerBlock, + sizePerToken / 8, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size, + params.sink_token_length, params.can_use_one_more_block, + params.host_primary_block_scale_pool_pointer, params.host_secondary_block_scale_pool_pointer, + reinterpret_cast(params.block_offsets)); + } } else if constexpr (std::is_same_v) { @@ -2144,6 +2184,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud kv_cache_buffer = KVLinearBuffer(batch_beam, params.max_attention_window_size, sizePerToken, params.cyclic_attention_window_size, params.sink_token_length, false, reinterpret_cast(params.key_value_cache)); + TLLM_CHECK_WITH_INFO(!(mKVCacheQuantMode.hasFp4KvCache()), "FP4 KV cache only supports paged KV."); } } sync_check_cuda_error(stream); @@ -2215,7 +2256,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud xqaParams.output = mhaOutput; xqaParams.qkv = attention_input; } - mXqaDispatcher->run(xqaParams, kv_cache_buffer); + mXqaDispatcher->run(xqaParams, kv_cache_buffer, kv_scale_cache_buffer); if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1) { this->template ulyssesGenerationPostprocess( @@ -2232,6 +2273,10 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud { TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 output."); } + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + TLLM_CHECK_WITH_INFO(false, "No available kernels are found for FP4 KV cache."); + } else { TLLM_LOG_DEBUG("XQA kernels are not selected in the generation phase."); @@ -2503,6 +2548,10 @@ int AttentionOp::initialize() noexcept TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121, "fuse_fp4_quant only supports SM100 or SM120 or SM121 devices."); + // Check requirements for FP4 KV cache. + TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA, + "mFP8ContextFMHA must enable if FP4 KV cache is enabled"); + TLLM_CHECK(isRoPE() == (mRotaryEmbeddingDim != 0)); TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16), "Unsupported data type, pre SM 80 GPUs do not support bfloat16"); @@ -2579,7 +2628,10 @@ int AttentionOp::initialize() noexcept { fmhaParams.dataTypeKv = DATA_TYPE_E4M3; } - // TODO: add FP4 KV cache support. + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + fmhaParams.dataTypeKv = DATA_TYPE_E2M1; + } } // The output dtype. fmhaParams.dataTypeOut = data_type; @@ -2789,6 +2841,11 @@ int AttentionOp::initialize() noexcept fixedParams.kvDataType = DATA_TYPE_E4M3; fixedParams.mathDataType = DATA_TYPE_E4M3; } + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + fixedParams.kvDataType = DATA_TYPE_E2M1; + fixedParams.mathDataType = DATA_TYPE_E4M3; + } else { fixedParams.kvDataType = fixedParams.inputDataType; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index 39745a3957e..8cd27cd2d2e 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -94,6 +94,8 @@ class AttentionOp kernels::KVBlockArray::DataType* block_offsets = nullptr; void* host_primary_pool_pointer = nullptr; void* host_secondary_pool_pointer = nullptr; + void* host_primary_block_scale_pool_pointer = nullptr; + void* host_secondary_block_scale_pool_pointer = nullptr; int32_t num_tokens = 0; int32_t total_kv_len = 0; int32_t max_blocks_per_sequence = 0; @@ -233,6 +235,20 @@ class AttentionOp return num_sm_parts; } + template + int getKvCacheElemSizeInBits() const + { + if (mKVCacheQuantMode.hasInt8KvCache() || mKVCacheQuantMode.hasFp8KvCache()) + { + return 8; + } + else if (mKVCacheQuantMode.hasFp4KvCache()) + { + return 4; + } + return sizeof(T) * 8; + } + // Called in configurePlugin(). template void prepareEnqueueGeneration(EnqueueGenerationParams const& params); diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 77bbbba876a..8b19962a8a6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -267,6 +267,8 @@ struct MHARunnerParams void const* vPtr; // The paged kv cache array. KVBlockArray pagedKvCache; + // The paged kv cache array for scaling factor. + KVBlockArray pagedKvSfCache; // The output buffer ptr. void* outputPtr; // The output scaling factor buffer ptr. (only used for FP4 output) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index ac331ac33fe..83ba3662262 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -303,8 +303,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr; preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf; preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin; - preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant; - preprocessingParams.kv_cache_scale_factors = nullptr; + preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant; preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets; preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas; // Scalar parameters. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index ebe6722ac71..c0ae76263d8 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -224,8 +224,7 @@ class XQAKernelList preprocessingParams.cu_seq_lens = xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr; preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf; preprocessingParams.rotary_coef_cache_buffer = xqaParams.rotary_cos_sin; - preprocessingParams.kvScaleOrigQuant = xqaParams.kv_scale_orig_quant; - preprocessingParams.kv_cache_scale_factors = nullptr; + preprocessingParams.qkv_scale_orig_quant = xqaParams.kv_scale_orig_quant; preprocessingParams.spec_decoding_position_offsets = xqaParams.spec_decoding_position_offsets; preprocessingParams.mrope_position_deltas = xqaParams.mrope_position_deltas; // Scalar parameters. diff --git a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp index 2db0317adc1..d6dd07d4d95 100644 --- a/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp @@ -139,6 +139,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized."); // Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams void const* kvPoolPtr = nullptr; + void const* kvSfPoolPtr = nullptr; void const* kvPageIdxPtr = nullptr; auto qkvLayout = kernels::QkvLayout::PackedQkv; int32_t maxBlocksPerSeq = 0; @@ -148,6 +149,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) qkvLayout = kernels::QkvLayout::PagedKv; auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA(); kvPoolPtr = pagedKvCache.mPrimaryPoolPtr; + kvSfPoolPtr = runnerParams.pagedKvSfCache.mPrimaryPoolPtr; kvPageIdxPtr = reinterpret_cast(pagedKvCache.data); maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq; numTokensPerBlock = pagedKvCache.mTokensPerBlock; @@ -172,6 +174,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams) tllmRunnerParams.kPtr = runnerParams.kPtr; tllmRunnerParams.vPtr = runnerParams.vPtr; tllmRunnerParams.kvPtr = kvPoolPtr; + tllmRunnerParams.kvSfPtr = kvSfPoolPtr; tllmRunnerParams.qkvPtr = runnerParams.qkvPtr; tllmRunnerParams.attentionSinksPtr = runnerParams.attentionSinksPtr; tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast(runnerParams.cuQSeqLenPtr); diff --git a/cpp/tensorrt_llm/kernels/gptKernels.cu b/cpp/tensorrt_llm/kernels/gptKernels.cu index f79c7af6a63..7d6332d1a4d 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.cu +++ b/cpp/tensorrt_llm/kernels/gptKernels.cu @@ -276,13 +276,18 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets params.fmhaTileCounter[0] = 0u; } // Take the quantization scales into consideration. - float dequantScaleQ = params.dequantScaleQ ? params.dequantScaleQ[0] : 1.f; - float dequantScaleKv = params.dequantScaleKv ? params.dequantScaleKv[0] : 1.f; + int const q_scale_idx = 0; + int const k_scale_idx = params.separateQkvScales ? 1 : 0; + int const v_scale_idx = params.separateQkvScales ? 2 : 0; + float dequantScaleQ = params.dequantScaleQkv ? params.dequantScaleQkv[q_scale_idx] : 1.f; + float dequantScaleK = params.dequantScaleQkv ? params.dequantScaleQkv[k_scale_idx] : 1.f; + float dequantScaleV = params.dequantScaleQkv ? params.dequantScaleQkv[v_scale_idx] : 1.f; + float quantScaleO = params.quantScaleO ? params.quantScaleO[0] : 1.f; if (params.fmhaBmm1Scale) { // The scale after fmha bmm1. - params.fmhaBmm1Scale[0] = dequantScaleQ * dequantScaleKv * params.fmhaHostBmm1Scale; + params.fmhaBmm1Scale[0] = dequantScaleQ * dequantScaleK * params.fmhaHostBmm1Scale; // The scale prepared for log2 optimization. constexpr float kLog2e = 1.4426950408889634074f; params.fmhaBmm1Scale[1] = params.fmhaBmm1Scale[0] * kLog2e; @@ -290,7 +295,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets if (params.fmhaBmm2Scale) { // The scale after fmha bmm2. - params.fmhaBmm2Scale[0] = quantScaleO * dequantScaleKv; + params.fmhaBmm2Scale[0] = quantScaleO * dequantScaleV; } } } diff --git a/cpp/tensorrt_llm/kernels/gptKernels.h b/cpp/tensorrt_llm/kernels/gptKernels.h index d659d6e3ec7..38c56be9026 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.h +++ b/cpp/tensorrt_llm/kernels/gptKernels.h @@ -144,8 +144,10 @@ struct BuildDecoderInfoParams // Scales for fmha only. // The scale to dequant Q/Kv input. - float const* dequantScaleQ; - float const* dequantScaleKv; + float const* dequantScaleQkv; + // Whether to use separate scales for Q/K/V. + bool separateQkvScales; + // The scale to quant O output. float const* quantScaleO; // The fmha bmm1 host scale (1.0f / sqrt(headSize) by default). diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp index 9ff85d9d7ce..54eb24ae766 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.cpp @@ -37,7 +37,8 @@ TllmGenFmhaRunner::TllmGenFmhaRunner(Data_type dtypeQ, Data_type dtypeKv, Data_t TLLM_CHECK_WITH_INFO(mSM == kSM_100, "Unsupported architecture"); TLLM_CHECK_WITH_INFO( mDtypeQ == DATA_TYPE_E4M3 || mDtypeQ == DATA_TYPE_FP16 || mDtypeQ == DATA_TYPE_BF16, "Unsupported Q data type"); - TLLM_CHECK_WITH_INFO(mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 || mDtypeKv == DATA_TYPE_BF16, + TLLM_CHECK_WITH_INFO(mDtypeKv == DATA_TYPE_E2M1 || mDtypeKv == DATA_TYPE_E4M3 || mDtypeKv == DATA_TYPE_FP16 + || mDtypeKv == DATA_TYPE_BF16, "Unsupported Kv data type"); TLLM_CHECK_WITH_INFO(mDtypeOut == DATA_TYPE_E2M1 || mDtypeOut == DATA_TYPE_E4M3 || mDtypeOut == DATA_TYPE_FP16 || mDtypeOut == DATA_TYPE_BF16, diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h index 0f042ec74b9..63d0d24bdc4 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h @@ -192,12 +192,10 @@ struct TllmGenFmhaRunnerParams void const* vPtr; // Packed KV buffer void const* kvPtr; + // Packed KV scaling factor buffer + void const* kvSfPtr; // Packed QKV buffer void const* qkvPtr; - // The scaling factor pointer of K. - void const* kSfBasePtr; - // The scaling factor pointer of V. - void const* vSfBasePtr; // The attention sinks pointer (additional value per head in the denominator of the softmax). float const* attentionSinksPtr; // The custom mask ptr. @@ -271,9 +269,6 @@ struct TllmGenFmhaRunnerParams // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase kernel when inflight // batching is enabled. int mSfStartTokenIdx; - - // The SF scale for Kv. - float mScaleSfKv; // The cuda stream. cudaStream_t stream; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h index 811cd30427a..8917496befd 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/kernelParams.h @@ -152,7 +152,7 @@ struct KernelParams float mScaleSfO; // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase // kernel when inflight batching is enabled in TRT-LLM. - int32_t mStartTokenIdxSfO; + int32_t mStartTokenIdx; // The sum of sequence lengths for Q and K/V. int32_t mSumOfSeqLensQ, mSumOfSeqLensKv; @@ -397,8 +397,8 @@ struct KernelParams // Create the TMA shape/stride for K. template - static auto makeTmaShapeStrideKv( - FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, bool isK) + static auto makeTmaShapeStrideKv(FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, + bool isK, bool storeTransformedKvInTmem) { // The shape elements. auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params); @@ -425,9 +425,13 @@ struct KernelParams // Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements. // The column index and strides needs to divide by 2. auto const colIdxDivisor = dtypeKv == DATA_TYPE_E2M1 ? 2 : 1; + // When storeTransformedKvInTmem is true, the dimensions reflect FP4 element dimensions, thus + // no need to divide. + auto shape - = std::vector{static_cast(headDim / colIdxDivisor), static_cast(numKeys), - static_cast(options.mNumHeadsKv), static_cast(batchSize)}; + = std::vector{static_cast(storeTransformedKvInTmem ? headDim : headDim / colIdxDivisor), + static_cast(numKeys), static_cast(options.mNumHeadsKv), + static_cast(batchSize)}; auto stride = std::vector{1, static_cast(strideKeys / colIdxDivisor), static_cast(strideHeads / colIdxDivisor), static_cast(strideBatch / colIdxDivisor)}; @@ -474,7 +478,7 @@ struct KernelParams // Prepare pointers for TMA descriptors. static std::tuple getDevicePtrs( - TllmGenFmhaRunnerParams const& runnerParams, int32_t bytesPerElt) + TllmGenFmhaRunnerParams const& runnerParams, int32_t bitsPerElt) { // Declare the q, k, v ptrs. void const *qPtr{runnerParams.qPtr}, *kPtr{runnerParams.kPtr}, *vPtr{runnerParams.vPtr}; @@ -484,9 +488,10 @@ struct KernelParams { qPtr = runnerParams.qkvPtr; kPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) - + runnerParams.mNumHeadsQ * runnerParams.mHeadDimQk * bytesPerElt); + + runnerParams.mNumHeadsQ * runnerParams.mHeadDimQk * bitsPerElt / 8 /*bits*/); vPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) - + (runnerParams.mNumHeadsQ + runnerParams.mNumHeadsKv) * runnerParams.mHeadDimQk * bytesPerElt); + + (runnerParams.mNumHeadsQ + runnerParams.mNumHeadsKv) * runnerParams.mHeadDimQk * bitsPerElt + / 8 /*bits*/); } // Set K and V pointer from pagedKv tensor. else if (isPagedKv(runnerParams.mQkvLayout)) @@ -503,7 +508,7 @@ struct KernelParams // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. int32_t const maxHeadDimKv{std::max(runnerParams.mHeadDimQk, runnerParams.mHeadDimV)}; vPtr = reinterpret_cast(reinterpret_cast(runnerParams.kvPtr) - + runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * maxHeadDimKv * bytesPerElt); + + runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * maxHeadDimKv * bitsPerElt / 8 /*bits*/); } // Return the pointers. @@ -514,12 +519,16 @@ struct KernelParams template static CUtensorMap buildNdTmaDescriptor(FmhaOptions const& options, Data_type dtypeElt, std::vector const& shapes, std::vector const& strides, - std::vector const& tileShapes, void* gmemAddr, bool swizzled = true) + std::vector const& tileShapes, void* gmemAddr, bool swizzled = true, bool unpack4b = false) { CUtensorMap desc{}; // The data type. CUtensorMapDataType tmaDataFormat; - if (dtypeElt == DATA_TYPE_E2M1 || dtypeElt == DATA_TYPE_E4M3) + if (dtypeElt == DATA_TYPE_E2M1) + { + tmaDataFormat = unpack4b ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B : CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + else if (dtypeElt == DATA_TYPE_E4M3) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; } @@ -543,6 +552,10 @@ struct KernelParams { swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE; } + else if (tmaDataFormat == CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) + { + swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; + } else if ((numBytesInLeadingDim % 128) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; @@ -629,7 +642,7 @@ struct KernelParams KernelParams params; // Get the device pointers for TMA descriptors. - auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv)); + auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bits(kernelMeta.mDataTypeKv)); // The maximum headDim of K and V. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. @@ -664,14 +677,22 @@ struct KernelParams // The number of head elts (per token) in each block of shared memory (see above explanation). int32_t numEltsInClampedHeadDimKv = std::min(numEltsIn128BKv, maxHeadDimKv); - // Shape/stride for gmem tensor Kv. - auto [shapeK, strideK] = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ true); - auto [shapeV, strideV] = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ false); - // Build tma descriptor for K. // Do we have to transform K/V before MMA? bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ}; + // Whether store transformed K/V in TMEM. + bool const isSwapsMmaAb = isSwapsMmaAbForGenerationKernel(static_cast(kernelMeta.mKernelType)); + bool const storeTransformedKvInTmem{kernelMeta.mDataTypeKv == DATA_TYPE_E2M1 + && kernelMeta.mDataTypeQ == DATA_TYPE_E4M3 && maxHeadDimKv == 128 && isSwapsMmaAb}; + + // Shape/stride for gmem tensor Kv. + auto [shapeK, strideK] + = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ true, storeTransformedKvInTmem); + auto [shapeV, strideV] + = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ false, storeTransformedKvInTmem); + // Whether swizzle is needed for K/V. + bool const swizzleKv{storeTransformedKvInTmem ? true : !transformsKv}; // Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements. - auto const numEltsDivisor = kernelMeta.mDataTypeKv == DATA_TYPE_E2M1 ? 2 : 1; + auto const numEltsDivisor = kernelMeta.mDataTypeKv == DATA_TYPE_E2M1 && !storeTransformedKvInTmem ? 2 : 1; // The tileShapes for K/V. std::vector tileShapeKv(shapeK.size(), 1); tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor; @@ -679,12 +700,11 @@ struct KernelParams // Build tma descriptor for K. params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK, tileShapeKv, const_cast(kPtr), - /*swizzled = */ !transformsKv); + /*swizzled = */ swizzleKv, /*unpack4b = */ storeTransformedKvInTmem); // Build tma descriptor for V. params.tmaV_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeV, strideV, tileShapeKv, const_cast(vPtr), - /*swizzled = */ !transformsKv); - + /*swizzled = */ swizzleKv, /*unpack4b = */ storeTransformedKvInTmem); // If the KV dtype is E2m1, additional scaling factors are needed for dequant. if (kernelMeta.mDataTypeKv == DATA_TYPE_E2M1) { @@ -704,12 +724,12 @@ struct KernelParams // headDim / NumEltsPerSf / 16). See makeTmaShapeStrideKvSf for details. Build tma descriptor // for K SF. params.tmaKSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf, tileShapeKvSf, - const_cast(options.kSfBasePtr), + const_cast(options.kvSfPtr), /*swizzled = */ false); // Build tma descriptor for V SF. params.tmaVSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf, tileShapeKvSf, - const_cast(options.vSfBasePtr), + const_cast(options.kvSfPtr), /*swizzled = */ false); } @@ -789,8 +809,7 @@ struct KernelParams params.mNumHiddenEltsO = options.mNumHeadsQ * options.mHeadDimQk; params.mOutputScale = 1.f; params.mScaleSoftmaxLog2 = (1.f / (std::sqrt((float) (options.mHeadDimQk)) * options.mScaleQ)) * M_LOG2E; - params.mStartTokenIdxSfO = options.mSfStartTokenIdx; - params.mScaleSfKv = options.mScaleSfKv; + params.mStartTokenIdx = options.mSfStartTokenIdx; return params; } diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h index 71178ba0f76..22fa7d464ed 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h @@ -122,8 +122,8 @@ struct QKVPreprocessingParams // Fuse the computation of FMHA quantization scales into the preprocessing kernels. // This can also be done in gptKernels.h if there is no preprocessing kernels. // The scale to dequant Q/Kv input. - float const* q_scale_quant_orig{nullptr}; - float const* kv_scale_quant_orig{nullptr}; + float const* qkv_scale_quant_orig{nullptr}; + float const* qkv_scale_orig_quant{nullptr}; // The scale to quant O output. float const* o_scale_orig_quant{nullptr}; // The scale after fmha bmm1. @@ -154,10 +154,6 @@ struct QKVPreprocessingParams // the pre-computed RoPE factors. computed at model build time, stored in the engine // shape is {rotary_embedding_max_positions, rotary_embedding_dim}. eg (2048, 128) float2 const* rotary_coef_cache_buffer{nullptr}; - float const* kvScaleOrigQuant{nullptr}; - // Pair of floats on the GPU corresponding to the second level K/V scale for - // FP4 KV cache quantization. - float const* kv_cache_scale_factors{nullptr}; int const* spec_decoding_position_offsets{nullptr}; float2 const* mrope_rotary_cos_sin{nullptr}; @@ -239,7 +235,7 @@ struct QKVPreprocessingParams << *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT, runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2}))); ss << "rotary_coef_cache_buffer: " << rotary_coef_cache_buffer << std::endl; - ss << "kvScaleOrigQuant: " << kvScaleOrigQuant << std::endl; + ss << "qkv_scale_orig_quant: " << qkv_scale_orig_quant << std::endl; ss << "spec_decoding_position_offsets: " << spec_decoding_position_offsets << std::endl; ss << "batch_size: " << batch_size << std::endl; ss << "max_input_seq_len: " << max_input_seq_len << std::endl; @@ -345,8 +341,6 @@ void invokeQKVPreprocessing(QKVPreprocessingParams params, cud { TLLM_CHECK_WITH_INFO(params.kv_cache_block_scales_buffer.data != nullptr, "Cannot append to FP4 KV cache without block scales pool"); - TLLM_CHECK_WITH_INFO( - params.kv_cache_scale_factors != nullptr, "Cannot append to FP4 KV cache without KV cache scale factors"); if constexpr (std::is_same_v) { // TODO: needs special quantization logic. The existing quantization functions diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h index 4442e2c2369..fb86622031a 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h @@ -559,7 +559,7 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams( params.kv_cache_block_scales_buffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache)); - float kSecondLevelSF = params.kv_cache_scale_factors[0]; - float vSecondLevelSF = params.kv_cache_scale_factors[1]; - + float kSecondLevelSF = params.qkv_scale_orig_quant[1]; + float vSecondLevelSF = params.qkv_scale_orig_quant[2]; auto& kPacked = reinterpret_cast&>(k_to_cache); auto& vPacked = reinterpret_cast&>(v); quantizeAndWriteFP4KVCache(kBlockScales, vBlockScales, reinterpret_cast(kDst), @@ -640,13 +639,24 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams::Type; TScale scaleOrigQuant; - mmha::convert_from_float(&scaleOrigQuant, params.kvScaleOrigQuant[0]); + mmha::convert_from_float(&scaleOrigQuant, params.qkv_scale_orig_quant[0]); // Store 8bits kv cache. mmha::store_8bits_vec(kDst, k, inBlockIdx, scaleOrigQuant); mmha::store_8bits_vec(vDst, v, inBlockIdx, scaleOrigQuant); @@ -1007,9 +1018,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams( params.kv_cache_block_scales_buffer.getVBlockPtr(batch_idx, token_idx_in_kv_cache)); - float kSecondLevelSF = params.kv_cache_scale_factors[0]; - float vSecondLevelSF = params.kv_cache_scale_factors[1]; - + float kSecondLevelSF = params.qkv_scale_orig_quant[1]; + float vSecondLevelSF = params.qkv_scale_orig_quant[2]; auto& kPacked = reinterpret_cast&>(k); auto& vPacked = reinterpret_cast&>(v); quantizeAndWriteFP4KVCache(kBlockScales, vBlockScales, reinterpret_cast(kDst), @@ -1035,13 +1045,24 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams= 900)) @@ -1378,7 +1399,8 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams QKVPreprocessingParams makeQKVPreprocessingParams(XQAParams const& params, XQALaunchParam const& launchParams, void* xqa_q_input_ptr, Data_type QDataType, KvCacheDataType cache_type, int32_t batch_beam_size, KVCacheBuffer const& kv_cache_buffer, - int32_t const* cu_seqlens, int32_t const* cu_kv_seqlens, float const* rotary_inv_freq_buf, int multiProcessorCount) + KVCacheBuffer const& kv_cache_block_scales_buffer, int32_t const* cu_seqlens, int32_t const* cu_kv_seqlens, + float const* rotary_inv_freq_buf, int multiProcessorCount) { QKVPreprocessingParams preprocessingParms; memset(&preprocessingParms, 0, sizeof(preprocessingParms)); @@ -55,16 +56,15 @@ QKVPreprocessingParams makeQKVPreprocessingParams(XQAParams co preprocessingParms.qkv_input = static_cast(const_cast(params.qkv)); preprocessingParms.q_output = static_cast(xqa_q_input_ptr); preprocessingParms.kv_cache_buffer = kv_cache_buffer; - preprocessingParms.kv_cache_block_scales_buffer = {}; + preprocessingParms.kv_cache_block_scales_buffer = kv_cache_block_scales_buffer; preprocessingParms.qkv_bias = static_cast(params.qkv_bias); // Prepare values for fmha. preprocessingParms.fmha_bmm1_scale = launchParams.bmm1_scale_ptr; preprocessingParms.fmha_bmm2_scale = launchParams.bmm2_scale_ptr; bool const is_fp8_q_input = (QDataType == DATA_TYPE_E4M3); - if (params.kv_cache_quant_mode.hasFp8KvCache()) + if (params.kv_cache_quant_mode.hasFp8KvCache() || params.kv_cache_quant_mode.hasFp4KvCache()) { - preprocessingParms.q_scale_quant_orig = params.kv_scale_quant_orig; - preprocessingParms.kv_scale_quant_orig = params.kv_scale_quant_orig; + preprocessingParms.qkv_scale_quant_orig = params.kv_scale_quant_orig; } if (params.is_fp8_output) { @@ -77,8 +77,7 @@ QKVPreprocessingParams makeQKVPreprocessingParams(XQAParams co preprocessingParms.cu_seq_lens = cu_seqlens; preprocessingParms.rotary_embedding_inv_freq = rotary_inv_freq_buf; preprocessingParms.rotary_coef_cache_buffer = params.rotary_cos_sin; - preprocessingParms.kvScaleOrigQuant = params.kv_scale_orig_quant; - preprocessingParms.kv_cache_scale_factors = nullptr; + preprocessingParms.qkv_scale_orig_quant = params.kv_scale_orig_quant; preprocessingParms.spec_decoding_position_offsets = params.cross_attention ? nullptr : params.spec_decoding_position_offsets; preprocessingParms.mrope_position_deltas = params.mrope_position_deltas; @@ -131,8 +130,11 @@ XqaDispatcher::XqaDispatcher(XqaFixedParams fixedParams) { if (mUseTllmGen) { - // The preprocessing kernel will convert Q from inputDataType to fp8 if the kv cache dtype is also e4m3. - mQDataType = (mFixedParams.kvDataType == DATA_TYPE_E4M3) ? DATA_TYPE_E4M3 : mFixedParams.inputDataType; + // The preprocessing kernel will convert Q from inputDataType to fp8 if the kv cache dtype e4m3 or e2m1, + // as both the NVFP4 KV kernels and FP8 KV kernels uses FP8 input for Q. + mQDataType = (mFixedParams.kvDataType == DATA_TYPE_E4M3 || mFixedParams.kvDataType == DATA_TYPE_E2M1) + ? DATA_TYPE_E4M3 + : mFixedParams.inputDataType; mTllmGenFMHARunner.reset( new TllmGenFmhaRunner(mQDataType, mFixedParams.kvDataType, mFixedParams.outputDataType)); } @@ -251,7 +253,7 @@ bool XqaDispatcher::isSupported() if (mUseTllmGen) { // TODO (perkzz): add the support of fp8-kv fp16/bf16-mma fmha. - if ((mFixedParams.kvDataType != mFixedParams.mathDataType) || (mQDataType != mFixedParams.mathDataType)) + if (mQDataType != mFixedParams.mathDataType) { TLLM_LOG_WARNING("Unsupported data type combination."); return false; @@ -309,7 +311,8 @@ bool XqaDispatcher::isSupported() //////////////////////////////////////////////////////////////////////////////////////////////////// template -void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buffer) +void XqaDispatcher::runImpl( + XQAParams params, KVCacheBuffer const& kv_cache_buffer, KVCacheBuffer const& kv_cache_block_scales_buffer) { if (mUseTllmGen) { @@ -323,9 +326,19 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff unsigned int beam_width = params.beam_width; unsigned int batch_beam_size = params.batch_size * beam_width; - const KvCacheDataType cache_type = params.kv_cache_quant_mode.hasInt8KvCache() - ? KvCacheDataType::INT8 - : (params.kv_cache_quant_mode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE); + KvCacheDataType cache_type{KvCacheDataType::BASE}; + if (params.kv_cache_quant_mode.hasInt8KvCache()) + { + cache_type = KvCacheDataType::INT8; + } + else if (params.kv_cache_quant_mode.hasFp8KvCache()) + { + cache_type = KvCacheDataType::FP8; + } + else if (params.kv_cache_quant_mode.hasFp4KvCache()) + { + cache_type = KvCacheDataType::NVFP4; + } XQALaunchParam launchParams; void* inputScratch = nullptr; @@ -373,8 +386,8 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff // The preprocessing kernel that applies RoPE and updates kv cache. auto preprocessingParms = makeQKVPreprocessingParams(params, launchParams, xqa_q_input_ptr, - mQDataType, cache_type, batch_beam_size, kv_cache_buffer, cu_seqlens, cu_kv_seqlens, rotary_inv_freq_buf, - mMultiProcessorCount); + mQDataType, cache_type, batch_beam_size, kv_cache_buffer, kv_cache_block_scales_buffer, cu_seqlens, + cu_kv_seqlens, rotary_inv_freq_buf, mMultiProcessorCount); invokeQKVPreprocessing(preprocessingParms, params.stream); sync_check_cuda_error(params.stream); @@ -399,6 +412,7 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff // Paged KV tllmRunnerParams.mQkvLayout = QkvLayout::PagedKv; tllmRunnerParams.kvPtr = kv_cache_buffer.mPrimaryPoolPtr; + tllmRunnerParams.kvSfPtr = kv_cache_block_scales_buffer.mPrimaryPoolPtr; tllmRunnerParams.kvPageIdxPtr = reinterpret_cast(kv_cache_buffer.data); tllmRunnerParams.mMaxNumPagesPerSeqKv = kv_cache_buffer.mMaxBlocksPerSeq; tllmRunnerParams.mNumTokensPerPage = kv_cache_buffer.mTokensPerBlock; @@ -461,31 +475,33 @@ void XqaDispatcher::runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buff } } -void XqaDispatcher::run(XQAParams const& params, KVLinearBuffer const& kv_cache_buffer) +void XqaDispatcher::run( + XQAParams const& params, KVLinearBuffer const& kv_cache_buffer, KVLinearBuffer const& kv_cache_block_scales_buffer) { TLLM_CHECK_WITH_INFO((mFixedParams.inputDataType == DATA_TYPE_FP16 || mFixedParams.inputDataType == DATA_TYPE_BF16), "The input Qkv tensor must be fp16/bf16."); if (mFixedParams.inputDataType == DATA_TYPE_FP16) { - this->runImpl<__half, KVLinearBuffer>(params, kv_cache_buffer); + this->runImpl<__half, KVLinearBuffer>(params, kv_cache_buffer, kv_cache_block_scales_buffer); } else { - this->runImpl<__nv_bfloat16, KVLinearBuffer>(params, kv_cache_buffer); + this->runImpl<__nv_bfloat16, KVLinearBuffer>(params, kv_cache_buffer, kv_cache_block_scales_buffer); } } -void XqaDispatcher::run(XQAParams const& params, KVBlockArray const& kv_cache_buffer) +void XqaDispatcher::run( + XQAParams const& params, KVBlockArray const& kv_cache_buffer, KVBlockArray const& kv_cache_block_scales_buffer) { TLLM_CHECK_WITH_INFO((mFixedParams.inputDataType == DATA_TYPE_FP16 || mFixedParams.inputDataType == DATA_TYPE_BF16), "The input Qkv tensor must be fp16/bf16."); if (mFixedParams.inputDataType == DATA_TYPE_FP16) { - this->runImpl<__half, KVBlockArray>(params, kv_cache_buffer); + this->runImpl<__half, KVBlockArray>(params, kv_cache_buffer, kv_cache_block_scales_buffer); } else { - this->runImpl<__nv_bfloat16, KVBlockArray>(params, kv_cache_buffer); + this->runImpl<__nv_bfloat16, KVBlockArray>(params, kv_cache_buffer, kv_cache_block_scales_buffer); } } diff --git a/cpp/tensorrt_llm/kernels/xqaDispatcher.h b/cpp/tensorrt_llm/kernels/xqaDispatcher.h index 1fc65150a5e..784b30eda86 100644 --- a/cpp/tensorrt_llm/kernels/xqaDispatcher.h +++ b/cpp/tensorrt_llm/kernels/xqaDispatcher.h @@ -78,9 +78,11 @@ class XqaDispatcher bool isSupported(); // Run the XQA kernel. - void run(XQAParams const& params, KVLinearBuffer const& kv_cache_buffer); + void run(XQAParams const& params, KVLinearBuffer const& kv_cache_buffer, + KVLinearBuffer const& kv_cache_block_scales_buffer); - void run(XQAParams const& params, KVBlockArray const& kv_cache_buffer); + void run( + XQAParams const& params, KVBlockArray const& kv_cache_buffer, KVBlockArray const& kv_cache_block_scales_buffer); int getWorkspaceAlignment(); @@ -104,7 +106,8 @@ class XqaDispatcher protected: template - void runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buffer); + void runImpl( + XQAParams params, KVCacheBuffer const& kv_cache_buffer, KVCacheBuffer const& kv_cache_block_scales_buffer); }; constexpr uint32_t xqaMlaCgaXBufSize = 8704 * 2; diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index c951f967c20..357b77da819 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -166,6 +166,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .value("FP8", nvinfer1::DataType::kFP8) .value("BF16", nvinfer1::DataType::kBF16) .value("INT64", nvinfer1::DataType::kINT64) + .value("NVFP4", nvinfer1::DataType::kFP4) .export_values(); nb::enum_(m, "GptModelVariant") diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 54835e81d7f..41cd3850a1a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -357,6 +357,18 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) } return block_pool_pointers; }) + .def("get_block_scale_pool_pointers", + [](tbk::BaseKVCacheManager& self) + { + std::optional block_scale_pool_pointers{std::nullopt}; + auto tensor = self.getBlockScalePoolPointers(); + if (tensor) + { + std::shared_ptr _tensor = std::move(tensor); + block_scale_pool_pointers = tr::Torch::tensor(_tensor); + } + return block_scale_pool_pointers; + }) .def("get_layer_to_pool_mapping", [](tbk::BaseKVCacheManager& self) { diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index cdc9736db09..d233ca59ed9 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -158,6 +158,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("FP8", nvinfer1::DataType::kFP8) .value("BF16", nvinfer1::DataType::kBF16) .value("INT64", nvinfer1::DataType::kINT64) + .value("NVFP4", nvinfer1::DataType::kFP4) .export_values(); py::enum_(m, "GptModelVariant") @@ -236,6 +237,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling) .def_property_readonly("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling) .def_property_readonly("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache) + .def_property_readonly("has_fp4_kv_cache", &tc::QuantMode::hasFp4KvCache) .def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache) .def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) .def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 8dcc9c1443a..8ef64464b63 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -248,22 +248,55 @@ class Runner : public RunnerBase ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); - auto const cache_elem_size = (op.mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T)); + // The cache element size in bits. + int cache_elem_bits = op.getKvCacheElemSizeInBits(); auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize; - auto const bytes_per_block = block_size * cache_elem_size; + auto const bytes_per_block = block_size * cache_elem_bits / 8 /*bits*/; int32_t const kv_factor = op.isMLAEnabled() ? 1 : 2; auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block; - void* host_primary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() - ? reinterpret_cast( + // Prepare block pool pointers for NVFP4 KV cache. + void* host_primary_pool_pointer{nullptr}; + void* host_secondary_pool_pointer{nullptr}; + void* host_primary_block_scale_pool_pointer{nullptr}; + void* host_secondary_block_scale_pool_pointer{nullptr}; + + // Whether NVFP4 KV cache is used. + bool const use_kv_cache = op.useKVCache() && host_kv_cache_pool_pointers.has_value(); + bool const use_nvfp4_kv_cache = use_kv_cache && op.mKVCacheQuantMode.hasFp4KvCache(); + if (use_nvfp4_kv_cache) + { + // For NVFP4 KV cache, extra block scales are stored in separate pools. + // The layout of host_kv_cache_pool_pointers is [num_pools, 2 (primary and secondary), 2 (data and scale)]. + TORCH_CHECK(host_kv_cache_pool_pointers.value().dim() == 3); + host_primary_pool_pointer = reinterpret_cast( + reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 0, 0}).item()) + + intra_pool_offset); + host_secondary_pool_pointer = reinterpret_cast( + reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 1, 0}).item()) + + intra_pool_offset); + // Calculate the intra-pool offset for scaling factors. + // Note that NVFP4 block scaling use a fixed vector size of 16. + auto constexpr vector_size = 16; + auto const bytes_per_block_sf = block_size / vector_size * 1 /*bytes per E4M3 sf*/; + auto const intra_pool_offset_sf = layer_idx_in_cache_pool * kv_factor * bytes_per_block_sf; + host_primary_block_scale_pool_pointer = reinterpret_cast( + reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 0, 1}).item()) + + intra_pool_offset_sf); + host_secondary_block_scale_pool_pointer = reinterpret_cast( + reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 1, 1}).item()) + + intra_pool_offset_sf); + } + else if (use_kv_cache) + { + TORCH_CHECK(host_kv_cache_pool_pointers.value().dim() == 2); + host_primary_pool_pointer = reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 0}).item()) - + intra_pool_offset) - : nullptr; - void* host_secondary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() - ? reinterpret_cast( + + intra_pool_offset); + host_secondary_pool_pointer = reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 1}).item()) - + intra_pool_offset) - : nullptr; + + intra_pool_offset); + } float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; @@ -272,6 +305,11 @@ class Runner : public RunnerBase { kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); + if (op.mKVCacheQuantMode.hasFp4KvCache()) + { + TORCH_CHECK(kv_scale_orig_quant.value().size(0) == 3); + TORCH_CHECK(kv_scale_quant_orig.value().size(0) == 3); + } } // For FP8 output, out_scale represents the output scale. float const* out_scale_ptr = (op.mFP8ContextFMHA && !op.mFuseFp4Quant && out_scale.has_value()) @@ -310,6 +348,8 @@ class Runner : public RunnerBase common_enqueue_params.block_offsets = block_offsets; common_enqueue_params.host_primary_pool_pointer = host_primary_pool_pointer; common_enqueue_params.host_secondary_pool_pointer = host_secondary_pool_pointer; + common_enqueue_params.host_primary_block_scale_pool_pointer = host_primary_block_scale_pool_pointer; + common_enqueue_params.host_secondary_block_scale_pool_pointer = host_secondary_block_scale_pool_pointer; common_enqueue_params.num_tokens = num_tokens; common_enqueue_params.total_kv_len = total_kv_len; common_enqueue_params.max_blocks_per_sequence = max_blocks_per_sequence; @@ -738,7 +778,7 @@ bool attention_supports_nvfp4_output(int64_t const num_heads, int64_t const num_ op->mHeadSize = head_size; op->mMaskType = static_cast(int32_t(mask_type)); op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode)); - op->mFP8ContextFMHA = op->mKVCacheQuantMode.hasFp8KvCache(); + op->mFP8ContextFMHA = op->mKVCacheQuantMode.hasFp8KvCache() || op->mKVCacheQuantMode.hasFp4KvCache(); op->mUseKVCache = true; op->mPagedKVCache = true; op->mTokensPerBlock = tokens_per_block.value_or(0); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index a52cca097a3..0a52ae84852 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -539,6 +539,8 @@ TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 16; auto constexpr blocksInSecondaryPool = 16; + auto constexpr numFp4EltsPerContainer = 2; + auto constexpr vectorSize = 16; auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -560,7 +562,9 @@ TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) auto const& blockManager = kvCacheManager.getBlockManager(); EXPECT_TRUE(blockManager.containsBlockScales(1)); - EXPECT_EQ(blockManager.getBlockSize(0) / 16, blockManager.getBlockSize(1)); + // Block size of pool 0 reflects the number of container elements. It is number of FP4 elements / 2. + // The expected block size of pool 1 should be the number of FP4 elements / vectorSize. + EXPECT_EQ(blockManager.getBlockSize(0) * numFp4EltsPerContainer / vectorSize, blockManager.getBlockSize(1)); } #endif @@ -3464,7 +3468,7 @@ TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest) auto bufferManager = tensorrt_llm::runtime::BufferManager(std::make_shared()); auto transferManager = KVCacheTransferManager(bufferManager); - auto pool = KVCacheBlockPool(0, 2, 0, 0, 0, 1); + auto pool = KVCacheBlockPool(0, 2, 0, 0, 0); pool.primaryPtr = bufferManager.gpu(tr::ITensor::makeShape({1, blockSize}), nvinfer1::DataType::kFLOAT); bufferManager.setZero(*pool.primaryPtr); diff --git a/cpp/tests/unit_tests/kernels/ropeTest.cu b/cpp/tests/unit_tests/kernels/ropeTest.cu index 3cd52f7c1cb..517b006e4fd 100644 --- a/cpp/tests/unit_tests/kernels/ropeTest.cu +++ b/cpp/tests/unit_tests/kernels/ropeTest.cu @@ -274,11 +274,6 @@ void computeReferenceBiasRope(QKVPreprocessingParams para float kGlobalScale = 0.f; float vGlobalScale = 0.f; - if (params.kv_cache_scale_factors) - { - kGlobalScale = params.kv_cache_scale_factors[0]; - vGlobalScale = params.kv_cache_scale_factors[1]; - } // the size of a (Q)/K/V matrix TODO(dblanaru) separate this into q and kv sizes @@ -448,7 +443,6 @@ protected: int32_t* q_seq_lengths{nullptr}; int32_t* kv_seq_lengths{nullptr}; float* kv_scale_orig_quant{nullptr}; - float* kv_cache_scale_factors{nullptr}; KVBlockArray::DataType* block_offsets{nullptr}; void* host_primary_pool_pointer{nullptr}; @@ -687,9 +681,9 @@ protected: if (global_scale_tensor) { - kv_cache_scale_factors = bufferCast(*(global_scale_tensor)); - kv_cache_scale_factors[0] = 5.1f; - kv_cache_scale_factors[1] = 0.25f; + kv_scale_orig_quant = bufferCast(*(global_scale_tensor)); + kv_scale_orig_quant[0] = 5.1f; + kv_scale_orig_quant[1] = 0.25f; } qkv_size = num_tokens * 3 * mNumHeads * mHeadSize; @@ -715,8 +709,7 @@ protected: preprocessingParams.cu_kv_seq_lens = nullptr; // Only used by cross attention. preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf; preprocessingParams.rotary_coef_cache_buffer = rotary_cos_sin; - preprocessingParams.kvScaleOrigQuant = kv_scale_orig_quant; - preprocessingParams.kv_cache_scale_factors = kv_cache_scale_factors; + preprocessingParams.qkv_scale_orig_quant = kv_scale_orig_quant; preprocessingParams.spec_decoding_position_offsets = nullptr; // Cast to int* if necessary preprocessingParams.batch_size = batch_size; preprocessingParams.max_input_seq_len = input_seq_length; diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 5a4d157a70e..42cb70c86c6 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -171,6 +171,8 @@ def plan( kv_scale_quant_orig: Optional[torch.Tensor] = None, out_scale: Optional[torch.Tensor] = None, out_scale_sf: Optional[torch.Tensor] = None, + kv_scales_sf: Optional[torch.Tensor] = None, + kv_scales_sf_inv: Optional[torch.Tensor] = None, use_nvfp4_output: bool = False, use_paged_context_fmha: bool = False, attention_input_type: AttentionInputType = AttentionInputType.mixed, @@ -216,6 +218,8 @@ def plan( kv_scale_quant_orig (torch.Tensor): The tensor to store the scaling factor for dequantization from INT8/FP8 in the KV cache, with shape (1) on GPU. out_scale (torch.Tensor): The tensor to store the scaling factor to quantize output, with shape (1) on GPU. out_scale_sf (torch.Tensor): The tensor to store the global scale for NVFP4 scaling factors, with shape (1) on GPU. + kv_scales_sf (torch.Tensor): The tensor to store the global scale for KV NVFP4 scaling factors, with shape (2) on GPU. + kv_scales_sf_inv (torch.Tensor): The tensor to store the inverse of the global scale for KV NVFP4 scaling factors, with shape (2) on GPU. use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner. mrope_config (dict): The dictionary containing the mRope configuration. softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) @@ -240,8 +244,8 @@ def plan( self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping self.workspace = workspace self.cache_indirection = cache_indirection - self.kv_scale_orig_quant = kv_scale_orig_quant - self.kv_scale_quant_orig = kv_scale_quant_orig + self.kv_scale_orig_quant = kv_scale_orig_quant if kv_scales_sf_inv is None else kv_scales_sf_inv + self.kv_scale_quant_orig = kv_scale_quant_orig if kv_scales_sf is None else kv_scales_sf self.out_scale = out_scale self.out_scale_sf = out_scale_sf self.use_paged_context_fmha = use_paged_context_fmha @@ -478,7 +482,6 @@ def run( spec_decoding_bool_params, spec_decoding_tensor_params, ) - # reset the planned states (especially tensors) to avoid memory leak self.plan() return output, output_sf @@ -1051,14 +1054,11 @@ def __init__( self.is_mla_enable = mla_params is not None self.mla_params = mla_params or MLAParams() self.v_head_dim = self.mla_params.v_head_dim if self.is_mla_enable else head_dim - - self.kv_cache_scaling_factor = torch.tensor( - [1.0], - dtype=torch.float32, - device='cuda', - ) + self.kv_cache_scaling_factor = torch.ones(1, + dtype=torch.float32, + device='cuda') self.kv_scale_quant_orig = self.kv_cache_scaling_factor - self.kv_scale_orig_quant = 1.0 / self.kv_scale_quant_orig + self.kv_scale_orig_quant = 1.0 / self.kv_cache_scaling_factor if not skip_create_weights_in_init: self.update_quant_config(self.quant_config) @@ -1070,6 +1070,8 @@ def update_quant_config(self, new_quant_config: Optional[QuantConfig]): if self.quant_config is not None: self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache( ) + self.has_fp4_kv_cache = self.quant_config.layer_quant_mode.has_fp4_kv_cache( + ) self.has_fp8_qdq = self.quant_config.layer_quant_mode.has_fp8_qdq() self.has_fp8_block_wise = self.quant_config.layer_quant_mode.has_fp8_block_scales( @@ -1092,6 +1094,8 @@ def forward( metadata: TrtllmAttentionMetadata, out_scale: Optional[torch.Tensor] = None, out_scale_sf: Optional[torch.Tensor] = None, + kv_scales_sf: Optional[torch.Tensor] = None, + kv_scales_sf_inv: Optional[torch.Tensor] = None, *, attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, attention_input_type: AttentionInputType = AttentionInputType.mixed, @@ -1159,6 +1163,8 @@ def forward( kv_scale_quant_orig=self.kv_scale_quant_orig, out_scale=out_scale, out_scale_sf=out_scale_sf, + kv_scales_sf=kv_scales_sf, + kv_scales_sf_inv=kv_scales_sf_inv, use_nvfp4_output=use_nvfp4_output, use_paged_context_fmha=use_paged_context_fmha, attention_input_type=attention_input_type, @@ -1182,7 +1188,8 @@ def forward( # Use UINT8 as the container dtype for NVFP4. out_dtype = torch.uint8 elif (self.has_fp8_qdq or self.has_nvfp4 or self.has_fp8_block_wise - or self.has_fp8_rowwise) and self.has_fp8_kv_cache: + or self.has_fp8_rowwise) and (self.has_fp8_kv_cache + or self.has_fp4_kv_cache): # TODO(qijun): revisit fp8_context_fmha logic out_dtype = torch.float8_e4m3fn diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index aa5183c1c14..f67f047c7a4 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -348,6 +348,13 @@ def _attn_impl( if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output: out_scale_sf = self.o_proj.input_scale + kv_scales_sf = None + kv_scales_sf_inv = None + if self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp4_kv_cache( + ): + kv_scales_sf = self.qkv_proj.kv_scales + kv_scales_sf_inv = self.qkv_proj.inv_kv_scales + mrope_config = None if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None: mrope_config = dict() @@ -363,6 +370,8 @@ def _attn_impl( attn_metadata, out_scale=out_scale, out_scale_sf=out_scale_sf, + kv_scales_sf=kv_scales_sf, + kv_scales_sf_inv=kv_scales_sf_inv, attention_mask=attention_mask, mrope_config=mrope_config, attention_window_size=attention_window_size, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 67d49b3d945..def85dddcad 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -2,6 +2,7 @@ import enum import math +import os from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, List, Optional, Union @@ -327,7 +328,12 @@ def create_weights(self, module: Linear, in_features: int, module.inv_input_scale = Parameter(torch.tensor(1., dtype=torch.float32), requires_grad=False) - + # K, V scales for NVFP4 KV cache + module.kv_scales = Parameter(torch.ones(3, dtype=torch.float32), + requires_grad=False) + # K, V scales for NVFP4 KV cache + module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32), + requires_grad=False) if bias: module.bias = Parameter(torch.empty((out_features), dtype=dtype), requires_grad=False) @@ -375,6 +381,15 @@ def apply(self, module: Linear, input: torch.Tensor, output = output + bias return output + def load_kv_scales(self, weights: List[Dict]): + k_scale, v_scale = [], [] + for w in weights: + if "k_scale" in w: + k_scale.append(w["k_scale"][...].reshape([])) + if "v_scale" in w: + v_scale.append(w["v_scale"][...].reshape([])) + return k_scale, v_scale + def load_weight_scales(self, weights: List[Dict]): input_scale, weight_scale = [], [] for w in weights: @@ -409,6 +424,7 @@ def load_weights_fused_qkv_linear(self, module: Linear, else: # Dynamic quantization module.input_scale = None + copy_weight(module.weight_scale, max(weight_scale)) q_weight = q_weight.to(module.dtype) * weight_scale[0] @@ -423,6 +439,22 @@ def load_weights_fused_qkv_linear(self, module: Linear, torch.float8_e4m3fn) copy_weight(module.weight, fused_weight) + # Load k and v scales, used for NVFP4 KV cache + k_scale, v_scale = self.load_kv_scales(weights) + # NOTE: Currently the calibrated kv scales may cause overflow for certain input, disabling by default. + if os.environ.get("TRTLLM_LOAD_KV_SCALES", "0") == "1": + if len(k_scale) != 0: + assert len(v_scale) != 0 + # The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448, + # to avoid overflow when dequantizing NVFP4 in attention kernels. + copy_weight( + module.kv_scales, + torch.tensor( + [1.0, max(k_scale) * 6.0, + max(v_scale) * 6.0], + dtype=torch.float32)) + module.inv_kv_scales.data = 1.0 / module.kv_scales + def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: input_scale, weight_scale = self.load_weight_scales(weights) @@ -695,6 +727,13 @@ def create_weights(self, module: Linear, in_features: int, module.alpha = Parameter(torch.empty([1], dtype=torch.float32), requires_grad=False) + # K, V scales for NVFP4 KV cache + module.kv_scales = Parameter(torch.ones(3, dtype=torch.float32), + requires_grad=False) + # K, V scales for NVFP4 KV cache + module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32), + requires_grad=False) + if bias: module.bias = Parameter(torch.empty((out_features), dtype=dtype), requires_grad=False) @@ -718,6 +757,15 @@ def apply(self, module: Linear, input: torch.Tensor, output = output + bias return output + def load_kv_scales(self, weights: List[Dict]): + k_scale, v_scale = [], [] + for w in weights: + if "k_scale" in w: + k_scale.append(w["k_scale"][...].reshape([])) + if "v_scale" in w: + v_scale.append(w["v_scale"][...].reshape([])) + return k_scale, v_scale + def load_weight_scales(self, weights: List[Dict], tp_size: int = 1, @@ -796,10 +844,25 @@ def load_weights_fused_qkv_linear(self, module: Linear, copy_weight(module.input_scale, input_scale) copy_weight(module.weight_scale, weight_scale) copy_weight(module.alpha, alpha) - fused_weight = torch.cat((q_weight, k_weight, v_weight)) copy_weight(module.weight, fused_weight) + # Load k and v scales, used for NVFP4 KV cache + k_scale, v_scale = self.load_kv_scales(weights) + # NOTE: Currently the calibrated kv scales may cause overflow for certain input, disabling by default. + if os.environ.get("TRTLLM_LOAD_KV_SCALES", "0") == "1": + if len(k_scale) != 0: + assert len(v_scale) != 0 + # The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448, + # to avoid overflow when dequantizing NVFP4 in attention kernels using FP8 math. + copy_weight( + module.kv_scales, + torch.tensor( + [1.0, max(k_scale) * 6.0, + max(v_scale) * 6.0], + dtype=torch.float32)) + module.inv_kv_scales.data = 1.0 / module.kv_scales + def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: gate_weight, up_weight = load_weights_fused_gate_up_helper( diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 2f0753ed31a..677399e90ed 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -293,6 +293,9 @@ def _create_kv_cache_manager( if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( ): kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8 + elif quant_config is not None and quant_config.quant_mode.has_fp4_kv_cache( + ): + kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4 else: kv_cache_dtype = str_dtype_to_binding( torch_dtype_to_str(model_engine.dtype)) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index d9f180c0fc3..85510f4c67e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -98,7 +98,7 @@ def warmup(self, resource_manager: ResourceManager) -> None: "nvfp4": QuantAlgo.NVFP4.value, "auto": "auto" } -_VALID_KV_CACHE_DTYPES = ("fp8", "auto") +_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9a5b42166dc..5e63ff735b8 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -14,7 +14,7 @@ from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.sampling_params import SamplingParams -from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range +from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range from ...logger import logger from ...mapping import CpType, Mapping from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, @@ -347,6 +347,14 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], self.impl.allocate_pools(False) self.kv_cache_pool_pointers = self.impl.get_block_pool_pointers() + kv_cache_block_scale_pool_pointers = self.impl.get_block_scale_pool_pointers( + ) + if kv_cache_block_scale_pool_pointers.numel() > 0: + self.kv_cache_pool_pointers = torch.stack([ + self.kv_cache_pool_pointers, kv_cache_block_scale_pool_pointers + ], + dim=-1) + self.kv_cache_pool_mapping = self.impl.get_layer_to_pool_mapping() self.num_pools = self.impl.num_pools self.max_blocks_per_seq = self.impl.max_blocks_per_seq @@ -515,11 +523,15 @@ def calculate_max_num_blocks(self, self.num_kv_heads_per_layer) * head_dim if dtype not in (DataType.FP8, DataType.HALF, DataType.BF16, - DataType.FLOAT): + DataType.FLOAT, DataType.NVFP4): raise ValueError(f'Cannot support {dtype} KV cache.') - kv_cache_dtype_bytes = binding_dtype_size(dtype) - cache_size_bytes_per_token = cache_size_per_token * kv_cache_dtype_bytes + cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, + dtype) + if dtype == DataType.NVFP4: + # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. + cache_size_bytes_per_token += cache_size_per_token / 16 + free_mem, total_mem = torch.cuda.mem_get_info() assert free_mem_fraction < 1.0, f"Invalid freeMemFraction, freeMemFraction {free_mem_fraction} must be smaller than 1.0" @@ -709,8 +721,11 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: for window_size in sorted(window_size_to_layers): layers = window_size_to_layers[window_size] cache_size_per_token = calculate_cache_size_per_token(layers) - cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size( - dtype) + cache_size_bytes_per_token = get_size_in_bytes( + cache_size_per_token, dtype) + if dtype == DataType.NVFP4: + # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. + cache_size_bytes_per_token += cache_size_per_token / 16 required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token logger.debug( f'Required memory per sequence: {required_mem_bytes_per_seq} bytes') @@ -737,8 +752,11 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: # Calculate cache size per token for remaining layers only cache_size_per_token = calculate_cache_size_per_token( remaining_layers) - cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size( - dtype) + cache_size_bytes_per_token = get_size_in_bytes( + cache_size_per_token, dtype) + if dtype == DataType.NVFP4: + # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. + cache_size_bytes_per_token += cache_size_per_token / 16 logger.debug( f'Cache size per token for {len(remaining_layers)} layers: ' f'{cache_size_bytes_per_token} bytes') diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index c68777b96a7..c468b6487d4 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -183,16 +183,17 @@ def str_dtype_to_torch(dtype): ) _binding_to_str_dtype = {v: k for k, v in _str_to_binding_dtype_dict.items()} -_binding_dtype_size = { - DataType.INT64: 8, - DataType.FLOAT: 4, - DataType.INT32: 4, - DataType.BF16: 2, - DataType.HALF: 2, - DataType.BOOL: 1, - DataType.FP8: 1, - DataType.INT8: 1, - DataType.UINT8: 1, +_binding_dtype_bits = { + DataType.INT64: 64, + DataType.FLOAT: 32, + DataType.INT32: 32, + DataType.BF16: 16, + DataType.HALF: 16, + DataType.BOOL: 8, + DataType.FP8: 8, + DataType.INT8: 8, + DataType.UINT8: 8, + DataType.NVFP4: 4, } @@ -206,6 +207,12 @@ def binding_dtype_size(dtype: DataType): return _binding_dtype_size[dtype] +def get_size_in_bytes(num_elements: int, dtype: DataType): + total_num_bits = _binding_dtype_bits[dtype] * num_elements + assert total_num_bits % 8 == 0, f"Total number of bits {total_num_bits} must be divisible by 8" + return total_num_bits // 8 + + def str_dtype_to_binding(dtype): ret = _str_to_binding_dtype_dict.get(dtype) assert ret is not None, f'Unsupported dtype: {dtype}' diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index b2145ac7935..e55735043e9 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -395,10 +395,10 @@ def _update_from_hf_quant_config(self) -> bool: ) else: if quant_config.kv_cache_quant_algo not in [ - None, QuantAlgo.FP8 + None, QuantAlgo.FP8, QuantAlgo.NVFP4 ]: raise ValueError( - f"Only kv_cache_quant_algo={QuantAlgo.FP8} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}." + f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}." ) for key, value in hf_quant_config.items(): diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 33c264b9e47..77cac5b1d54 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -9,6 +9,9 @@ meta-llama/Llama-3.1-8B-Instruct: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 72.85 + - quant_algo: FP8 + kv_cache_quant_algo: NVFP4 + accuracy: 69.75 meta-llama/Llama-3.3-70B-Instruct: - accuracy: 83.78 - quant_algo: NVFP4 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 9dd1c25d3c3..62c48ddf8a1 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -32,6 +32,9 @@ meta-llama/Llama-3.1-8B-Instruct: - quant_algo: FP8 kv_cache_quant_algo: FP8 accuracy: 67.87 + - quant_algo: FP8 + kv_cache_quant_algo: NVFP4 + accuracy: 66.45 meta-llama/Llama-3.2-1B: - quant_algo: W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN accuracy: 32.72 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index f0a8e923289..443266a4c98 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -301,6 +301,29 @@ def test_ngram(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell + @parametrize_with_ids("torch_compile", [False, True]) + @parametrize_with_ids("attn_backend", ["TRTLLM"]) + def test_nvfp4_kv(self, attn_backend, torch_compile): + torch_compile_config = TorchCompileConfig( + enable_fullgraph=True, + enable_piecewise_cuda_graph=True, + max_num_streams=3) if torch_compile else None + pytorch_config = dict( + torch_compile_config=torch_compile_config, + cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile, + batch_sizes=[4]), + attn_backend=attn_backend, + disable_overlap_scheduler=torch_compile, + ) + pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4") + with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_nvfp4_fp8_hf", + **pytorch_config) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4 + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.parametrize("backend", ["xgrammar", "llguidance"]) def test_guided_decoding(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index c66da8305f8..c883cee1a92 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -17,6 +17,7 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4 - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True] From 8a1227a45e1fbf062f3111dba5f4f806e88535d9 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Thu, 28 Aug 2025 09:35:25 +0000 Subject: [PATCH 2/6] review Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- cpp/tensorrt_llm/nanobind/bindings.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 357b77da819..9bb6ea33768 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -245,6 +245,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_prop_ro("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling) .def_prop_ro("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling) .def_prop_ro("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache) + .def_prop_ro("has_fp4_kv_cache", &tc::QuantMode::hasFp4KvCache) .def_prop_ro("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache) .def_prop_ro("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) .def_prop_ro("has_nvfp4", &tc::QuantMode::hasNvfp4) From 21928e4d42733aa2723c5f7e48c40e578ddb69b9 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Fri, 29 Aug 2025 03:10:05 +0000 Subject: [PATCH 3/6] review Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 443266a4c98..e0c65f7f8fc 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -305,6 +305,8 @@ def test_ngram(self): @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attn_backend", ["TRTLLM"]) def test_nvfp4_kv(self, attn_backend, torch_compile): + if torch_compile: + pytest.skip("NVFP4 KV does not support torch compile currently.") torch_compile_config = TorchCompileConfig( enable_fullgraph=True, enable_piecewise_cuda_graph=True, @@ -321,6 +323,8 @@ def test_nvfp4_kv(self, attn_backend, torch_compile): **pytorch_config) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4 + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) From 17906ad71f8844691424626635fc4ff588207dff Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Fri, 29 Aug 2025 07:04:45 +0000 Subject: [PATCH 4/6] minor fix Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 5e63ff735b8..c570df85e73 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -508,6 +508,13 @@ def update_resources(self, scheduled_batch: ScheduledRequests): def free_resources(self, request: LlmRequest): self.impl.remove_sequence(request.py_request_id, request) + def calculate_scaling_factor_size_bytes( + self, cache_size: int, quant_vector_size: int, + scaling_factor_dtype: DataType) -> int: + assert cache_size % quant_vector_size == 0, "NVFP4 cache size must be divisible by quant vector size" + return get_size_in_bytes(cache_size // quant_vector_size, + scaling_factor_dtype) + def calculate_max_num_blocks(self, kv_cache_config: KvCacheConfigCpp, head_dim: int, @@ -529,8 +536,10 @@ def calculate_max_num_blocks(self, cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, dtype) if dtype == DataType.NVFP4: - # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. - cache_size_bytes_per_token += cache_size_per_token / 16 + cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8) free_mem, total_mem = torch.cuda.mem_get_info() @@ -724,8 +733,10 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: cache_size_bytes_per_token = get_size_in_bytes( cache_size_per_token, dtype) if dtype == DataType.NVFP4: - # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. - cache_size_bytes_per_token += cache_size_per_token / 16 + cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8) required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token logger.debug( f'Required memory per sequence: {required_mem_bytes_per_seq} bytes') @@ -755,8 +766,10 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: cache_size_bytes_per_token = get_size_in_bytes( cache_size_per_token, dtype) if dtype == DataType.NVFP4: - # NVFP4 needs additional block scales. Vector Size is 16. Each scaling factor is 1 byte. - cache_size_bytes_per_token += cache_size_per_token / 16 + cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8) logger.debug( f'Cache size per token for {len(remaining_layers)} layers: ' f'{cache_size_bytes_per_token} bytes') From 4b8eb9048ff0f45d5f57ec0b3187c846ca16e807 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Fri, 29 Aug 2025 07:21:29 +0000 Subject: [PATCH 5/6] fix torch compile Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 3 ++- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 2 -- tests/integration/test_lists/test-db/l0_b200.yml | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index f67f047c7a4..e2859da5aa4 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -316,7 +316,8 @@ def create_output(self, q: torch.Tensor): has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales or self.o_proj.has_fp8_rowwise) - if has_quant_scale and self.attn.has_fp8_kv_cache: + if has_quant_scale and (self.attn.has_fp8_kv_cache + or self.attn.has_fp4_kv_cache): out_dtype = torch.float8_e4m3fn output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype) return output diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index e0c65f7f8fc..bdbd8a7375e 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -305,8 +305,6 @@ def test_ngram(self): @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attn_backend", ["TRTLLM"]) def test_nvfp4_kv(self, attn_backend, torch_compile): - if torch_compile: - pytest.skip("NVFP4 KV does not support torch compile currently.") torch_compile_config = TorchCompileConfig( enable_fullgraph=True, enable_piecewise_cuda_graph=True, diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index c883cee1a92..dcc81713ff7 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -18,6 +18,7 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True] From 8e5d33a82092fb5f62f1fc990c51e33e1c2470e0 Mon Sep 17 00:00:00 2001 From: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Date: Fri, 29 Aug 2025 07:26:54 +0000 Subject: [PATCH 6/6] minor fix Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index c570df85e73..ca3fe99c93b 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -508,8 +508,9 @@ def update_resources(self, scheduled_batch: ScheduledRequests): def free_resources(self, request: LlmRequest): self.impl.remove_sequence(request.py_request_id, request) + @staticmethod def calculate_scaling_factor_size_bytes( - self, cache_size: int, quant_vector_size: int, + cache_size: int, quant_vector_size: int, scaling_factor_dtype: DataType) -> int: assert cache_size % quant_vector_size == 0, "NVFP4 cache size must be divisible by quant vector size" return get_size_in_bytes(cache_size // quant_vector_size, @@ -733,7 +734,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: cache_size_bytes_per_token = get_size_in_bytes( cache_size_per_token, dtype) if dtype == DataType.NVFP4: - cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_bytes_per_token += KVCacheManager.calculate_scaling_factor_size_bytes( cache_size_per_token, quant_vector_size=16, scaling_factor_dtype=DataType.FP8) @@ -766,7 +767,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: cache_size_bytes_per_token = get_size_in_bytes( cache_size_per_token, dtype) if dtype == DataType.NVFP4: - cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_bytes_per_token += KVCacheManager.calculate_scaling_factor_size_bytes( cache_size_per_token, quant_vector_size=16, scaling_factor_dtype=DataType.FP8)