diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index cc227e75ca1..d97b87086f5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -57,6 +57,10 @@ static constexpr SizeType32 kPrimaryLevel = 0; static constexpr SizeType32 kSecondaryLevel = 1; +// Extra block buffer allocated for SWA to be able to always keep "window size" +// tokens held in the blocks. +static constexpr SizeType32 kSWAExtraBlock = 1; + class KVCacheBlock; class BlockManager; class KVCacheManager; @@ -93,8 +97,8 @@ struct WindowSizeMetadata SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager SizeType32 numPools; // number of managed pools - SizeType32 maxTokenNum; // Maximum token length (including bubble) - SizeType32 maxBlocksPerSeq; + SizeType32 maxTokenNum; // Maximum token length per sequence (TODO: account for streamLLM) + SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence. // Only needed when chunked context + sliding window attention are used @@ -344,14 +348,7 @@ class GenerationRequest , mNumTokens(numTokens) , mBeamWidth(beamWidth) , mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig)) - // min window size + sink bubble length - // Why use the minimum window size: - // Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets - // `mContextCurrentPosition` - this cannot be done for some windows sizes and - // not for others, the state needs to remain identical for all window sizes. So - // we currently resort to strictly disabling the reuse code path for all window - // sizes at once or enable it for all window sizes at once. - , mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum) + , mNumFrontBlocksRemoved(0) { auto const numWindowSizes = windowSizeToMetadata.size(); mCacheBlockIds.reserve(numWindowSizes); @@ -394,6 +391,11 @@ class GenerationRequest return mNumTokens; } + [[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const + { + return mNumFrontBlocksRemoved; + } + [[nodiscard]] SizeType32 getBeamWidth() const { return mBeamWidth; @@ -431,6 +433,12 @@ class GenerationRequest { beamBlockIds.clear(); } + mNumFrontBlocksRemoved = 0; + } + + void removeFrontBlock(SizeType32 windowSize) + { + ++mNumFrontBlocksRemoved; } void removeLastBlock(SizeType32 windowSize) @@ -461,14 +469,6 @@ class GenerationRequest return mKvCacheRetentionConfig.getDirectory(); } - // @brief Check whether the sequence uses cyclic KV cache. - // @return `true` if we have begun overwriting the beginning of the sequence's KV cache. - // @details If `true`, we cannot store the sequence's KV cache for reuse. - [[nodiscard]] bool isCyclic() const - { - return mNumTokens >= mCyclicThreshold; - } - private: // Request id of the sequence LlmRequest::RequestIdType mRequestId; @@ -482,9 +482,8 @@ class GenerationRequest std::unordered_map mCacheBlockIndices; // The retention priority to assign to decode blocks executor::KvCacheRetentionConfig mKvCacheRetentionConfig; - - // Number of tokens at which the KV Cache begins sliding [for the minimum attention window] - SizeType32 mCyclicThreshold; + // Number of front blocks removed from the sequence + SizeType32 mNumFrontBlocksRemoved; }; // attach metadata to a pool pointer @@ -550,7 +549,7 @@ class WindowBlockManager explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize, std::vector const& managedLayers, std::vector const& numKvHeadsPerLayer, - SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, + SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, @@ -581,19 +580,32 @@ class WindowBlockManager //! \brief Get the ids of all newly allocated (not reused) blocks for the sequence. std::vector getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const; - void storeBlocksForReuse(GenerationRequest& sequence, OptionalRef llmRequest); - void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); //! \brief Release blocks of the sequence. - void releaseBlocks(GenerationRequest& sequence); + //! \details When llmRequest is provided and reuse is enabled, blocks will be stored. + void releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest); //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); + //! \brief Update cache offsets for last block + void updateLastCacheBlockOffsets(GenerationRequest& seq); + //! \brief Release last block in the sequence void releaseLastBlock(GenerationRequest& sequence); + //! \brief Detach front block from the sequence + void detachFrontBlock(GenerationRequest& sequence, bool isEnableBlockReuse); + + //! \brief Add/detach block(s) to/from the sequence if needed + //! \details When we need a new block, we add it. For sliding window + //! attention (SWA), when a block goes out-of-window (OOW), we detach it + //! and store it if reuse is enabled. If this called in the first step of + //! the generation phase, we may detach more than a single block since + //! there may be more than one context block that goes OOW. + void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse); + [[nodiscard]] SizeType32 getWindowSize() const noexcept { return mWindowSize; @@ -745,7 +757,8 @@ class WindowBlockManager //! \brief Store blocks in cached blocks. //! \param blockKeys Key of each block. //! \param blockIds Id of each block. - void storeBlocks(std::vector const& blockKeys, std::vector const& blockIds); + //! \return Number of actual blocks stored. + SizeType32 storeBlocks(std::vector const& blockKeys, std::vector const& blockIds); [[nodiscard]] bool verifyQueueIntegrity(); @@ -767,6 +780,12 @@ class WindowBlockManager return 0; } + //! \brief Return whether this window is SWA. + [[nodiscard]] bool isSWA() const + { + return mIsSWA; + } + private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx); @@ -828,6 +847,8 @@ class WindowBlockManager SizeType32 mSchedulingNumFreeBlocks; // Number of tokens per one block SizeType32 mTokensPerBlock; + // Whether this window is sliding window attention/full attention + bool mIsSWA; // List of all blocks by idx std::vector mAllBlocksById; // Dummy block acting as root for BlockToken searches @@ -880,7 +901,7 @@ class BlockManager explicit BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, - CudaStreamPtr stream, std::optional maxSequenceLength, SizeType32 maxBeamWidth, + CudaStreamPtr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, @@ -1128,14 +1149,6 @@ class BlockManager //! \brief Store newest block for reuse void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); - [[nodiscard]] static bool isUseOneMoreBlock( - SizeType32 windowSize, std::optional maxSequenceLength, SizeType32 maxBeamWidth) - { - bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize; - bool const isBeamSearch = maxBeamWidth > 1; - return isCyclicWindowSize && isBeamSearch; - } - //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1154,12 +1167,17 @@ class BlockManager //! \brief Update cache offsets for blocks initiated from sequence void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); - //! \brief Update cache offsets for last block - void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize); - //! \brief Update cache offsets for block at index void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx); + //! \brief Add/detach block(s) to/from the sequence if needed + //! \details When we need a new block, we add it. For sliding window + //! attention (SWA), when a block goes out-of-window (OOW), we detach it + //! and store it if reuse is enabled. If this called in the first step of + //! the generation phase, we may detach more than a single block since + //! there may be more than one context block that goes OOW. + void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse); + private: [[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const { @@ -1411,8 +1429,8 @@ class KVCacheManager : public BaseKVCacheManager BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false, + bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1422,8 +1440,8 @@ class KVCacheManager : public BaseKVCacheManager BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false, + bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1433,8 +1451,8 @@ class KVCacheManager : public BaseKVCacheManager BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = true, + bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1444,9 +1462,9 @@ class KVCacheManager : public BaseKVCacheManager BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false, + bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); ~KVCacheManager() override = default; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 81a4746467d..d41a373adfa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -157,6 +157,39 @@ std::vector buildBlockKeys( return blockKeys; } +//! \brief Get all blocks in a sequence by traversing backwards from the last block. +//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from +//! \return Vector of BlockPtr-s in sequence order +std::vector getAllSequenceBlocks(BlockPtr lastBlock) +{ + // First count the number of blocks to pre-allocate the vector + auto currentBlock = lastBlock; + size_t blockCount = 0; + while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + blockCount++; + currentBlock = currentBlock->getPrevBlockInSeq(); + } + + if (blockCount == 0) + { + return {}; + } + // Create and pre-allocate the vector with the correct size + std::vector sequenceBlocks(blockCount); + + // Now traverse backwards and fill from the end + currentBlock = lastBlock; + size_t currentIndex = blockCount - 1; + while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + sequenceBlocks[currentIndex--] = currentBlock; + currentBlock = currentBlock->getPrevBlockInSeq(); + } + + return sequenceBlocks; +} + } // namespace namespace tensorrt_llm::batch_manager::kv_cache_manager @@ -468,6 +501,19 @@ bool KVCacheBlock::isLeaf() const return mNextBlocks.empty(); } +// This function calculates the number of block a layer should have, given +// the total free memory and the window size of each layer. +// For example, if we have 1 layer of window size 1024, and 2 layer of window +// size 2048, and 3 layers of 4096. +// Each layer of window size 1024 should have +// 1024 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. +// Each layer of window size 2048 should have +// 2048 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. +// Each layer of window size 4096 should have +// 4096 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks. +// NOTE: Currently the use of this function is not used for +// BaseKVCacheManager::calculateMaxNumBlocks because the we want to first +// achieve identical performance as assuming all layers as full attention. std::map BlockManager::calculateWindowSizeToShare( std::map> const& windowSizeToLayers, std::map const& windowSizeToCacheSizePerToken) @@ -510,7 +556,7 @@ std::map BlockManager::calculateWindowSizeToShare( BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, - std::shared_ptr stream, std::optional maxSequenceLength, SizeType32 maxBeamWidth, + std::shared_ptr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, @@ -543,6 +589,11 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si mLayerToWindowSize.resize(mNumLayers); for (auto const& [windowSize, layersWithWindowSize] : uniqueWindowSizeToLayers) { + if (windowSize > maxSequenceLength) + { + TLLM_LOG_WARNING("[kv cache manager] window size %d is greater than max sequence length %d", windowSize, + maxSequenceLength); + } for (auto& layerIdx : layersWithWindowSize) { mLayerToWindowSize.at(layerIdx) = windowSize; @@ -550,9 +601,9 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize); TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, - sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, - onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, - copyOnPartialReuse, kvCacheConnectorManager, mLoopbackAgent); + sizePerHead, tokensPerBlock, /*isSWA=*/windowSize < maxSequenceLength, allottedPrimaryBlocks, + allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority, + mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, mLoopbackAgent); } auto const numAllPools = getNumPools(); @@ -567,15 +618,27 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si mAbsolutePoolToWindowSize.push_back(windowSize); mAbsolutePoolToRelativePoolIndex.push_back(i); } - auto const maxTokenNum = windowSize + sinkBubbleLength - + (isUseOneMoreBlock(windowSize, maxSequenceLength, maxBeamWidth) ? tokensPerBlock : 0); + // (eop) SWA allocates blocks linearly, and we need as many blocks as full attention, + // where full attention has windowSize = maxSequenceLength. + auto const maxTokenNum = std::max(windowSize, maxSequenceLength) + sinkBubbleLength; auto const temporaryAttentionWindow = manager.calculateTemporaryAttentionWindow(tempAttentionWindowInputs); // Consider the temporaryAttentionWindow when allocating blocks. - auto const maxBlocksPerSeq = tc::ceilDiv(maxTokenNum + temporaryAttentionWindow, tokensPerBlock); + // (eop) Current tempAttentionWindow calculation does not consider the + // concept of SWA right now at most occupying maxSequenceLength of + // blocks. So the calculation of maxToken + tempAttention will exceed + // maxSequenceLength. A temporary resolution here is to cap the + // calculation to maxSequenceLength. I will proceed with a follow-up + // MR to remove the tempAttentionWindow concept. + auto const maxBlocksPerSeq + = tc::ceilDiv(std::min(maxSequenceLength, maxTokenNum + temporaryAttentionWindow), tokensPerBlock); auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize); mWindowSizeToMetadata[windowSize] = WindowSizeMetadata{allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, manager.getMaxNumBlocks(), temporaryAttentionWindow}; + TLLM_LOG_INFO( + "Max KV cache blocks per sequence: %d [window size=%d], tokens per block=%d, primary blocks=%d, secondary " + "blocks=%d", + maxBlocksPerSeq, windowSize, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks); TLLM_LOG_DEBUG( "%s Metadata: %s", manager.getLogPrefix().c_str(), mWindowSizeToMetadata[windowSize].toString().c_str()); absolutePoolsOffset += numPools; @@ -591,9 +654,9 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize, std::vector const& managedLayers, std::vector const& numKvHeadsPerLayer, - SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, - SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool, + SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, + bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager, std::shared_ptr loopbackAgent) @@ -605,6 +668,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mBufferManager{std::move(stream)} , mSchedulingNumFreeBlocks{0} , mTokensPerBlock{tokensPerBlock} + , mIsSWA{isSWA} , mCachedBlocksRoot{std::make_shared(KVCacheBlock::kCachedBlocksRootId, tk::KVCacheIndex{0})} , mCacheType{cacheType} , mEventManager(std::move(eventManager)) @@ -986,6 +1050,12 @@ void BlockManager::offloadBlock( void WindowBlockManager::offloadBlock( BlockPtr const& block, executor::KvCacheTransferMode mode, std::string const& directory) { + // The current default behavior is to offload the out-of-window block + // to secondary block pool to allow more free primary blocks for reuse. + // However, such behavior does not take account whether the offloaded + // block is useful or not and may just lead to more traffic instead. + // The ideal way of this is to dedicate the offloading of the block + // to the eviction policy. if (mOnboardBlocks && block->isPrimary()) { // Offload block in primary memory before repurposing @@ -1252,6 +1322,35 @@ void WindowBlockManager::addSequence( llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens); } +void BlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse) +{ + for (auto& [windowSize, manager] : mWindowBlockManagers) + { + mWindowBlockManagers.at(windowSize).adjustBlocksIfNeeded(sequence, isEnableBlockReuse); + } +} + +void WindowBlockManager::adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse) +{ + auto const minTokensForBlockDetach = mWindowSize + mTokensPerBlock; + while ( + sequence.getNumTokens() - sequence.getNumFrontBlocksRemoved() * getTokensPerBlock() >= minTokensForBlockDetach + && !isEnableBlockReuse) + { + // Detaching block for SWA is non-trivial due to the radix tree structure. + // For now, when reuse is enabled, we do not detach blocks for SWA. + TLLM_CHECK_WITH_INFO(mIsSWA, "A block only go out-of-window in SWA"); + detachFrontBlock(sequence, isEnableBlockReuse); + } + + if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0) + { + // Allocating a new block when the last token is a block boundary + allocateBlock(sequence, /*shareAmongBeams=*/sequence.getBeamWidth() == 1); + updateLastCacheBlockOffsets(sequence); + } +} + // There are two versions of BlockManager::addSequence function. // This is called when block reuse is disabled. void BlockManager::addSequence( @@ -1344,9 +1443,10 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm } } -void WindowBlockManager::storeBlocks( +SizeType32 WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds) { + SizeType32 numBlocksStoredForReuse = 0; TLLM_LOG_DEBUG( "%s::storeBlocks - %zu blockKeys, %zu blockIds", mLogPrefix.c_str(), blockKeys.size(), blockIds.size()); @@ -1398,12 +1498,14 @@ void WindowBlockManager::storeBlocks( block->setHash(newHash); } searchRoot = block; + numBlocksStoredForReuse++; } } if (mEventManager) { mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } + return numBlocksStoredForReuse; } void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) @@ -1516,44 +1618,29 @@ std::deque BlockManager::getLatestEvents(std::optional llmRequest) { - // When releasing the blocks for a sequence, we store those blocks for potential reuse only if: - // - Block reuse is enabled. - // - A request was provided to this function call to identify which tokens these blocks cover - // - Beam search is NOT enabled <=> beam width == 1 - // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit - // the max attention window). - // - The sequence did not switch to cyclic kv-cache during generation phase. - // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. - // - The sequence is not a dummy request. - bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic() - && !llmRequest->isDummyRequest(); + // Released block will be stored when reuse is enabled. + // Reuse is implied to be enabled if llmRequest is provided. for (auto& [_, manager] : mWindowBlockManagers) { - if (storeBlocksForReuse) + if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1) { - manager.storeBlocksForReuse(sequence, llmRequest); + manager.releaseBlocks(sequence, std::nullopt); + } + else + { + manager.releaseBlocks(sequence, llmRequest); } - manager.releaseBlocks(sequence); } } void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { - // we store newest block for potential reuse only if: - // - Block reuse is enabled. - // - A request was provided to this function call to identify which tokens these blocks cover - // - Beam search is NOT enabled <=> beam width == 1 - // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit - // the max attention window). - // - The sequence did not switch to cyclic kv-cache during generation phase. - // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. - bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic(); - if (!storeBlocksForReuse) - { - return; - } for (auto& [_, manager] : mWindowBlockManagers) { + if (manager.isSWA()) + { + continue; + } manager.storeNewBlock(sequence, llmRequest); } } @@ -1608,33 +1695,46 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } -void WindowBlockManager::storeBlocksForReuse(GenerationRequest& sequence, OptionalRef llmRequest) -{ - auto constexpr beamIdx = 0; - auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); - - // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't - // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume - // the last token's state is not filled yet. - auto const usableSize = static_cast(uniqueTokens.size()) - 1; - auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); -} - -void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) +void WindowBlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest) { auto const requestId = sequence.getRequestId(); auto node = mAllocatedBlocksPerSeq.extract(requestId); TLLM_CHECK(node); auto& allocatedBlocks = node.mapped(); - for (auto it = allocatedBlocks.rbegin(); it != allocatedBlocks.rend(); ++it) + if (mIsSWA) + { + // For SWA, get all blocks in the sequence. + allocatedBlocks = getAllSequenceBlocks(allocatedBlocks.back()); + } + if (llmRequest.has_value()) + { + // If llmRequest is provided, store the blocks for reuse. + auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0); + // Only (length - 1) tokens of the sequence have their kv-state + // recorded in kv-cache. We assume the last token's state is not filled yet. + auto const usableSize = static_cast(uniqueTokens.size()) - 1; + auto blockedUniqueTokens + = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true); + auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); + + std::vector cacheBlockIds(allocatedBlocks.size()); + std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), + [](BlockPtr const& block) { return block->getBlockId(); }); + + auto numBlocksStoredForReuse = storeBlocks(std::move(blockKeys), cacheBlockIds); + TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), + sequence.getRequestId(), numBlocksStoredForReuse); + } + for (auto it = allocatedBlocks.rbegin(); it != allocatedBlocks.rend() - sequence.getNumFrontBlocksRemoved(); ++it) { auto& block = *it; // Decrease ref count - block->decRefCount(); + if (block->hasRefs()) + { + // An out-of-window block may not have any ref count. + block->decRefCount(); + } // If ref count is zero, move block to free blocks if (!block->hasRefs()) { @@ -1671,8 +1771,8 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse) + SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, + bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, @@ -1684,9 +1784,8 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + SizeType32 sinkTokenLength, int64_t stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, + bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, @@ -1701,9 +1800,8 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, + bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : mMaxBeamWidth(maxBeamWidth) @@ -1736,9 +1834,8 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + SizeType32 sinkTokenLength, CudaStreamPtr stream, runtime::SizeType32 maxSequenceLength, bool enableBlockReuse, + bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, @@ -1834,10 +1931,11 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep( { if ((req.isContextInitState() && req.isFirstContextChunk()) || req.isDisaggGenerationInitState()) { - auto const maxTokensToAddToKVCache = req.mMaxNewTokens; - auto const maxDraftTokensToAdd = std::min(req.getNumDraftTokens(), maxTokensToAddToKVCache); + auto const chunkSize = req.mMaxNewTokens; + auto const maxDraftTokensToAdd = req.getNumDraftTokens(); auto const promptCacheLen - = std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd, windowSize) + = std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd, + windowSize + chunkSize) + mSinkBubbleLength; auto const numSharedBlocks = promptCacheLen / getTokensPerBlock(); auto const numUnSharedTokens = promptCacheLen % getTokensPerBlock(); @@ -1910,12 +2008,26 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req, } } - if (numAllocBlocksPerBeam < numContextBlocks) + // In case of sliding window attention, a new block is allocated when the + // window slides (and then the out-of-window block is detached). So we + // need an extra block for generation if the diff between the max sequence + // length and the current sequence length crosses both a block boundary + // and a window boundary. + auto const isSlidingWindow = (req.mPromptLen + req.mMaxNewTokens) > windowSize; + SizeType32 const currentSeqlenInBlocks = tc::ceilDiv(req.getNumTokens(0), getTokensPerBlock()); + SizeType32 const maxSeqlenInBlocks = tc::ceilDiv(req.mPromptLen + req.mMaxNewTokens, getTokensPerBlock()); + auto const willCrossBlockBoundary = maxSeqlenInBlocks > currentSeqlenInBlocks; + auto const willCrossWindowBlockBoundary = maxSeqlenInBlocks > numTotalBlocksPerBeam; + SizeType32 numExtraBlocksPerBeam + = isSlidingWindow && willCrossBlockBoundary && willCrossWindowBlockBoundary ? 1 : 0; + + if (numAllocBlocksPerBeam < numContextBlocks) // Still haven't allocated all context blocks { - return numContextBlocks - numAllocBlocksPerBeam + numGenBlocksPerBeam * req.mSamplingConfig.beamWidth; + return numContextBlocks - numAllocBlocksPerBeam + + (numGenBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth; } - return (numTotalBlocksPerBeam - numAllocBlocksPerBeam) * req.mSamplingConfig.beamWidth; + return (numTotalBlocksPerBeam - numAllocBlocksPerBeam + numExtraBlocksPerBeam) * req.mSamplingConfig.beamWidth; } void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) @@ -1938,10 +2050,10 @@ void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, } } -void BlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize) +void WindowBlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence) { - auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize); - auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize); + auto const& cacheBlocks = sequence.getCacheBlockIds(mWindowSize); + auto& cacheBlocksTensor = sequence.getCacheBlockIndices(mWindowSize); auto const beamWidth = sequence.getBeamWidth(); auto* offsetsPtr = bufferCast(cacheBlocksTensor); @@ -1952,7 +2064,7 @@ void BlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence, Size auto const& beamCacheBlock = cacheBlocks[beamIdx]; auto const blockId = beamCacheBlock.back(); auto const blockIdx = static_cast(beamCacheBlock.size() - 1); - mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); + setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } } @@ -1978,26 +2090,38 @@ void KVCacheManager::addToken(RequestIdType requestId) // TODO: add streamLLM support auto& sequence = getSequence(requestId); sequence.addNewTokens(1); - for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) + mBlockManager.adjustBlocksIfNeeded(sequence, mEnableBlockReuse); +} + +void WindowBlockManager::detachFrontBlock(GenerationRequest& sequence, bool const isEnableBlockReuse) +{ + // streamLLM is not supported at the moment. The out of window block will + // always be the 0th block. + TLLM_CHECK_WITH_INFO( + sequence.getBeamWidth() == 1, "[kv cache manager] detachBlock does not support beamWidth > 1 now."); + + auto const requestId = sequence.getRequestId(); + auto const beamWidth = sequence.getBeamWidth(); + auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); + SizeType32 outOfWindowBlockIdx = 0; + + for (auto beamIdx = 0; beamIdx < beamWidth; ++beamIdx) { - if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0) + auto outOfWindowBlock = allocatedBlocks.at(outOfWindowBlockIdx * beamWidth + beamIdx); + + outOfWindowBlock->decRefCount(); + + if (!outOfWindowBlock->hasRefs()) { - if (sequence.getNumTokens() <= windowSize) - { - // Allocate new unshared blocks until the window can always - // accommodate "window size" number of tokens. - mBlockManager.allocateBlock(sequence, windowSize); - mBlockManager.updateLastCacheBlockOffsets(sequence, windowSize); - } - else if (sequence.getBeamWidth() > 1) - { - // For beam search, shared block is replaced with unshared ones - auto const nextBlockIdx = (sequence.getNumTokens() - 1) / getTokensPerBlock(); - mBlockManager.replaceSharedBlock(sequence, windowSize, nextBlockIdx); - mBlockManager.updateCacheBlockOffsetsAtIdx(sequence, windowSize, nextBlockIdx); - } + // For now, OOW block is not released when reused is enabled. + mEvictionPolicy->releaseBlock(outOfWindowBlock); } } + + // Disconnect first block from sequence and remove it from allocated blocks + sequence.removeFrontBlock(mWindowSize); + allocatedBlocks.erase(allocatedBlocks.begin() + outOfWindowBlockIdx * beamWidth, + allocatedBlocks.begin() + (outOfWindowBlockIdx + 1) * beamWidth); } std::optional KVCacheManager::findNewContextBlock( @@ -2032,13 +2156,14 @@ void KVCacheManager::addSequence( for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { + // NOTE: Caller to KVCacheManager::addSequence should deal with the chunking auto const maxTokenNum = metadata.maxTokenNum; auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow; // Consider the temporaryAttentionWindow when allocating blocks. auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow); auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock()); - if (!sequence.isCyclic() && mEnableBlockReuse) + if (mEnableBlockReuse) { mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize); } @@ -2053,8 +2178,7 @@ void KVCacheManager::addSequence( "have no effect.", llmRequest->mRequestId); } - bool isShareLastContextBlock = isCrossKv() || (sequence.isCyclic() && beamWidth == 1) - || effectiveInputLength % getTokensPerBlock() == 0; + bool isShareLastContextBlock = isCrossKv() || effectiveInputLength % getTokensPerBlock() == 0; mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock); } mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize); @@ -2077,22 +2201,29 @@ void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest) if (mSequences.find(requestId) != mSequences.end()) { auto& sequence = getSequence(requestId); - if (mEnableBlockReuse && !sequence.isCyclic() && !llmRequest.isDummyRequest()) + if (mEnableBlockReuse && !llmRequest.isDummyRequest()) { mBlockManager.storeContextBlocks(sequence, llmRequest); } } + else + { + TLLM_LOG_WARNING("[kv cache manager] storeContextBlocks: Can not find sequence for request %lu", requestId); + } } void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) { + // We store newest block for potential reuse only if: + // - Beam search is NOT enabled + // - Block reuse is enabled. auto const requestId = llmRequest.mRequestId; auto& sequence = getSequence(requestId); - bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && !sequence.isCyclic(); - if (mEnableBlockReuse && storeBlocksForReuse) + if (sequence.getBeamWidth() > 1 || !mEnableBlockReuse) { - mBlockManager.storeNewBlock(sequence, llmRequest); + return; } + mBlockManager.storeNewBlock(sequence, llmRequest); } void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef llmRequest) @@ -2105,7 +2236,6 @@ void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef windowSizeToShare; + // NOTE: Righteously, blocks allocated should be proportional with + // regard to window size. Currently, we are first allocating identical + // number of blocks for all layers to achieve identical performance. + for (auto const& [windowSize, _] : windowSizeToLayers) + { + windowSizeToShare[windowSize] = 1.0f / windowSizeToLayers.size(); + } std::vector blocksPrimary; std::vector blocksSecondary; @@ -2394,9 +2530,8 @@ void KVCacheManager::removeToken(RequestIdType requestId) sequence.removeTokens(1); for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata()) { - SizeType32 const maxTokensInWindow = metadata.maxTokenNum; - SizeType32 const tokensInWindow = sequence.getNumTokens() % maxTokensInWindow; - if (tokensInWindow % getTokensPerBlock() == 0 && tokensInWindow <= maxTokensInWindow) + SizeType32 const tokensInWindow = sequence.getNumTokens() % windowSize; + if (tokensInWindow % getTokensPerBlock() == 0) { mBlockManager.releaseLastBlock(sequence, windowSize); } @@ -2488,7 +2623,12 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirementsPerBeam( auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock); auto const actualSeqLen = std::min(sequenceLength, windowSize); auto actualMaxTokenNum = actualSeqLen + sinkBubbleLength; - return tc::ceilDiv(actualMaxTokenNum, tokensPerBlock); + auto numBlocks = tc::ceilDiv(actualMaxTokenNum, tokensPerBlock); + if (sequenceLength > windowSize) + { + numBlocks += kSWAExtraBlock; + } + return numBlocks; } SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength, @@ -2502,12 +2642,11 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, wholeSequenceLength, sinkTokenLength, windowSize, tokensPerBlock); } - // If the whole attention window can fit in the output, then we can simply multiply the cost of a sequence of - // length max attention window by the beam width. if (windowSize <= outputLength) { + // We at most will need outputLength of distinct blocks for SWA return KVCacheManager::calculateMaxBlockRequirementsPerBeam( - windowSize, sinkTokenLength, windowSize, tokensPerBlock) + outputLength, sinkTokenLength, windowSize, tokensPerBlock) * beamWidth; } @@ -2518,7 +2657,11 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength, auto const sinkBubbleLength = BaseKVCacheManager::getSinkBubbleLength(sinkTokenLength, tokensPerBlock); auto const numContextBlocks = (numContextTokensInAttentionWindow + sinkBubbleLength) / tokensPerBlock; auto const leftoverContextToken = numContextTokensInAttentionWindow - numContextBlocks * tokensPerBlock; - auto const numOutputBlocks = tc::ceilDiv(outputLength + leftoverContextToken, tokensPerBlock); + auto numOutputBlocks = tc::ceilDiv(outputLength + leftoverContextToken, tokensPerBlock); + if (wholeSequenceLength > windowSize) + { + numOutputBlocks += kSWAExtraBlock; + } return numContextBlocks + numOutputBlocks * beamWidth; } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 116091670d1..c54e02642ca 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -352,15 +352,13 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr TrtGptModelInflightBatching::c auto const tokensPerBlock = mModelConfig.getTokensPerBlock(); auto const kvDtype = mModelConfig.getKvDataType(); - bool enableCyclicKvCache = false; - for (SizeType32 maxAttenWin : getMaxAttentionWindowVec()) - { - if (maxAttenWin != getMaxSequenceLen()) - { - enableCyclicKvCache = true; - break; - } - } - // Below assertion should be removed once SWA/VSWA is no longer cyclic. - TLLM_CHECK_WITH_INFO( - getMaxBeamWidth() == 1 || !enableCyclicKvCache, "Can't support cyclic kv cache with beam search."); - // init KV cache block manager auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = mModelConfig.getNumKvHeadsPerLayerLocalRange( mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank(), isCrossAttention); @@ -702,7 +687,8 @@ std::unique_ptr TrtGptModelInflightBatching::c auto kvCacheManager = std::make_unique(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, getMaxNumSequences(), getMaxBeamWidth(), maxAttentionWindowVec, tempAttentionWindowInputs, - kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse, + kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), + kvCacheType == KvCacheType::kCROSS ? mModelConfig.getMaxEncoderLen() : getMaxSequenceLen(), enableBlockReuse, kvCacheConfig.getOnboardBlocks(), kvCacheType, kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique(kvCacheConfig.getEventBufferMaxSize()) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index c3bccf87b47..68c719fb687 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -483,13 +483,12 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .value("SELFKONLY", tbk::CacheType::kSELFKONLY); nb::class_(m, "KVCacheManager") - .def( - nb::init const&, SizeType32, SizeType32, - std::map> const&, SizeType32, SizeType32, - std::vector const&, std::optional const&, - nvinfer1::DataType, SizeType32, int64_t, std::optional, bool, bool, tbk::CacheType, - std::optional, std::shared_ptr, - bool, bool, std::shared_ptr>(), + .def(nb::init const&, SizeType32, SizeType32, + std::map> const&, SizeType32, SizeType32, + std::vector const&, std::optional const&, + nvinfer1::DataType, SizeType32, int64_t, runtime::SizeType32, bool, bool, tbk::CacheType, + std::optional, std::shared_ptr, + bool, bool, std::shared_ptr>(), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"), nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"), nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"), diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index d5e57797b77..1bc13959940 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -58,7 +58,7 @@ class CacheTransBufferTest : public ::testing::Test mCacheManager = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, cacheType, + std::nullopt, dataType, sinkTokenLength, stream, kvMaxNumTokens, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, true); mCacheManager->allocatePools(false); diff --git a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp index 0942b716c15..58f1f7d4fa0 100644 --- a/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp @@ -135,7 +135,7 @@ class CapacitySchedulerTest : public ::testing::Test // NOLINT(cppcoreguidelines // init KV cache block manager return std::make_shared(numLayers, nbKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumRequests, 1, std::vector{maxNumTokensPerSeq}, std::nullopt, kvDtype, - sinkTokenLength, streamPtr, std::nullopt, enableReuse, onboardBlocks, cacheType); + sinkTokenLength, streamPtr, maxNumTokensPerSeq, enableReuse, onboardBlocks, cacheType); } static std::shared_ptr getPeftCacheManager() diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index ce572f93360..3e266d0cd14 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -615,7 +615,7 @@ TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kFP4, - false, stream, true, onboardBlocks); + false, stream, maxAttentionWindow, true, onboardBlocks); kvCacheManager.allocatePools(/*useUvm=*/false); @@ -705,8 +705,9 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); - llmRequest1->addNewToken(9, beamIdx); // block 3 contains [8] - llmRequest1->addNewToken(10, beamIdx); // block 3 contains [8, 9] + // at this point, block 3 contains [8] + llmRequest1->addNewToken(9, beamIdx); // block 3 contains [8, 9] + llmRequest1->addNewToken(10, beamIdx); // block 3 contains [8, 9, 10] EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); @@ -1958,6 +1959,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; + auto constexpr maxSequenceLength = maxBlocksPerSeq * tokensPerBlock; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 16; auto constexpr blocksInSecondaryPool = 0; @@ -1975,7 +1977,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8}); @@ -2118,6 +2120,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 8; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; @@ -2135,7 +2138,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); auto const& blockManager = kvCacheManager.getBlockManager(); @@ -2224,6 +2227,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; @@ -2241,7 +2245,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2292,6 +2296,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; @@ -2309,7 +2314,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); { auto inputTokens0 = std::make_shared(VecTokens{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2370,6 +2375,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; auto constexpr maxNumSequences = 8; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 4; auto constexpr onboardBlocks = true; @@ -2386,7 +2392,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2445,6 +2451,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 0; @@ -2461,7 +2468,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3}); @@ -2522,6 +2529,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) auto constexpr sizePerHead = 16; auto constexpr tokensPerBlock = 4; auto constexpr maxBlocksPerSeq = 4; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 1; @@ -2540,7 +2548,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + 0, stream, maxSequenceLength, true, onboardBlocks); kvCacheManager.allocatePools(false); // Create sequence with one block worth of context tokens @@ -2638,9 +2646,9 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - auto constexpr maxAttentionWindow = maxNumTokens; - auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto constexpr maxAttentionWindow = maxSequenceLength; + auto constexpr inputLength = maxSequenceLength - tokensPerBlock - 1; auto constexpr numSharedBlocks = inputLength / tokensPerBlock; auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; @@ -2659,10 +2667,10 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) : KVCacheManager(std::vector(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, + std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); auto const& blockManager = kvCacheManager.getBlockManager(); @@ -2701,9 +2709,9 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) auto const stream = std::make_shared(); auto constexpr requestId = 7; - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - auto constexpr maxAttentionWindow = maxNumTokens; - auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto constexpr maxAttentionWindow = maxSequenceLength; + auto constexpr inputLength = maxSequenceLength - tokensPerBlock - 1; auto constexpr numSharedBlocks = inputLength / tokensPerBlock; auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; @@ -2720,10 +2728,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -2816,7 +2824,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth); EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); - EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numBlocksPerSeq); + EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth * 2); EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); @@ -2850,9 +2858,9 @@ TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) auto const stream = std::make_shared(); auto constexpr requestId = 7; - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - auto constexpr maxAttentionWindow = maxNumTokens; - auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto constexpr maxAttentionWindow = maxSequenceLength; + auto constexpr inputLength = maxSequenceLength - tokensPerBlock - 1; auto constexpr numSharedBlocks = inputLength / tokensPerBlock; auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; @@ -2868,10 +2876,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) : KVCacheManager(std::vector(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, + std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); kvCacheManager.allocatePools(false); @@ -2925,7 +2933,6 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) std::map const expectedHeadsPerPool({{0, 1}, {1, 2}, {2, 3}}); std::map const expectedLayersPerPool({{0, 1}, {1, 2}, {2, 1}}); auto constexpr sizePerHead = 64; - auto constexpr hiddenSize = numHeads * sizePerHead; auto constexpr tokensPerBlock = 64; auto constexpr blockLengthPerSeq = 10; auto constexpr maxNumSequences = 8; @@ -2935,14 +2942,13 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) auto const stream = std::make_shared(); auto constexpr requestId = 7; - auto constexpr maxNumTokens = tokensPerBlock * blockLengthPerSeq; + auto constexpr maxSequenceLength = tokensPerBlock * blockLengthPerSeq; + auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock); - auto constexpr inputLength = maxNumTokens - tokensPerBlock - 1; - // Enable cyclic kv cache for all new generated tokens. - auto constexpr maxAttentionWindow = maxNumTokens; + auto constexpr inputLength = maxSequenceLength - tokensPerBlock - 1; + auto constexpr maxAttentionWindow = inputLength; // sliding window attention auto constexpr numSharedBlocks = inputLength / tokensPerBlock; - auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow, tokensPerBlock); - auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; + auto constexpr numBlocksPerSeq = numSharedBlocks + (blockLengthPerSeq - numSharedBlocks) * maxBeamWidth; auto constexpr totalNumBlocks = maxNumSequences * numBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; @@ -2957,10 +2963,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -3050,7 +3056,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), maxNumSequences); } -TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowWithReuseTest) +TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowSmallerThanBlockSizeTest) { auto constexpr numLayers = 2; auto constexpr numHeads = 2; @@ -3060,27 +3066,29 @@ TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowWithReuseTest) auto constexpr maxBeamWidth = 1; auto constexpr sinkTokenLength = 0; auto const stream = std::make_shared(); + auto constexpr maxSequenceLength = 128; - // Enable cyclic kv cache for long input tokens. - auto constexpr maxAttentionWindow = 16; - auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow, tokensPerBlock); + // Enable sliding window kv cache for long input tokens. + auto constexpr maxAttentionWindow = 3; + auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock); auto constexpr blocksInPrimaryPool = 16; auto constexpr blocksInSecondaryPool = 0; - auto constexpr enableBlockReuse = true; + auto constexpr enableBlockReuse = false; auto constexpr onboardBlocks = true; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; - KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); kvCacheManager.allocatePools(false); auto const& blockManager = kvCacheManager.getBlockManager(); - SizeType32 constexpr maxNewTokens = 4; + auto const onlyWindowSize = theOnlyWindowSize(kvCacheManager); + + SizeType32 constexpr maxNewTokens = 40; // prepare tokens with token[i] = 1000 + i TokenIdType constexpr firstToken = 1000; @@ -3089,243 +3097,69 @@ TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowWithReuseTest) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; + /////////////////////////////////////////////////////////////////////////// + // add a request that starts shorter and gets longer than the max attention window and then remove it SizeType32 requestId = 0; - int inputLength = 16; + int inputLength = 2; auto inputTokens = std::make_shared(inputLength); std::iota(inputTokens->begin(), inputTokens->end(), firstToken); auto llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); auto constexpr beamIdx = 0; - /////////////////////////////////////////////////////////////////////////// - // add a long request and then remove it kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); GenerationRequest const& seq0 = kvCacheManager.getSequence(requestId); EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2, 3})); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0})); - // add tokens to enable cyclic kv cache - llmRequest->addNewToken(1016, beamIdx); + // add tokens, reaching max attention window + llmRequest->addNewToken(1002, beamIdx); kvCacheManager.addToken(requestId); - llmRequest->addNewToken(1017, beamIdx); - kvCacheManager.addToken(requestId); - auto numTokens = llmRequest->getNumTokens(beamIdx); - auto numBlocks = seq0.getCacheBlockIds(maxAttentionWindow)[beamIdx].size(); - EXPECT_EQ(numBlocks, maxBlocksPerSeq); + auto numBlocks = seq0.getCacheBlockIds(onlyWindowSize)[beamIdx].size(); + EXPECT_EQ(numBlocks, 1); EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - // no blocks stored because cyclic KV cache was enabled - EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); - EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); - - /////////////////////////////////////////////////////////////////////////// - // add a short request and then remove it - requestId = 1; - inputLength = 7; - inputTokens->resize(inputLength); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq1 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5})); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0})); - llmRequest->addNewToken(1007, beamIdx); + // add new tokens exceeding max attention window, but not enough to allocate another block + llmRequest->addNewToken(1003, beamIdx); kvCacheManager.addToken(requestId); - llmRequest->addNewToken(1008, beamIdx); - kvCacheManager.addToken(requestId); - numTokens = llmRequest->getNumTokens(beamIdx); - numBlocks = seq1.getCacheBlockIds(maxAttentionWindow)[beamIdx].size(); - EXPECT_EQ(numBlocks, 3); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - // store blocks 4, 5 for reuse ([1000,1001,1002,1003], [1004,1005,1006,1007]) - - /////////////////////////////////////////////////////////////////////////// - // add a medium request and then remove it - // reuse first 2 blocks {4, 5} in previous request, and get new block 7 - requestId = 2; - inputLength = 10; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq2 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 2 * tokensPerBlock); - EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 7})); - - numTokens = llmRequest->getNumTokens(beamIdx); - numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); - EXPECT_EQ(numBlocks, 3); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - // store block 7 for reuse ([1008]) - - /////////////////////////////////////////////////////////////////////////// - // add a longer request within attention window and try to reuse - // reuse blocks 4, 5, 7(p) and get new block 8 - // upon reaching the attention window, the block ids shouldn't change - requestId = 3; - inputLength = 15; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq3 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 9); - EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 7, 8})); + numBlocks = seq0.getCacheBlockIds(onlyWindowSize)[beamIdx].size(); + EXPECT_EQ(numBlocks, 1); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0})); - llmRequest->addNewToken(1015, beamIdx); - kvCacheManager.addToken(requestId); - llmRequest->addNewToken(1016, beamIdx); + // add more new tokens, enough to allocate a new block but not enough to detach block + llmRequest->addNewToken(1004, beamIdx); kvCacheManager.addToken(requestId); - // FIXME: This means that reuse will break here - the window will start writing to a reused block, and the following - // sequence that tries to reuse the block will read garbage. This will be fixed by removing the cyclic kv cache. - EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 7, 8})); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - - /////////////////////////////////////////////////////////////////////////// - // add a long request that exceeded attention window, no reuse - requestId = 4; - inputLength = 20; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - GenerationRequest const& seq4 = kvCacheManager.getSequence(requestId); - EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({9, 10, 11, 12})); -} - -TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest) -{ - auto constexpr numLayers = 2; - auto constexpr numHeads = 2; - auto constexpr sizePerHead = 64; - auto constexpr tokensPerBlock = 4; - auto constexpr maxNumSequences = 8; - auto constexpr maxBeamWidth = 1; - auto constexpr sinkTokenLength = 0; - auto constexpr dtype = nvinfer1::DataType::kHALF; - auto const stream = std::make_shared(); - - // Enable cyclic kv cache for long input tokens. - auto constexpr minAttentionWindow = 8; - auto constexpr maxAttentionWindow = 16; - auto const maxAttentionWindowVec = std::vector{maxAttentionWindow, minAttentionWindow}; - auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxAttentionWindow, tokensPerBlock); - - auto constexpr blocksInPrimaryPool = 16; - auto constexpr blocksInSecondaryPool = 0; - - auto constexpr enableBlockReuse = true; - auto constexpr onboardBlocks = true; - - auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}, - {minAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; - - KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, - maxBeamWidth, maxAttentionWindowVec, std::nullopt, dtype, sinkTokenLength, stream, std::nullopt, - enableBlockReuse, onboardBlocks); - kvCacheManager.allocatePools(false); - - auto const& blockManager = kvCacheManager.getBlockManager(); - - auto const allBlocksInPrimaryPools = blockManager.getNumPrimaryBlocks(); - EXPECT_THAT(allBlocksInPrimaryPools, blocksInPrimaryPool * 2); - - ASSERT_EQ(blockManager.isVariableWindow(), true); - ASSERT_EQ(blockManager.isVariableGQA(), false); - - SizeType32 constexpr maxNewTokens = 4; - - // prepare tokens with token[i] = 1000 + i - TokenIdType constexpr firstToken = 1000; - - auto constexpr beamWidth = maxBeamWidth; - tr::SamplingConfig const samplingConfig{beamWidth}; - bool constexpr isStreaming{false}; - - SizeType32 requestId = 0; - int inputLength = 7; - auto inputTokens = std::make_shared(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - auto llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - auto constexpr beamIdx = 0; - - /////////////////////////////////////////////////////////////////////////// - // add a request that will exceed the min attention window *after context*, and then remove it. - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq0 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - - auto const assertBlocks - = [minAttentionWindow, maxAttentionWindow, beamIdx](GenerationRequest seq, - std::initializer_list expectedBlocksMin, std::initializer_list expectedBlocksMax) - { - auto blocksMin = seq.getCacheBlockIds(minAttentionWindow).at(beamIdx); - auto blocksMax = seq.getCacheBlockIds(maxAttentionWindow).at(beamIdx); - EXPECT_THAT(blocksMin, ::testing::ElementsAreArray(expectedBlocksMin)); - EXPECT_THAT(blocksMax, ::testing::ElementsAreArray(expectedBlocksMax)); - return blocksMin.size() + blocksMax.size(); - }; - - assertBlocks(seq0, {0, 1}, {0, 1}); + numBlocks = seq0.getCacheBlockIds(onlyWindowSize)[beamIdx].size(); + EXPECT_EQ(numBlocks, 2); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0, 1})); - // add tokens to enable cyclic kv cache for minimum but not maximum - llmRequest->addNewToken(1016, beamIdx); + // add more new tokens, enough to detach block without allocating a new one + llmRequest->addNewToken(1005, beamIdx); kvCacheManager.addToken(requestId); - llmRequest->addNewToken(1017, beamIdx); + llmRequest->addNewToken(1006, beamIdx); kvCacheManager.addToken(requestId); - auto const numBlocks = assertBlocks(seq0, {0, 1}, {0, 1, 2}); - EXPECT_EQ(blockManager.getNumFreeBlocks(), allBlocksInPrimaryPools - numBlocks); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - // no blocks stored because cyclic KV cache was enabled - EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); - EXPECT_EQ(blockManager.getNumFreeBlocks(), allBlocksInPrimaryPools); + numBlocks = seq0.getCacheBlockIds(onlyWindowSize)[beamIdx].size(); + EXPECT_EQ(numBlocks, 2); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0, 1})); - /////////////////////////////////////////////////////////////////////////// - // add a short request that is between the min and max attention window - requestId = 1; - inputLength = 9; - inputTokens->resize(inputLength); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq1 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - assertBlocks(seq1, {2, 3}, {3, 4, 5}); + // add more new tokens, to allocate a new block llmRequest->addNewToken(1007, beamIdx); kvCacheManager.addToken(requestId); llmRequest->addNewToken(1008, beamIdx); kvCacheManager.addToken(requestId); - assertBlocks(seq1, {2, 3}, {3, 4, 5}); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - - /////////////////////////////////////////////////////////////////////////// - // add a request that won't reach the min attention window, so a block can be reused. - requestId = 2; - inputLength = 4; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq2 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 0); - assertBlocks(seq2, {4}, {6}); + numBlocks = seq0.getCacheBlockIds(onlyWindowSize)[beamIdx].size(); + EXPECT_EQ(numBlocks, 3); + EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); - auto const numTokens = llmRequest->getNumTokens(beamIdx); - EXPECT_EQ(tc::ceilDiv(numTokens, tokensPerBlock), 1); EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); - // store block 6 for reuse - - /////////////////////////////////////////////////////////////////////////// - // add a request that won't reach the min attention window, so a block 6 from previous request will be reused. - requestId = 3; - inputLength = 4; - inputTokens->resize(inputLength); - std::iota(inputTokens->begin(), inputTokens->end(), firstToken); - llmRequest = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest); - GenerationRequest const& seq3 = kvCacheManager.getSequence(requestId); - EXPECT_EQ(llmRequest->getContextCurrentPosition(), 3); - assertBlocks(seq3, {4}, {6}); + // no blocks stored because reuse is disabled + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) @@ -3347,13 +3181,14 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; - auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto const maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto const maxAttentionWindow = maxSequenceLength; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); @@ -3503,13 +3338,14 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamOverflow) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; - auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto const maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto const maxAttentionWindow = maxSequenceLength; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1)); kvCacheManager.allocatePools(false); @@ -3561,13 +3397,14 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; - auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto const maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto const maxAttentionWindow = maxSequenceLength; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); @@ -3636,19 +3473,20 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; - auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto const maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto const maxAttentionWindow = maxSequenceLength; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManagerTest(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, - stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt); + stream, maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt); EXPECT_EQ(getEvents(kvCacheManagerTest).size(), 0); KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + 0, stream, maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); @@ -3689,7 +3527,8 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize) tr::SamplingConfig const samplingConfig{beamWidth}; bool constexpr isStreaming{false}; - auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto const maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto const maxAttentionWindow = maxSequenceLength; auto const slidingWindow = tokensPerBlock * (maxBlocksPerSeq - 1); auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPool[0], blocksInPool[1]}}, @@ -3697,7 +3536,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow, slidingWindow}, std::nullopt, dtype, 0, - stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + stream, maxSequenceLength, true, onboardBlocks, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); @@ -3929,10 +3768,12 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) auto constexpr sinkTokenLength = 0; auto const stream = std::make_shared(); - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - auto constexpr maxAttentionWindow = maxNumTokens; - auto constexpr inputLength = maxNumTokens - 2; - auto constexpr numBlocksPerSeq = maxBlocksPerSeq - 1 + maxBeamWidth; + auto constexpr maxSequenceLength = tokensPerBlock * maxBlocksPerSeq; + auto constexpr maxAttentionWindow = maxSequenceLength; + + auto constexpr inputLength = maxSequenceLength - 2; + auto constexpr numSharedBlocks = inputLength / tokensPerBlock; + auto constexpr numBlocksPerSeq = numSharedBlocks + (maxBlocksPerSeq - numSharedBlocks) * maxBeamWidth; auto constexpr totalNumBlocks = maxNumSequences * numBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; @@ -3947,10 +3788,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -3963,7 +3804,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) for (auto requestId = 0; requestId < maxNumSequences; ++requestId) { EXPECT_NO_THROW(kvCacheManager.addSequence(requestId, inputLength, maxBeamWidth)); - auto const currentNumBlocks = totalNumBlocks - (requestId + 1) * numBlocksPerSeq; + auto const currentNumBlocks = totalNumBlocks - (requestId + 1) * (numSharedBlocks + maxBeamWidth); EXPECT_EQ(blockManager.getNumFreeBlocks(), currentNumBlocks); } @@ -4001,7 +3842,8 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) tk::KVCacheIndex::UnderlyingType runningSum{0}; for (auto requestId = 0; requestId < maxNumSequences; ++requestId) { - for (auto block = 0; block < maxBlocksPerSeq - 1; ++block) + // Shared blocks + for (auto block = 0; block < numSharedBlocks; ++block) { for (auto beam = 0; beam < maxBeamWidth; ++beam) { @@ -4016,7 +3858,8 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) } runningSum += offsetBetweenBlocks; } - auto const block = maxBlocksPerSeq - 1; + // Unshared blocks + auto const block = numSharedBlocks; { for (auto beam = 0; beam < maxBeamWidth; ++beam) { @@ -4038,6 +3881,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) namespace { +// beam search with SWA is not supported for now void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draftLen, bool homogeneousLayers) { using DType = half; @@ -4050,13 +3894,10 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft auto constexpr sizePerHead = 64; auto constexpr hiddenSize = numHeads * sizePerHead; auto constexpr tokensPerBlock = 8; - auto constexpr maxBlocksPerSeq = 10; auto constexpr maxNumSequences = 8; auto constexpr sinkTokenLength = 0; auto const stream = std::make_shared(); - auto constexpr totalNumBlocks = maxNumSequences * maxBlocksPerSeq; - TLLM_CHECK(draftLen == 0 || beamWidth == 1); // Deal with one sequence for now @@ -4070,30 +3911,28 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft for (int maxBeamWidth = 1; maxBeamWidth <= maxMaxBeamWidth; ++maxBeamWidth) { tr::SamplingConfig const samplingConfig{maxBeamWidth}; - for (int inputLength = 1; inputLength < maxInputLength; ++inputLength) + for (int inputLength = 44; inputLength < 45; ++inputLength) { - auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq; - // auto constexpr maxAttentionWindow = maxNumTokens / 2; auto constexpr maxAttentionWindow = 46; - auto constexpr totalNumBlocks = maxNumSequences * maxBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; auto constexpr onboardBlocks = true; - + auto constexpr maxSequenceLength = 256; + auto constexpr maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock); + auto constexpr totalNumBlocks = maxNumSequences * maxBlocksPerSeq; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse, + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, kv_cache_block_reuse, onboardBlocks) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse, + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, kv_cache_block_reuse, onboardBlocks); kvCacheManager.allocatePools(false); - EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, - tc::ceilDiv(maxAttentionWindow, tokensPerBlock)); + EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); auto inputTokens = std::make_shared(VecTokens(inputLength, 0)); @@ -4107,9 +3946,10 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft auto remainingBlocksToCompletion = kvCacheManager.getRemainingBlocksToCompletion(*llmRequest, onlyWindowSize); auto neededBlocksOneStep = kvCacheManager.getNeededBlocksOneStep(*llmRequest, false, onlyWindowSize); + auto currentNumAllocTotalBlocks = kvCacheManager.getNumAllocTotalBlocks(); EXPECT_NO_THROW(kvCacheManager.addSequence(requestId, inputLength, maxBeamWidth, llmRequest)); - for (int di = 0; di < draftLen && di < maxNewTokens && (inputLength + di) < maxAttentionWindow; ++di) + for (int di = 0; di < draftLen && di < maxNewTokens; ++di) { for (int beam = 0; beam < maxBeamWidth; beam++) { @@ -4118,20 +3958,20 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); } - auto numUsedBlocksThisStep = kvCacheManager.getUsedNumBlocks(); + auto numUsedBlocksThisStep = kvCacheManager.getNumAllocTotalBlocks() - currentNumAllocTotalBlocks; EXPECT_EQ(numUsedBlocksThisStep, neededBlocksOneStep); // Simulate adding new tokens during generation llmRequest->setState(LlmRequestState::kGENERATION_IN_PROGRESS); for (int i = draftLen; i < maxNewTokens && (inputLength + i) < maxAttentionWindow; i += (draftLen + 1)) { - auto numCurrentlyUsedBlocks = kvCacheManager.getUsedNumBlocks(); for (int beam = 0; beam < maxBeamWidth; beam++) { llmRequest->addNewToken(1, beam); } neededBlocksOneStep = kvCacheManager.getNeededBlocksOneStep(*llmRequest, false, onlyWindowSize); + currentNumAllocTotalBlocks = kvCacheManager.getNumAllocTotalBlocks(); for (int beam = 0; beam < maxBeamWidth; beam++) { @@ -4147,14 +3987,27 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft { EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); } - numUsedBlocksThisStep = kvCacheManager.getUsedNumBlocks() - numCurrentlyUsedBlocks; + numUsedBlocksThisStep = kvCacheManager.getNumAllocTotalBlocks() - currentNumAllocTotalBlocks; - EXPECT_EQ(numUsedBlocksThisStep, neededBlocksOneStep); + if (inputLength + i + draftLen + 1 < maxAttentionWindow) + { + EXPECT_EQ(numUsedBlocksThisStep, neededBlocksOneStep); + } + else + { + // This test calculates neededBlocksOneStep for the entire step (which may exceed + // maxAttentionWindow), but adds tokens only up to maxAttentionWindow. In this case, + // numUsedBlocksThisStep may be smaller than neededBlocksOneStep by 1 block. + ASSERT_THAT(numUsedBlocksThisStep, + testing::AnyOf(testing::Eq(neededBlocksOneStep), testing::Eq(neededBlocksOneStep - 1))); + } } - // After adding all tokens, we should match remainingBlocksToCompletion - EXPECT_EQ(remainingBlocksToCompletion, kvCacheManager.getUsedNumBlocks()); - EXPECT_EQ(kvCacheManager.getRemainingBlocksToCompletion(*llmRequest, onlyWindowSize), 0); + // After adding tokens, initial remainingBlocksToCompletion should match current state + new + // remainingBlocksToCompletion + EXPECT_EQ(remainingBlocksToCompletion, + kvCacheManager.getNumAllocTotalBlocks() + + kvCacheManager.getRemainingBlocksToCompletion(*llmRequest, onlyWindowSize)); } } } @@ -4199,11 +4052,11 @@ TEST_P(BlockRequirementsParamTest, TestCaculateMaxBlocksRequirement) INSTANTIATE_TEST_SUITE_P(CalculateMaxBlockRequirementsPerBeam, BlockRequirementsParamTest, testing::Values(std::make_tuple(512, 0, 1024, 64, 8), std::make_tuple(513, 0, 1024, 64, 9), - std::make_tuple(512, 0, 256, 64, 4), std::make_tuple(512, 0, 257, 64, 5), std::make_tuple(512, 64, 1024, 64, 8), - std::make_tuple(513, 64, 1024, 64, 9), std::make_tuple(512, 64, 256, 64, 4), - std::make_tuple(512, 64, 257, 64, 5), std::make_tuple(512, 65, 1024, 64, 9), - std::make_tuple(513, 65, 1024, 64, 9), std::make_tuple(512, 65, 256, 64, 5), - std::make_tuple(512, 65, 257, 64, 5))); + std::make_tuple(512, 0, 256, 64, 5), std::make_tuple(512, 0, 257, 64, 6), std::make_tuple(512, 64, 1024, 64, 8), + std::make_tuple(513, 64, 1024, 64, 9), std::make_tuple(512, 64, 256, 64, 5), + std::make_tuple(512, 64, 257, 64, 6), std::make_tuple(512, 65, 1024, 64, 9), + std::make_tuple(513, 65, 1024, 64, 9), std::make_tuple(512, 65, 256, 64, 6), + std::make_tuple(512, 65, 257, 64, 6))); // calculateMaxBlockRequirements TEST(CalculateMaxBlockRequirements, BeamWidthOneEqualRequirementsPerBeam) @@ -4233,7 +4086,7 @@ TEST(CalculateMaxBlockRequirements, AttentionWindowOverlapsInputAndOutputReferen auto const numContextBlocks = 2; // (412 - 255) / 64 // There are 29 context tokens left over to be put in output blocks, so 284 tokens to fit in output blocks in // total: 5 blocks - auto const numOutputBlocks = 5 * beamWidth; + auto const numOutputBlocks = (5 + kSWAExtraBlock) * beamWidth; ASSERT_EQ(result, numContextBlocks + numOutputBlocks); } @@ -4304,6 +4157,7 @@ std::shared_ptr createKvCacheManager( auto const temporaryKvCacheInputs = TempAttentionWindowInputs{true, maxInputLength, kvCacheInstantiationParameters.maxNumTokens}; + auto const maxSequenceLength = kvCacheInstantiationParameters.maxNumTokens; auto const maxAttentionWindow = kvCacheInstantiationParameters.maxAttentionWindow; auto const [numBlocksInPrimaryPool, _] = kvCacheInstantiationParameters.blocksPerWindow.at(maxAttentionWindow); @@ -4316,8 +4170,8 @@ std::shared_ptr createKvCacheManager( kvCacheInstantiationParameters.tokensPerBlock, kvCacheInstantiationParameters.blocksPerWindow, numBlocksInPrimaryPool, kvCacheInstantiationParameters.maxBeamWidth, std::vector{kvCacheInstantiationParameters.maxAttentionWindow}, temporaryKvCacheInputs, - kvCacheInstantiationParameters.dtype, kvCacheInstantiationParameters.sinkTokenLength, stream, std::nullopt, - kvCacheInstantiationParameters.kvCacheBlockReuse, true, CacheType::kSELF); + kvCacheInstantiationParameters.dtype, kvCacheInstantiationParameters.sinkTokenLength, stream, + maxSequenceLength, kvCacheInstantiationParameters.kvCacheBlockReuse, true, CacheType::kSELF); } if (std::holds_alternative>(kvCacheInstantiationParameters.numHeadsPerLayer)) { @@ -4327,8 +4181,8 @@ std::shared_ptr createKvCacheManager( kvCacheInstantiationParameters.tokensPerBlock, kvCacheInstantiationParameters.blocksPerWindow, numBlocksInPrimaryPool, kvCacheInstantiationParameters.maxBeamWidth, std::vector{kvCacheInstantiationParameters.maxAttentionWindow}, temporaryKvCacheInputs, - kvCacheInstantiationParameters.dtype, kvCacheInstantiationParameters.sinkTokenLength, stream, std::nullopt, - kvCacheInstantiationParameters.kvCacheBlockReuse, true, CacheType::kSELF); + kvCacheInstantiationParameters.dtype, kvCacheInstantiationParameters.sinkTokenLength, stream, + maxSequenceLength, kvCacheInstantiationParameters.kvCacheBlockReuse, true, CacheType::kSELF); } TLLM_THROW("Unhandled type of num heads per layer provided."); } @@ -4801,3 +4655,316 @@ auto const paramValues = ::testing::Values( }); INSTANTIATE_TEST_SUITE_P(FillKvCacheAndCompleteRequestsTest, FillKvCacheAndCompleteRequestsTest, paramValues); + +namespace +{ +struct GetNeededBlocksOneStepOneRequestParameters +{ + KvCacheManagerInstantiationParameters kvCacheManagerInstantiationParameters; + SizeType32 promptLength; + SizeType32 draftLength; + bool contextStep; + SizeType32 previousGeneratedTokens; + bool twoStepsLookAhead; + SizeType32 expectedNeededBlocksOneStep; +}; +} // namespace + +class NeededBlocksOneStepTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto const stream = std::make_shared(); + auto const params = GetParam(); + kvCacheManager = createKvCacheManager(params.kvCacheManagerInstantiationParameters, stream); + kvCacheManager->allocatePools(/*useUvm=*/false); + } + + void TearDown() override {} + + std::shared_ptr kvCacheManager; +}; + +TEST_P(NeededBlocksOneStepTest, NeededBlocksOneStepTestCorrectlyEstimated) +{ + auto const params = GetParam(); + auto const onlyWindowSize = theOnlyWindowSize(*kvCacheManager); + auto const requestId = 0; + auto const inputTokens = std::make_shared>(static_cast(params.promptLength)); + auto llmRequest = LlmRequest{ + requestId, + params.kvCacheManagerInstantiationParameters.maxNumTokens, + inputTokens, + tensorrt_llm::runtime::SamplingConfig{params.kvCacheManagerInstantiationParameters.maxBeamWidth}, + true, + }; + auto draftTokens = std::make_shared>(params.draftLength); + llmRequest.setDraftTokens(draftTokens); + if (params.contextStep) + { + auto neededBlocksOneStep = kvCacheManager->getNeededBlocksOneStep(llmRequest, false, onlyWindowSize); + ASSERT_EQ(neededBlocksOneStep, params.expectedNeededBlocksOneStep); + } + else + { + kvCacheManager->addSequence( + requestId, params.promptLength, params.kvCacheManagerInstantiationParameters.maxBeamWidth, llmRequest); + llmRequest.setState(LlmRequestState::kGENERATION_IN_PROGRESS); + for (int beam = 0; beam < params.kvCacheManagerInstantiationParameters.maxBeamWidth; beam++) + { + for (SizeType32 i = 0; i < params.previousGeneratedTokens; i++) + { + llmRequest.addNewToken(0, beam); + kvCacheManager->addToken(llmRequest.mRequestId); + } + } + + auto neededBlocksOneStep + = kvCacheManager->getNeededBlocksOneStep(llmRequest, params.twoStepsLookAhead, onlyWindowSize); + ASSERT_EQ(neededBlocksOneStep, params.expectedNeededBlocksOneStep); + } +} + +INSTANTIATE_TEST_SUITE_P(NeededBlocksOneStepTestCorrectlyEstimated, NeededBlocksOneStepTest, + ::testing::Values( + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 136, + /* draftLength */ 0, + /* contextStep */ true, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 9, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 512, + /* draftLength */ 0, + /* contextStep */ true, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 32, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 1024, + /* draftLength */ 0, + /* contextStep */ true, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 64, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 512, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 1, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 512, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 8, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 0, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 518, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 0, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 530, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 512, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 16, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 1, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 128, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 15, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 0, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 128, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 15, + /* twoStepsLookAhead */ true, + /* expectedNeededBlocksOneStep */ 1, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 128, + /* draftLength */ 0, + /* contextStep */ false, + /* previousGeneratedTokens */ 15, + /* twoStepsLookAhead */ true, + /* expectedNeededBlocksOneStep */ 1, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 302, // 14 tokens in last block + /* draftLength */ 3, + /* contextStep */ false, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 1, + }, + GetNeededBlocksOneStepOneRequestParameters{ + KvCacheManagerInstantiationParameters{ + /* numLayers */ 1, + /* numHeads */ 1, + /* sizePerHead */ 1, + /* tokensPerBlock */ 16, + /* blocksPerWindow */ blocksAndWindow(/* numPrimaryBlocks */ 256, /* windowSize */ 512), + /* sinkTokenLength */ 0, + /* maxAttentionWindow */ 512, + /* maxBeamWidth */ 1, + /* maxNumTokens */ 513, + /* kvCacheBlockReuse */ false, + }, + /* promptLength */ 298, // 10 tokens in last block + /* draftLength */ 3, + /* contextStep */ false, + /* previousGeneratedTokens */ 0, + /* twoStepsLookAhead */ false, + /* expectedNeededBlocksOneStep */ 0, + })); diff --git a/cpp/tests/unit_tests/executor/agentCommTest.cpp b/cpp/tests/unit_tests/executor/agentCommTest.cpp index ee561ca816b..fd7a7a23de7 100644 --- a/cpp/tests/unit_tests/executor/agentCommTest.cpp +++ b/cpp/tests/unit_tests/executor/agentCommTest.cpp @@ -83,7 +83,7 @@ class AgentCommTest : public ::testing::Test mCacheManager = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, cacheType, + std::nullopt, dataType, sinkTokenLength, stream, kvMaxNumTokens, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, true); mCacheManager->allocatePools(false); diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 898b7ddbd40..21852a4e498 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -217,7 +217,7 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- mManager = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, mMaxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF, + dataType, sinkTokenLength, stream, maxNumTokens, enableBlockReuse, onboardBlocks, CacheType::kSELF, std::nullopt, nullptr, true); auto attentionLayerNumPerPP = std::vector{numLayers}; mCacheState = std::make_unique( @@ -619,7 +619,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(layerNumthisRank, numHeadsPerRank, sizePerHead, tokensPerBlock, blocksPerWindow, mMaxNumSequences, maxBeamWidth, maxAttentionWindowVec, std::nullopt, dataType, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, + sinkTokenLength, stream, maxNumTokens, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, true); texec::kv_cache::CacheState::AttentionType attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA @@ -1313,23 +1313,27 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) tensorrt_llm::mpi::MpiComm::world().barrier(); } +// (eop) Waive off isWindow test for now INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0, AsymmetricalCacheTest, testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), - testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true, false))); + testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(/*true,*/ false))); -INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest, - testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), - testing::Values(1), testing::Values(5), testing::Values(4), testing::Values(4), testing::Values(8), - testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), - testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true))); +// (eop) Waive off isWindow test for now +// INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest, +// testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), +// testing::Values(1), +// testing::Values(1), testing::Values(5), testing::Values(4), testing::Values(4), testing::Values(8), +// testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), +// testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true))); +// (eop) Waive off isWindow test for now INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest, testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(8), testing::Values(4), testing::Values(4), testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), - testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true))); + testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false /*, true*/))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1EvenLayer, AsymmetricalCacheTest, testing::Combine(testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), diff --git a/examples/models/core/llama/summarize_long.py b/examples/models/core/llama/summarize_long.py index cee2e07fdd5..45558587180 100644 --- a/examples/models/core/llama/summarize_long.py +++ b/examples/models/core/llama/summarize_long.py @@ -45,7 +45,7 @@ def parse_args(): type=int, default=4096, help= - 'The attention window size that controls the sliding window attention / cyclic kv cache behavior' + 'The attention window size that controls the sliding window attention kv cache behavior' ) parser.add_argument( '--max_input_len', diff --git a/examples/models/core/qwen2audio/utils.py b/examples/models/core/qwen2audio/utils.py index 607d2fc3989..3252beebbf7 100644 --- a/examples/models/core/qwen2audio/utils.py +++ b/examples/models/core/qwen2audio/utils.py @@ -38,7 +38,7 @@ def add_common_args(parser): default=None, nargs="+", help= - 'The attention window size that controls the sliding window attention / cyclic kv cache behavior' + 'The attention window size that controls the sliding window attention kv cache behavior' ) parser.add_argument( '--multi_block_mode', diff --git a/examples/utils.py b/examples/utils.py index 509b734ebea..8956e4979e0 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -358,7 +358,7 @@ def add_common_args(parser): default=None, nargs="+", help= - 'The attention window size that controls the sliding window attention / cyclic kv cache behavior' + 'The attention window size that controls the sliding window attention kv cache behavior' ) parser.add_argument( '--multi_block_mode', diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e8d68a59381..3b83dcb72ed 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -332,13 +332,9 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: logger.info( f"max_tokens={self._max_kv_tokens_in} is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB" ) - if is_vswa: - # For VSWA KvCacheManager now it can only use max_gpu_total_bytes - self._kv_cache_config.max_tokens = None - else: - # For non-VSWA KvCacheManager, its logic still relies on max_tokens, need to improve in the future. - self._kv_cache_config.max_tokens = int( - kv_cache_max_memory // self._get_kv_size_per_token()) + # For KvCacheManager, its logic still relies on max_tokens, need to improve in the future. + self._kv_cache_config.max_tokens = int(kv_cache_max_memory // + self._get_kv_size_per_token()) # ---------------------------handle max_tokens--------------------------------- # ---------------------------handle max_gpu_total_bytes--------------------------------- diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 2492eb6a61b..685cd469f64 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -5358,7 +5358,7 @@ def gpt_attention( An INT32 tensor of shape [1]. by default, the max_attention_window_size is determined by the shape of cache_indir_table. And we support independent max_attention_window_size for each layer. - This controls the sliding-window-attention/cyclic-kv-cache features. + This controls the sliding-window-attention kv-cache features. context_lengths: Tensor (On GPU) The tensor that stores the context-phase sequence length of each request. Its shape diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2e295473bf9..6e6d939f7bb 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1057,28 +1057,40 @@ def test_fp8_prequantized(self): task = MMLU(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip( - reason= - "Skipped because cyclic kv cache is disabled on the feature branch") - def test_auto_dtype_vswa(self): - # # NOTE: Test with VSWA kv cache config. - # self.kv_cache_config.max_attention_window = [ - # 512, 512, 512, 512, 512, 32768 - # ] # Gemma3 1B attention window size pattern - # # TODO: uncomment to use the real window pattern when optimal KV cache allocation is supported + def test_auto_dtype_vswa_without_reuse(self): + # NOTE: Test with VSWA kv cache config. + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + enable_partial_reuse=False, + max_attention_window=[512, 512, 512, 512, 512, 32768], + ) + + with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) - with LLM(self.MODEL_PATH, kv_cache_config=self.kv_cache_config) as llm: + def test_auto_dtype_vswa_reuse(self): + # NOTE: Test with VSWA kv cache config. + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + max_attention_window=[512, 512, 512, 512, 512, 32768], + ) + + with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) task = MMLU(self.MODEL_NAME) task.evaluate(llm) - def test_auto_dtype_chunked_prefill(self): - # # NOTE: Test with VSWA kv cache config. - # self.kv_cache_config.max_attention_window = [ - # 512, 512, 512, 512, 512, 32768 - # ] # Gemma3 1B attention window size pattern - # # TODO: uncomment to use the real window pattern when optimal KV cache allocation is supported + def test_auto_dtype_vswa_chunked_prefill_without_reuse(self): + # NOTE: Test with VSWA kv cache config. + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + enable_partial_reuse=False, + max_attention_window=[512, 512, 512, 512, 512, 32768], + ) # chunked prefill case or more features extra_llm_config = dict( @@ -1086,7 +1098,27 @@ def test_auto_dtype_chunked_prefill(self): max_num_tokens=1024, ) with LLM(self.MODEL_PATH, - kv_cache_config=self.kv_cache_config, + kv_cache_config=kv_cache_config, + **extra_llm_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + + def test_auto_dtype_vswa_chunked_prefill_reuse(self): + # NOTE: Test with VSWA kv cache config. + kv_cache_config = KvCacheConfig( + enable_block_reuse=True, + max_attention_window=[512, 512, 512, 512, 512, 32768], + ) + + # chunked prefill case or more features + extra_llm_config = dict( + enable_chunked_prefill=True, + max_num_tokens=1024, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=kv_cache_config, **extra_llm_config) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index a153971cdc8..d70b2542d27 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -36,8 +36,10 @@ l0_h100: - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype - - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa - - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_chunked_prefill_without_reuse + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_chunked_prefill_reuse - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)