diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 4a5ddb89286..06fa145ecbb 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -83,6 +83,40 @@ namespace tk = tensorrt_llm::kernels; namespace tensorrt_llm::batch_manager { +std::map TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg( + ModelConfig const& modelConfig, WorldConfig const& worldConfig, + std::vector const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor) +{ + // These are the number of attention layers on this PP rank. + auto const numLocalAttnLayers + = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); + // These are the number of attention layers on all previous PP ranks. + auto const numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION, + worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); + // Use global ranks of attention layers to lookup from maxAttentionWindowVec. + auto const startAttnLayerId = numLowerRankAttnLayers; + auto const endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers; + auto const numNonUniqueWindowSizes = static_cast(maxAttentionWindowVec.size()); + std::map> uniqueWindowSizeToLayers; + for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++) + { + // maxAttentionWindowVec may or may not be stretched to the length of numLayers yet. + // If not stretched yet, we cycle through the window sizes. + auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes); + uniqueWindowSizeToLayers[windowSize].push_back(layerIdx); + } + std::map cacheSizeBytesPerTokenPerWindow; + for (auto const& [windowSize, globalLayerIds] : uniqueWindowSizeToLayers) + { + auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize( + modelConfig, globalLayerIds, isCrossAttention, kvFactor); + auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize(); + cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken; + } + + return cacheSizeBytesPerTokenPerWindow; +}; + bool TrtGptModelInflightBatching::executorConfigIsValid( ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig) { @@ -264,32 +298,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor) - { - auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange( - worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention); - auto numKvHeadsPerLayer = std::vector(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd); - auto windowSizeLayers - = BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers()); - std::map cacheSizeBytesPerTokenPerWindow; - for (auto const& [windowSize, managedLayers] : windowSizeLayers) - { - auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize( - modelConfig, managedLayers, isCrossAttention, kvFactor); - auto const cacheSizeBytesPerToken - = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize(); - cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken; - } - - return cacheSizeBytesPerTokenPerWindow; - }; auto cacheTransceiverConfig = executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig()); - auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerToken( + auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerTokenForDisagg( mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2); auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize( cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 28d1767525c..61f5ccfb882 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -148,6 +148,19 @@ class TrtGptModelInflightBatching : public TrtGptModel ~TrtGptModelInflightBatching() override; + /// @brief Calculate the cache size per token for the disaggregated serving. + /// @param modelConfig Model configuration. + /// @param worldConfig World configuration. + /// @param maxAttentionWindowVec Maximum attention window vector. (may have fewer elements than numLayers, in which + /// case it cycles) + /// @param isCrossAttention Whether the attention is cross attention. + /// @param kvFactor KV factor. + /// @return Cache size per token for the disaggregated layers. Note that window size is not included in the result + /// here. + [[nodiscard]] static std::map calculateCacheSizePerTokenForDisagg( + runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + std::vector const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor); + void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override; /// @brief Terminate request in the next forwardSync call that includes the request. diff --git a/cpp/tests/unit_tests/executor/executorTestSmall.cpp b/cpp/tests/unit_tests/executor/executorTestSmall.cpp index 472f56e8a6a..2987509f16a 100644 --- a/cpp/tests/unit_tests/executor/executorTestSmall.cpp +++ b/cpp/tests/unit_tests/executor/executorTestSmall.cpp @@ -11,6 +11,7 @@ #include #include +#include namespace tensorrt_llm::testing { @@ -201,4 +202,88 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator, return nameStringStream.str(); }); +// Helper function to test calculateCacheSizePerToken with given parameters. +std::map calculateCacheSizePerTokenHelper( + std::vector const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2, + runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4, + runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512, + bool isCrossAttention = false) +{ + // Create minimal ModelConfig for testing. + auto modelConfig = runtime::ModelConfig( + vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT); + modelConfig.useGptAttentionPlugin(true); + modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt); + modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED); + + auto const worldConfig = runtime::WorldConfig(); + + return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg( + modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor); +} + +// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types. +TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg) +{ + // Common parameters. + constexpr runtime::SizeType32 nbLayers = 5; + constexpr runtime::SizeType32 hiddenSize = 512; + constexpr runtime::SizeType32 kvFactor = 2; + constexpr runtime::SizeType32 vocabSize = 32; + constexpr runtime::SizeType32 nbHeads = 8; + // Test case 1: Single attention window size - attention layers only. + { + std::vector maxAttentionWindowVec = {128}; + constexpr runtime::SizeType32 nbAttentionLayers = 5; + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; + constexpr runtime::SizeType32 nbRnnLayers = 0; + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement); + } + + // Test case 2: Multiple attention window sizes - attention layers only. + { + std::vector maxAttentionWindowVec = {128, 256}; + constexpr runtime::SizeType32 nbAttentionLayers = 5; + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; + constexpr runtime::SizeType32 nbRnnLayers = 0; + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); + EXPECT_EQ(result.size(), 2); + auto const nbAttentionLayersIn128Window = 3; + auto const nbAttentionLayersIn256Window = 2; + EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement); + EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement); + } + + // Test case 3: Single attention window size - attention and rnn layers. + { + std::vector maxAttentionWindowVec = {128}; + constexpr runtime::SizeType32 nbAttentionLayers = 3; + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; + constexpr runtime::SizeType32 nbRnnLayers = 2; + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement); + } + + // Test case 4: Multiple attention window sizes - attention and rnn layers. + { + std::vector maxAttentionWindowVec = {128, 256}; + constexpr runtime::SizeType32 nbAttentionLayers = 3; + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; + constexpr runtime::SizeType32 nbRnnLayers = 2; + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); + EXPECT_EQ(result.size(), 2); + auto const nbAttentionLayersIn128Window = 2; + auto const nbAttentionLayersIn256Window = 1; + EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement); + EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement); + } +} + } // namespace tensorrt_llm::testing