4141#include < optional>
4242#include < set>
4343#include < unordered_map>
44+ #include < utility>
4445#include < vector>
4546
4647namespace kvc = tensorrt_llm::executor::kv_cache;
@@ -84,6 +85,32 @@ using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
8485template <typename T>
8586using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
8687
88+ // ! \brief Split vector into list of blocks of given size.
89+ // ! \param vec vector to split
90+ // ! \param usableSize part of the vector that is processed
91+ // ! \param elementsPerBlock desired size of blocks
92+ // ! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end
93+ // ! \return list of blocks
94+ template <typename T>
95+ std::list<std::vector<T>> chopVectorIntoBlocks (
96+ std::vector<T> const & vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial)
97+ {
98+ TLLM_CHECK_WITH_INFO (
99+ usableSize <= static_cast <SizeType32>(vec.size ()), " usableSize=%d > %ld=vec.size()" , usableSize, vec.size ());
100+ std::list<std::vector<T>> blockedVectors;
101+ auto const vecEnd = vec.begin () + usableSize;
102+ for (auto begin = vec.begin (); begin < vecEnd; begin += elementsPerBlock)
103+ {
104+ auto blockSize = std::min (elementsPerBlock, static_cast <SizeType32>(std::distance (begin, vecEnd)));
105+ auto end = begin + blockSize;
106+ if (blockSize == elementsPerBlock || allowPartial)
107+ {
108+ blockedVectors.emplace_back (begin, end);
109+ }
110+ }
111+ return blockedVectors;
112+ }
113+
87114struct TempAttentionWindowInputs
88115{
89116 bool pagedContextFMHA;
@@ -114,6 +141,9 @@ struct WindowSizeMetadata
114141 }
115142};
116143
144+ std::vector<MmKey> generateBlockHashExtraKeys (
145+ tensorrt_llm::batch_manager::LlmRequest const & llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx);
146+
117147struct BlockKey
118148{
119149 bool usesExtraIds = false ;
@@ -147,11 +177,7 @@ struct BlockKey
147177 {
148178 }
149179
150- bool operator ==(BlockKey const & other) const noexcept
151- {
152- return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
153- && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID );
154- }
180+ bool operator ==(BlockKey const & other) const noexcept ;
155181
156182 int partialMatch (BlockKey const & other) const noexcept
157183 {
@@ -166,6 +192,8 @@ struct BlockKey
166192 }
167193};
168194
195+ std::vector<BlockKey> buildBlockKeys (std::list<VecUniqueTokens>& blockedUniqueTokens, LlmRequest const & llmRequest);
196+
169197// Implement hash functor for BlockKey.
170198// This allows us to use unordered_map with BlockKey as key.
171199// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
@@ -577,14 +605,18 @@ class WindowBlockManager
577605
578606 void replaceSharedBlock (GenerationRequest& sequence, SizeType32 blockIdx);
579607
580- // ! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
581- std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds (GenerationRequest const & sequence) const ;
608+ [[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse (
609+ GenerationRequest & sequence, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks = false ) ;
582610
583611 void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
584612
613+ // ! \brief Pin blocks associated with a sequence to prevent eviction.
614+ void pinBlocks (GenerationRequest& sequence);
615+
585616 // ! \brief Release blocks of the sequence.
586617 // ! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
587- void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
618+ std::optional<KVCacheBlock::IdType> releaseBlocks (
619+ GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
588620
589621 // ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
590622 void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
@@ -757,8 +789,11 @@ class WindowBlockManager
757789 // ! \brief Store blocks in cached blocks.
758790 // ! \param blockKeys Key of each block.
759791 // ! \param blockIds Id of each block.
760- // ! \return Number of actual blocks stored.
761- SizeType32 storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
792+ // ! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
793+ // ! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
794+ [[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks (
795+ std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds,
796+ bool pinBlocks = false );
762797
763798 [[nodiscard]] bool verifyQueueIntegrity ();
764799
@@ -786,6 +821,11 @@ class WindowBlockManager
786821 return mIsSWA ;
787822 }
788823
824+ [[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey (BlockKey const & blockKey);
825+
826+ // ! \brief Unpin blocks by starting from a block id and walking prev pointers.
827+ void unpinBlocksById (KVCacheBlock::IdType blockId);
828+
789829private:
790830 // ! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
791831 void addBlockToBeam (BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -890,6 +930,9 @@ class WindowBlockManager
890930 bool mCopyOnPartialReuse ;
891931 // The kv cache connector manager
892932 std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager ;
933+
934+ // Mutex for the cached blocks root
935+ std::mutex mCachedBlocksRootMutex ;
893936};
894937
895938class BlockManager
@@ -940,13 +983,20 @@ class BlockManager
940983
941984 void replaceSharedBlock (GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx);
942985
943- std::vector <KVCacheBlock::IdType> getNewlyAllocatedBlockIds (
944- GenerationRequest const & sequence, SizeType32 windowSize) const ;
986+ std::optional <KVCacheBlock::IdType> releaseBlocks (
987+ GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std:: nullopt , bool pinBlocks = false ) ;
945988
946- void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt );
989+ [[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse (
990+ GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt , bool pinBlocks = false );
947991
948992 void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
949993
994+ // / @brief Pin all blocks associated with a sequence across all window managers.
995+ // / @param sequence The generation request whose blocks should be pinned.
996+ void pinBlocks (GenerationRequest& sequence);
997+
998+ void unpinBlocksById (KVCacheBlock::IdType blockId);
999+
9501000 void releaseLastBlock (GenerationRequest& sequence, SizeType32 windowSize);
9511001
9521002 void setOffsets (kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const & offsetsShape, SizeType32 beamIdx,
@@ -966,10 +1016,11 @@ class BlockManager
9661016 void offloadBlock (BlockPtr const & block, SizeType32 windowSize,
9671017 executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const & directory = " " );
9681018
969- void storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds,
970- SizeType32 windowSize)
1019+ [[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks (
1020+ std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds,
1021+ SizeType32 windowSize, bool pinBlocks = false )
9711022 {
972- mWindowBlockManagers .at (windowSize).storeBlocks (blockKeys, blockIds);
1023+ return mWindowBlockManagers .at (windowSize).storeBlocks (blockKeys, blockIds, pinBlocks );
9731024 }
9741025
9751026 [[nodiscard]] bool verifyQueueIntegrity (SizeType32 windowSize);
@@ -1003,6 +1054,15 @@ class BlockManager
10031054 return sumWindows ([](auto const & manager) { return manager.getNumAllocTotalBlocks (); });
10041055 }
10051056
1057+ [[nodiscard]] SizeType32 getFirstWindowSize () const
1058+ {
1059+ if (mWindowBlockManagers .empty ())
1060+ {
1061+ return 0 ;
1062+ }
1063+ return mWindowBlockManagers .begin ()->first ;
1064+ }
1065+
10061066 [[nodiscard]] SizeType32 getNumAllocNewBlocks () const
10071067 {
10081068 return sumWindows ([](auto const & manager) { return manager.getNumAllocNewBlocks (); });
@@ -1133,6 +1193,12 @@ class BlockManager
11331193 return mWindowBlockManagers .at (windowSize).getBlockById (blockId);
11341194 }
11351195
1196+ [[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey (
1197+ BlockKey const & blockKey, SizeType32 windowSize)
1198+ {
1199+ return mWindowBlockManagers .at (windowSize).findBlocksInReuseTreeByBlockKey (blockKey);
1200+ }
1201+
11361202 [[nodiscard]] SizeType32 getNumPrimaryBlocks () const
11371203 {
11381204 return sumWindows ([](auto const & manager) { return manager.getNumPrimaryBlocks (); });
@@ -1274,6 +1340,10 @@ class BaseKVCacheManager
12741340 [[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion (LlmRequest const & req, SizeType32 windowSize) const
12751341 = 0;
12761342
1343+ // / @brief Pin blocks associated with a request to prevent eviction.
1344+ // / @param requestId The ID of the request whose blocks should be pinned.
1345+ virtual void pinBlocks (LlmRequest::RequestIdType requestId) = 0;
1346+
12771347 // / @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
12781348 virtual void addToken (LlmRequest::RequestIdType requestId) = 0;
12791349
@@ -1287,8 +1357,8 @@ class BaseKVCacheManager
12871357 OptionalRef<LlmRequest> llmRequest = std::nullopt )
12881358 = 0;
12891359
1290- virtual void removeSequence (
1291- LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const > llmRequest = std::nullopt )
1360+ [[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence (LlmRequest::RequestIdType requestId,
1361+ OptionalRef<LlmRequest const > llmRequest = std::nullopt , bool pinOnRelease = false )
12921362 = 0;
12931363
12941364 virtual void schedulingRemoveSequence (LlmRequest::RequestIdType requestId) = 0;
@@ -1332,6 +1402,11 @@ class BaseKVCacheManager
13321402 // ! \details This block become reusable from next step.
13331403 virtual void storeNewBlock (LlmRequest const & llmRequest) = 0;
13341404
1405+ // / \brief Store blocks for reuse for a given request id
1406+ [[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse (
1407+ LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks = false )
1408+ = 0;
1409+
13351410 // ! \brief Get the block ids of a request [per beam] **for a given window size block manager**
13361411 [[nodiscard]] virtual std::vector<std::vector<SizeType32>> const & getCacheBlockIds (
13371412 LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
@@ -1342,8 +1417,8 @@ class BaseKVCacheManager
13421417 std::vector<LlmRequest::RequestIdType> const & requestIds, SizeType32 windowSize) const
13431418 = 0;
13441419
1345- [[nodiscard]] virtual std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds (
1346- LlmRequest::RequestIdType requestId, SizeType32 windowSize ) const
1420+ // / @brief Get the last block id (beam 0) for a given sequence and window size
1421+ [[nodiscard]] virtual std::optional<KVCacheBlock::IdType> getLastBlockId ( LlmRequest::RequestIdType requestId) const
13471422 = 0;
13481423
13491424 [[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool () const = 0;
@@ -1414,6 +1489,12 @@ class BaseKVCacheManager
14141489 [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize (SizeType32 inputLength, SizeType32 outputLength) const = 0;
14151490
14161491 [[nodiscard]] virtual CacheType getCacheType () const = 0;
1492+
1493+ [[nodiscard]] virtual std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey (
1494+ BlockKey const & blockKey, SizeType32 windowSize)
1495+ = 0;
1496+
1497+ virtual void unpinBlocksById (KVCacheBlock::IdType blockId) = 0;
14171498};
14181499
14191500class KVCacheManager : public BaseKVCacheManager
@@ -1591,8 +1672,8 @@ class KVCacheManager : public BaseKVCacheManager
15911672 void addSequence (LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
15921673 OptionalRef<LlmRequest> llmRequest = std::nullopt ) override ;
15931674
1594- void removeSequence (
1595- LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const > llmRequest = std::nullopt ) override ;
1675+ [[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence (LlmRequest::RequestIdType requestId,
1676+ OptionalRef<LlmRequest const > llmRequest = std::nullopt , bool pinOnRelease = false ) override ;
15961677
15971678 void schedulingRemoveSequence (LlmRequest::RequestIdType requestId) override ;
15981679
@@ -1652,6 +1733,9 @@ class KVCacheManager : public BaseKVCacheManager
16521733 // ! \brief Store newest blocks for reuse
16531734 void storeNewBlock (LlmRequest const & llmRequest) override ;
16541735
1736+ [[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse (
1737+ LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const > llmRequest, bool pinBlocks = false ) override ;
1738+
16551739 [[nodiscard]] static SizeType32 getSinkBubbleLength (SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
16561740
16571741 [[nodiscard]] SizeType32 getMaxCapacityBatchSize (SizeType32 inputLength, SizeType32 outputLength) const override ;
@@ -1668,6 +1752,12 @@ class KVCacheManager : public BaseKVCacheManager
16681752 [[nodiscard]] static SizeType32 calculateMaxBlockRequirements (SizeType32 inputLength, SizeType32 outputLength,
16691753 SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);
16701754
1755+ void pinBlocks (LlmRequest::RequestIdType requestId) override ;
1756+
1757+ void unpinBlocksById (KVCacheBlock::IdType blockId) override ;
1758+
1759+ std::optional<KVCacheBlock::IdType> getLastBlockId (LlmRequest::RequestIdType requestId) const override ;
1760+
16711761 // / @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
16721762 // /
16731763 // / @param sequenceLength The total length of the sequence (input and output).
@@ -1684,9 +1774,6 @@ class KVCacheManager : public BaseKVCacheManager
16841774 std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds (
16851775 std::vector<LlmRequest::RequestIdType> const & requestIds, SizeType32 windowSize) const override ;
16861776
1687- std::vector<SizeType32> getNewlyAllocatedBlockIds (
1688- LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override ;
1689-
16901777 runtime::ITensor::SharedPtr getUniquePrimaryPool () const override ;
16911778 runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const override ;
16921779
@@ -1706,6 +1793,12 @@ class KVCacheManager : public BaseKVCacheManager
17061793 mBlockManager .flushIterationEvents ();
17071794 }
17081795
1796+ std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey (
1797+ BlockKey const & blockKey, SizeType32 windowSize) override
1798+ {
1799+ return mBlockManager .findBlocksInReuseTreeByBlockKey (blockKey, windowSize);
1800+ }
1801+
17091802 // / @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
17101803 // /
17111804 // / @param inputLength The number of input tokens in the sequence.
0 commit comments