From 56d2a1b58675f15b510564882452572f63bf9e07 Mon Sep 17 00:00:00 2001 From: eopXD Date: Tue, 5 Aug 2025 00:51:41 -0700 Subject: [PATCH 1/4] [KV cache manager] No functional change intended, separate `KVCacheManager::updateToken` into `addToken` and `removeToken` Given that streamLLM is broken now, removed related code logic and added assertion to guard for sinkAttention settings. This allows us to revisit the computation in the future. Main reasons of this change: Token addition and removal logic should be decoupled. Flattens callstack and nested-if. std::swap use is bizarre. Broken features should be guarded safely. Rename cache offset bookkeeping utility for clearance. These utilities are pushed downwards from `KVCacheManager` to `BlockManager`. - cacheBlockOffsets` --> `updateSequenceCacheBlockOffsets` - cacheNewBlockOffsets` --> `updateLastCacheBlockOffsets` - updateNewBlockPointer` --> `updateCacheBlockOffsetsAtIdx` Comments have been added to test cases for reminder of future test coverage on feature support. Signed-off-by: eopXD --- .../batch_manager/kvCacheManager.h | 15 +-- .../batch_manager/kvCacheManager.cpp | 97 +++++++------------ .../batch_manager/capacitySchedulerTest.cpp | 44 ++++++--- .../batch_manager/kvCacheManagerTest.cpp | 21 ++-- .../batch_manager/kvCacheUtilsTest.cpp | 1 + .../batch_manager/llmRequestTest.cpp | 6 +- 6 files changed, 92 insertions(+), 92 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a49527a6157..80c9e42623a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -1106,6 +1106,15 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex); } + //! \brief Update cache offsets for blocks initiated from sequence + void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); + + //! \brief Update cache offsets for last block + void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); + + //! \brief Update cache offsets for block at index + void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); + private: [[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const { @@ -1637,12 +1646,6 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] static SizeType32 calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock); -private: - void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); - void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); - void updateNewBlockPointer(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); - void updateToken(GenerationRequest& sequence, bool addToken); - private: // Maximum number of sequences SizeType32 mMaxNumSequences; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index d5fa982a37a..6d807464084 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1639,6 +1639,8 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} { + TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0, + "[kv cache manager] streamLLM is not supported at the moment"); TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow) != maxAttentionWindowVec.end()); // The sink tokens are stored in blocks separate from other tokens. @@ -1834,7 +1836,7 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, return (numTotalBlocksPerBeam - numAllocBlocksPerBeam) * req.mSamplingConfig.beamWidth; } -void KVCacheManager::cacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) +void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) { auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); @@ -1849,12 +1851,12 @@ void KVCacheManager::cacheBlockOffsets(GenerationRequest& sequence, SizeType32 w for (SizeType32 blockIdx = 0; blockIdx < static_cast(beamCacheBlock.size()); ++blockIdx) { auto const blockId = beamCacheBlock.at(blockIdx); - mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize); + mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } } -void KVCacheManager::cacheNewBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) +void BlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) { auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); @@ -1868,11 +1870,11 @@ void KVCacheManager::cacheNewBlockOffsets(GenerationRequest& sequence, SizeType3 auto const& beamCacheBlock = cacheBlocks[beamIdx]; auto const blockId = beamCacheBlock.back(); auto const blockIdx = static_cast(beamCacheBlock.size() - 1); - mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize); + mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } -void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) +void BlockManager::updateCacheBlockOffsetsAtIdx(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) { auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); @@ -1885,76 +1887,37 @@ void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType { auto const& beamCacheBlock = cacheBlocks[beamIdx]; auto const blockId = beamCacheBlock.at(blockIdx); - mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize); + mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } -void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken) +void KVCacheManager::addToken(RequestIdType requestId) { - auto currNumTokens = sequence.getNumTokens(); - - if (addToken) - { - sequence.addNewTokens(1); - } - else - { - sequence.removeTokens(1); - } - - auto newNumTokens = sequence.getNumTokens(); - - if (!addToken) - { - std::swap(currNumTokens, newNumTokens); - } - + // TODO: add streamLLM support + auto& sequence = getSequence(requestId); + sequence.addNewTokens(1); for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { - auto const maxTokenNum = metadata.maxTokenNum; - SizeType32 const cyclicTokenNum = maxTokenNum - mSinkBlockTokenLength; - SizeType32 const nextTokenIdxInCycle = (currNumTokens - mSinkBlockTokenLength) % cyclicTokenNum; - SizeType32 const nextTokenIdxInCache = mSinkBlockTokenLength + nextTokenIdxInCycle; - - // (nextTokenIdxInCache - mSinkBlockTokenLength) % cyclicTokenNum == 0) - // <=> nextTokenIdxInCycle == 0 - // <=> nextTokenIdxInCache == mSinkBlockTokenLength - // => nextTokenIdxInCache % getTokensPerBlock() == 0 - - // Check if require a new block - if (nextTokenIdxInCache % getTokensPerBlock() == 0) + if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0) { - if (newNumTokens <= maxTokenNum) + if (sequence.getNumTokens() <= windowSize) { - if (addToken) - { - mBlockManager.allocateBlock(sequence, windowSize); - cacheNewBlockOffsets(sequence, windowSize); - } - else - { - mBlockManager.releaseLastBlock(sequence, windowSize); - } + // Allocate new unshared blocks until the window can always + // accommodate "window size" number of tokens. + mBlockManager.allocateBlock(sequence, windowSize); + mBlockManager.updateLastCacheBlockOffsets(sequence, windowSize); } else if (sequence.getBeamWidth() > 1) { - TLLM_CHECK_WITH_INFO(addToken, "Remove token is not supported with beam search"); - // Get next block index - SizeType32 nextBlockIdx = nextTokenIdxInCache / getTokensPerBlock(); - // Replace the shared block with the unshared ones + // For beam search, shared block is replaced with unshared ones + auto const nextBlockIdx = (sequence.getNumTokens() - 1) / getTokensPerBlock(); mBlockManager.replaceSharedBlock(sequence, windowSize, nextBlockIdx); - updateNewBlockPointer(sequence, windowSize, nextBlockIdx); + mBlockManager.updateCacheBlockOffsetsAtIdx(sequence, windowSize, nextBlockIdx); } } } } -void KVCacheManager::addToken(RequestIdType requestId) -{ - auto& sequence = getSequence(requestId); - updateToken(sequence, true); -} - std::optional KVCacheManager::findNewContextBlock( VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const { @@ -2026,7 +1989,7 @@ void KVCacheManager::addSequence( } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); } - cacheBlockOffsets(sequence, windowSize); + mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); } if (llmRequest) @@ -2353,15 +2316,23 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi void KVCacheManager::removeToken(RequestIdType requestId) { + // TODO: add streamLLM support auto& sequence = getSequence(requestId); - auto const beamWidth = sequence.getBeamWidth(); - - TLLM_CHECK_WITH_INFO(beamWidth == 1, "removeToken does not support beamWidth > 1"); if (sequence.getNumTokens() == 0) { return; } - updateToken(sequence, false); + TLLM_CHECK_WITH_INFO(sequence.getBeamWidth() == 1, "[kv cache manager] removeToken does not support beamWidth > 1"); + sequence.removeTokens(1); + for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) + { + SizeType32 const maxTokensInWindow = metadata.maxTokenNum; + SizeType32 const tokensInWindow = sequence.getNumTokens() % maxTokensInWindow; + if (tokensInWindow % getTokensPerBlock() == 0 && tokensInWindow <= maxTokensInWindow) + { + mBlockManager.releaseLastBlock(sequence, windowSize); + } + } } void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths) diff --git a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp index 1e90017c8b1..24125c267d4 100644 --- a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp @@ -453,7 +453,8 @@ TEST_F(CapacitySchedulerTest, SimpleShouldFit) auto capacitySchedulerPolicies = std::vector{CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, CapacitySchedulerPolicy::kMAX_UTILIZATION, CapacitySchedulerPolicy::kSTATIC_BATCH}; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto capacitySchedulerPolicy : capacitySchedulerPolicies) { for (auto sinkTokenLen : sinkTokenLens) @@ -506,7 +507,8 @@ TEST_F(CapacitySchedulerTest, SimpleShouldFitWithCrossBlocks) auto capacitySchedulerPolicies = std::vector{CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto capacitySchedulerPolicy : capacitySchedulerPolicies) { for (auto sinkTokenLen : sinkTokenLens) @@ -550,7 +552,8 @@ TEST_F(CapacitySchedulerTest, SimpleLoraFitsDuplicateTask) auto capacitySchedulerPolicies = std::vector{CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, CapacitySchedulerPolicy::kMAX_UTILIZATION, CapacitySchedulerPolicy::kSTATIC_BATCH}; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto capacitySchedulerPolicy : capacitySchedulerPolicies) { for (auto sinkTokenLen : sinkTokenLens) @@ -594,7 +597,8 @@ TEST_F(CapacitySchedulerTest, SimpleLoraDoesntFitDuplicateTask) auto capacitySchedulerPolicies = std::vector{CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, CapacitySchedulerPolicy::kMAX_UTILIZATION, CapacitySchedulerPolicy::kSTATIC_BATCH}; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto capacitySchedulerPolicy : capacitySchedulerPolicies) { @@ -704,7 +708,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitMaxUtilization) SizeType32 maxNumRequests = 2; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -829,9 +834,13 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitPriorities) SizeType32 maxNumRequests = 2; SizeType32 maxInputLen = 1000; + // TODO: Support and add coverage for sinkTokenLen > 0 + // Removed configuration: + // {CapacitySchedulerPolicy::kMAX_UTILIZATION, 4, 125} + // {CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, 1, 160} + auto sinkTokenLens = std::vector{0}; auto configurations = std::vector>{ - {CapacitySchedulerPolicy::kMAX_UTILIZATION, 0, 119}, {CapacitySchedulerPolicy::kMAX_UTILIZATION, 4, 125}, - {CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT, 1, 160}}; + {CapacitySchedulerPolicy::kMAX_UTILIZATION, 0, 119}}; for (auto [capacitySchedulerPolicy, sinkTokenLen, expectedNumIters] : configurations) { @@ -1027,7 +1036,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitGuaranteedCompletion) SizeType32 maxNumRequests = 2; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0. (e.g. 4) + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -1080,7 +1090,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitWithCrossBlocks) auto capacitySchedulerPolicies = std::vector{CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto capacitySchedulerPolicy : capacitySchedulerPolicies) { @@ -1172,7 +1183,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitAddingNewRequestsMaxUtilization) SizeType32 maxNumRequests = 4; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -1357,7 +1369,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitAddingNewRequestsMaxUtilizationPrio SizeType32 maxNumRequests = 4; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -1504,7 +1517,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitAddingNewRequestsGuaranteedCompleti SizeType32 maxNumRequests = 4; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -1555,7 +1569,8 @@ TEST_F(CapacitySchedulerTest, SimpleDoesntFitAddingNewRequestsGuaranteedCompleti SizeType32 maxNumRequests = 4; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( @@ -1835,7 +1850,8 @@ TEST_F(CapacitySchedulerTest, SimpleFitsStaticBatch) SizeType32 maxNumRequests = 2; SizeType32 maxInputLen = 1000; - auto sinkTokenLens = std::vector{0, 4}; + // TODO: Support and add coverage for sinkTokenLen > 0 + auto sinkTokenLens = std::vector{0}; for (auto sinkTokenLen : sinkTokenLens) { auto kvCacheManager = getKvCacheManager( diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 8e58ee77f45..514575600d9 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -2408,6 +2408,7 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) TEST_P(KVCacheManagerTest, KVCacheManagerTest) { + // Full attention using DType = half; using SizeType32 = KVCacheManager::SizeType32; @@ -2655,6 +2656,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) auto constexpr tokensPerBlock = 64; auto constexpr blockLengthPerSeq = 10; auto constexpr maxNumSequences = 8; + // TODO: Support and add coverage for beamWidth > 1 auto constexpr maxBeamWidth = 1; auto constexpr sinkTokenLength = 0; auto const stream = std::make_shared(); @@ -2664,10 +2666,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1; // Enable cyclic kv cache for all new generated tokens. - auto constexpr maxAttentionWindow = inputLength; - auto constexpr numSharedBlocks = std::min(inputLength, maxAttentionWindow) / tokensPerBlock; - auto constexpr numBlocksPerSeq = numSharedBlocks + (blockLengthPerSeq - numSharedBlocks) * maxBeamWidth; + auto constexpr maxAttentionWindow = maxNumTokens; + auto constexpr numSharedBlocks = inputLength / tokensPerBlock; auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow, tokensPerBlock); + auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; auto constexpr totalNumBlocks = maxNumSequences * numBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; @@ -2757,9 +2759,9 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) } EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); - EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numBlocksPerSeq + 1); + EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth); EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); - EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numBlocksPerSeq + 1); + EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth * 2); EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); @@ -3488,8 +3490,10 @@ TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest) } } -TEST_P(KVCacheManagerTest, KVCacheManagerSinkTokenLengthTest) +TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerSinkTokenLengthTest) { + // TODO: Support sink attention and add coverage + // TODO: Support and add coverage for beamWidth > 1 using DType = half; using SizeType32 = KVCacheManager::SizeType32; @@ -3633,6 +3637,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerSinkTokenLengthTest) TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) { + // Full attention using DType = half; using SizeType32 = KVCacheManager::SizeType32; @@ -4124,6 +4129,8 @@ TEST_P(RemainingBlocksToCompletionTest, RemainingBlocksToCompletionCorrectlyEsti ASSERT_EQ(result, params.expectedRemainingBlocksToCompletion); } +// TODO: Support and add coverage for beamWidth > 1 +// TODO: Support and add coverage for sink attention INSTANTIATE_TEST_SUITE_P(RemainingBlocksToCompletionCorrectlyEstimated, RemainingBlocksToCompletionTest, ::testing::Values( GetRemainingBlocksToCompletionOneRequestParameters{ @@ -4228,6 +4235,8 @@ TEST_P(FillKvCacheAndCompleteRequestsTest, FillKvCacheAndCompleteInParallel) } } +// TODO: Support and add coverage for beamWidth > 1 +// TODO: Support and add coverage for sink attention auto const paramValues = ::testing::Values( FillKvCacheAndCompleteRequestsParameters{ KvCacheManagerInstantiationParameters{ diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp index 191c4fa2624..c4de5b6a8c6 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheUtilsTest.cpp @@ -84,6 +84,7 @@ TEST_F(BlockIteratorTest, CacheManagerTest) auto const stream = std::make_shared(); auto constexpr onboardBlocks = true; + // TODO: Support and add coverage for beamWidth > 1 auto constexpr beamWidth = 1; auto constexpr numBlocksPerBeam = blocksInPrimaryPool / beamWidth; auto constexpr maxSequenceLength = tokensPerBlock * numBlocksPerBeam; diff --git a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp index a7b10c66256..d08d2f5bc47 100644 --- a/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp @@ -748,13 +748,13 @@ TEST_P(ParamTest, createResponse) INSTANTIATE_TEST_SUITE_P(LlmRequestTest, ParamTest, testing::Combine( - // streaming - testing::Values(false, true), + // TODO: Support and add coverage for streamLLM + testing::Values(false), // excludeInputFromOutput testing::Values(false, true), // returnAllGeneratedTokens testing::Values(false, true), - // beamWdith + // beamWidth testing::Values(1, 2), // tokensPerIteration testing::Values(1, 3), From 5a968394c075bdf900d85acae1dee3060d32ba51 Mon Sep 17 00:00:00 2001 From: eopXD Date: Tue, 5 Aug 2025 02:03:16 -0700 Subject: [PATCH 2/4] [KV cache manager] No functional change intended, add comment for more context to distinguish `BlockManager::addSequence` and `WindowBlockManager::addSequence` Signed-off-by: eopXD --- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 6d807464084..93882427615 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1146,12 +1146,16 @@ void WindowBlockManager::refreshBlocks() mTransferManager->syncTransfers(); } +// There are two versions of BlockManager::addSequence function. +// This is called when block reuse is enabled. void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest); } +// There are two versions of WindowBlockManager::addSequence function. +// This is called when block reuse is enabled. void WindowBlockManager::addSequence( GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest) { @@ -1189,12 +1193,16 @@ void WindowBlockManager::addSequence( inputLength, prepopulatedPromptLen); } +// There are two versions of BlockManager::addSequence function. +// This is called when block reuse is disabled. void BlockManager::addSequence( GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).addSequence(sequence, numBlocks, unsharedBlockIdx); } +// There are two versions of WindowBlockManager::addSequence function. +// This is called when block reuse is disabled. void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx) { auto const requestId = sequence.getRequestId(); From 892906cbec9ebf5aebd85eee71afad50a00961dd Mon Sep 17 00:00:00 2001 From: eopXD Date: Mon, 11 Aug 2025 07:43:57 -0700 Subject: [PATCH 3/4] [KV cache manager] No functional change intended, simplify shared/unshared last context block logic under `KVCacheManager::addSequence` The last context block is shared when: - Operating on a cyclic kv cache and with no beam search - Operating on cross kv cache - Last context block is full Prune incorrect comment left under `KVCacheManager::getNeededBlocksOneStep`. No need to complicate things. Signed-off-by: eopXD --- .../batch_manager/kvCacheManager.h | 9 ++++- .../batch_manager/kvCacheManager.cpp | 38 ++++++------------- .../batch_manager/kvCacheManagerTest.cpp | 16 +++++--- 3 files changed, 29 insertions(+), 34 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 80c9e42623a..df526a5dfbe 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -551,7 +551,7 @@ class WindowBlockManager GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest); //! \brief Assign blocks for new sequence. Does not try to reuse blocks. - void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx); + void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock); //! \brief Allocate new block for each beam of the sequence. //! \details Might free cached blocks if no free blocks are available. @@ -869,8 +869,13 @@ class BlockManager void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest, SizeType32 windowSize); + //! \brief Assign blocks for a new sequence. + //! \param sequence The GenerationRequest to process. + //! \param numContextBlocks Number of context blocks to allocate. + //! \param windowSize Attention window size + //! \param isShareLastContextBlock If true, the last context block is shared among beams. void addSequence( - GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize); + GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock); void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 93882427615..0b793a041aa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1196,25 +1196,26 @@ void WindowBlockManager::addSequence( // There are two versions of BlockManager::addSequence function. // This is called when block reuse is disabled. void BlockManager::addSequence( - GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize) + GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock) { - mWindowBlockManagers.at(windowSize).addSequence(sequence, numBlocks, unsharedBlockIdx); + mWindowBlockManagers.at(windowSize).addSequence(sequence, numContextBlocks, isShareLastContextBlock); } // There are two versions of WindowBlockManager::addSequence function. // This is called when block reuse is disabled. -void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx) +void WindowBlockManager::addSequence( + GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock) { auto const requestId = sequence.getRequestId(); auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector{}); TLLM_CHECK(emplaceDone); - // Allocate blocks - for (SizeType32 bi = 0; bi < numBlocks; ++bi) + TLLM_CHECK_WITH_INFO(numContextBlocks > 0, "numContextBlocks must be greater than 0"); + for (SizeType32 bi = 0; bi < numContextBlocks - 1; ++bi) { - bool shareAmongBeams = bi != unsharedBlockIdx; - allocateBlock(sequence, shareAmongBeams); + allocateBlock(sequence, /*shareAmongBeams=*/true); } + allocateBlock(sequence, /*shareAmongBeams=*/isShareLastContextBlock); } void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx) @@ -1761,7 +1762,6 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep( { auto const maxTokensToAddToKVCache = req.mMaxNewTokens; auto const maxDraftTokensToAdd = std::min(req.getNumDraftTokens(), maxTokensToAddToKVCache); - // Assumes shared among beam = True auto const promptCacheLen = std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd, windowSize) + mSinkBubbleLength; @@ -1936,9 +1936,7 @@ std::optional KVCacheManager::findNewContextBlock( void KVCacheManager::addSequence( RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest) { - // Need to add the bubble after the sink tokens to use even block size - inputLength += mSinkBubbleLength; - + // TODO: add streamLLM support auto kvCacheRetentionConfig = llmRequest ? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig()) : executor::KvCacheRetentionConfig(); @@ -1963,20 +1961,6 @@ void KVCacheManager::addSequence( auto const maxTokenNum = metadata.maxTokenNum; auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow; - // Get the final token index in kv cache - SizeType32 const finalTokenKVIdx = mSinkBlockTokenLength - + ((inputLength - 1 - mSinkBlockTokenLength) % (maxTokenNum - mSinkBlockTokenLength)); - - // Get block index that with shareAmongBeams=False. - // For cross kv cache in encoder-decoder models, always shareAmongBeams=True. - SizeType32 unsharedBlockIdx = -1; - if ((!sequence.isCyclic() || beamWidth > 1 || finalTokenKVIdx % getTokensPerBlock() > 0) && !isCrossKv()) - { - unsharedBlockIdx = ((finalTokenKVIdx + 1) % getTokensPerBlock() == 0) - ? finalTokenKVIdx / getTokensPerBlock() + 1 - : finalTokenKVIdx / getTokensPerBlock(); - } - // Consider the temporaryAttentionWindow when allocating blocks. auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); @@ -1995,7 +1979,9 @@ void KVCacheManager::addSequence( "have no effect.", llmRequest->mRequestId); } - mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); + bool isShareLastContextBlock = isCrossKv() || (sequence.isCyclic() && beamWidth == 1) + || effectiveInputLength % getTokensPerBlock() == 0; + mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock); } mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 514575600d9..a52cca097a3 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -136,7 +136,7 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) auto constexpr requestId = 42; GenerationRequest seq0{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; - blockManager.addSequence(seq0, numBlocksPerBeam, numBlocksPerBeam - 1, maxAttentionWindow); + blockManager.addSequence(seq0, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/false); auto constexpr occupiedBlocks = (numBlocksPerBeam - 1) + beamWidth; EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - occupiedBlocks); auto const& ids = seq0.getCacheBlockIds(maxAttentionWindow); @@ -151,7 +151,7 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) blockManager.releaseBlocks(seq0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); - blockManager.addSequence(seq0, numBlocksPerBeam, -1, maxAttentionWindow); + blockManager.addSequence(seq0, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/true); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocksPerBeam); EXPECT_EQ(ids.size(), beamWidth); for (std::size_t i = 0u; i < ids.front().size(); ++i) @@ -165,17 +165,21 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); // occupy 22/24 blocks - EXPECT_NO_THROW(blockManager.addSequence(seq0, numBlocksPerBeam, numBlocksPerBeam - 1, maxAttentionWindow)); + EXPECT_NO_THROW( + blockManager.addSequence(seq0, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/false)); GenerationRequest seq1{requestId + 1, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; - EXPECT_NO_THROW(blockManager.addSequence(seq1, numBlocksPerBeam, numBlocksPerBeam - 1, maxAttentionWindow)); + EXPECT_NO_THROW( + blockManager.addSequence(seq1, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/false)); // same requestId not allowed GenerationRequest seq2{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; EXPECT_THROW( - blockManager.addSequence(seq2, numBlocksPerBeam, numBlocksPerBeam - 1, maxAttentionWindow), std::runtime_error); + blockManager.addSequence(seq2, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/false), + std::runtime_error); // no more blocks GenerationRequest seq3{requestId + 2, numTokens, beamWidth, blockManager.getWindowSizesMetadata()}; EXPECT_THROW( - blockManager.addSequence(seq3, numBlocksPerBeam, numBlocksPerBeam - 1, maxAttentionWindow), std::runtime_error); + blockManager.addSequence(seq3, numBlocksPerBeam, maxAttentionWindow, /*isShareLastContextBlock=*/false), + std::runtime_error); } template From 7e56c2cbfe02b38d4dd8a1ad0a228a1e8414b172 Mon Sep 17 00:00:00 2001 From: eopXD Date: Thu, 14 Aug 2025 23:53:58 -0700 Subject: [PATCH 4/4] [batch manager][unit test] No functional change intended, remove kv cache block allocation on token addition for cross kv manager For an encode-decoder model, generation tokens are added only to the decoder and not the encoder. Signed-off-by: eopXD --- cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp index 24125c267d4..0942b716c15 100644 --- a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp @@ -408,10 +408,6 @@ int runTest(CapacityScheduler& capacityScheduler, else { kvCacheManager->addToken(llmReq->mRequestId); - if (crossKvCacheManager) - { - crossKvCacheManager->addToken(llmReq->mRequestId); - } llmReq->addNewTokens({itCount}); } if (llmReq->getNumTokens(0) == promptLen + llmReq->mMaxNewTokens)