@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5353
5454static constexpr SizeType32 kSecondaryLevel = 1 ;
5555
56+ // Extra block buffer allocated for SWA to be able to always keep "window size"
57+ // tokens held in the blocks.
58+ static constexpr SizeType32 kSWAExtraBlock = 1 ;
59+
5660class KVCacheBlock ;
5761class BlockManager ;
5862class KVCacheManager ;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
8892 SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8993 SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9094 SizeType32 numPools; // number of managed pools
91- SizeType32 maxTokenNum ; // Maximum token length (including bubble )
92- SizeType32 maxBlocksPerSeq;
95+ SizeType32 maxTokensPerSeq ; // Maximum token length per sequence (TODO: account for streamLLM )
96+ SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
9397 SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
9498 SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
9599 // Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99103 {
100104 return tensorrt_llm::common::fmtstr (
101105 " WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
102- " .numPools=%d, .maxTokenNum =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
103- allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq ,
104- maxNumBlocks, temporaryAttentionWindow);
106+ " .numPools=%d, .maxTokensPerSeq =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
107+ allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq ,
108+ maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
105109 }
106110};
107111
@@ -203,6 +207,7 @@ class KVCacheBlock
203207 using IdType = std::int32_t ;
204208
205209 static constexpr IdType kCachedBlocksRootId = -1 ;
210+ static constexpr IdType kInvalidBlockId = -2 ;
206211
207212 explicit KVCacheBlock (IdType blockId, kernels::KVCacheIndex blockIdx);
208213
@@ -335,14 +340,7 @@ class GenerationRequest
335340 , mNumTokens(numTokens)
336341 , mBeamWidth(beamWidth)
337342 , mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338- // min window size + sink bubble length
339- // Why use the minimum window size:
340- // Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341- // `mContextCurrentPosition` - this cannot be done for some windows sizes and
342- // not for others, the state needs to remain identical for all window sizes. So
343- // we currently resort to strictly disabling the reuse code path for all window
344- // sizes at once or enable it for all window sizes at once.
345- , mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
343+ , mNumFrontBlocksRemoved(0 )
346344 {
347345 auto const numWindowSizes = windowSizeToMetadata.size ();
348346 mCacheBlockIds .reserve (numWindowSizes);
@@ -385,6 +383,11 @@ class GenerationRequest
385383 return mNumTokens ;
386384 }
387385
386+ [[nodiscard]] SizeType32 getNumFrontBlocksRemoved () const
387+ {
388+ return mNumFrontBlocksRemoved ;
389+ }
390+
388391 [[nodiscard]] SizeType32 getBeamWidth () const
389392 {
390393 return mBeamWidth ;
@@ -422,6 +425,26 @@ class GenerationRequest
422425 {
423426 beamBlockIds.clear ();
424427 }
428+ mNumFrontBlocksRemoved = 0 ;
429+ }
430+
431+ void removeFrontBlock (SizeType32 windowSize)
432+ {
433+ for (auto & beamBlockIds : mCacheBlockIds .at (windowSize))
434+ {
435+ if (mNumFrontBlocksRemoved < static_cast <SizeType32>(beamBlockIds.size ()))
436+ {
437+ // Doesn't actually remove from mCacheBlockIds like removeLastBlock,
438+ // block id is set to -1 instead because we preserve the blocks
439+ // for reuse when reuse is enabled.
440+ beamBlockIds[mNumFrontBlocksRemoved ] = KVCacheBlock::kInvalidBlockId ;
441+ }
442+ else
443+ {
444+ TLLM_LOG_WARNING (" RequestID %lu: removeFrontBlock called but nothing to remove" , mRequestId );
445+ }
446+ }
447+ ++mNumFrontBlocksRemoved ;
425448 }
426449
427450 void removeLastBlock (SizeType32 windowSize)
@@ -442,14 +465,6 @@ class GenerationRequest
442465 return mKvCacheRetentionConfig .getDecodeDurationMs ();
443466 }
444467
445- // @brief Check whether the sequence uses cyclic KV cache.
446- // @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
447- // @details If `true`, we cannot store the sequence's KV cache for reuse.
448- [[nodiscard]] bool isCyclic () const
449- {
450- return mNumTokens >= mCyclicThreshold ;
451- }
452-
453468private:
454469 // Request id of the sequence
455470 LlmRequest::RequestIdType mRequestId ;
@@ -463,9 +478,8 @@ class GenerationRequest
463478 std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices ;
464479 // The retention priority to assign to decode blocks
465480 executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
466-
467- // Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
468- SizeType32 mCyclicThreshold ;
481+ // Number of front blocks removed from the sequence
482+ SizeType32 mNumFrontBlocksRemoved ;
469483};
470484
471485// attach metadata to a pool pointer
@@ -533,7 +547,7 @@ class WindowBlockManager
533547
534548 explicit WindowBlockManager (nvinfer1::DataType dtype, SizeType32 windowSize,
535549 std::vector<SizeType32> const & managedLayers, std::vector<SizeType32> const & numKvHeadsPerLayer,
536- SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
550+ SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
537551 SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538552 bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539553 std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
@@ -567,14 +581,26 @@ class WindowBlockManager
567581 void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
568582
569583 // ! \brief Release blocks of the sequence.
570- void releaseBlocks (GenerationRequest& sequence);
584+ // ! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
585+ void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt );
571586
572587 // ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
573588 void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
574589
590+ // ! \brief Update cache offsets for last block
591+ void updateLastCacheBlockOffsets (GenerationRequest& seq);
592+
575593 // ! \brief Release last block in the sequence
576594 void releaseLastBlock (GenerationRequest& sequence);
577595
596+ // ! \brief Detach block from the sequence
597+ void detachBlock (GenerationRequest& sequence, bool isEnableBlockReuse);
598+
599+ // ! \brief Check and add a block to the sequence if needed.
600+ // ! \details Out-of-window blocks will be detached. If reuse is enabled,
601+ // ! the detached block will be stored via offload.
602+ void addBlockIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
603+
578604 [[nodiscard]] SizeType32 getWindowSize () const noexcept
579605 {
580606 return mWindowSize ;
@@ -585,7 +611,7 @@ class WindowBlockManager
585611 return mLogPrefix ;
586612 }
587613
588- [[nodiscard]] SizeType32 getNumFreeBlocks () const noexcept ;
614+ [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel = kPrimaryLevel ) const noexcept ;
589615
590616 [[nodiscard]] SizeType32 getNumAllocTotalBlocks () const
591617 {
@@ -715,7 +741,8 @@ class WindowBlockManager
715741 // ! \brief Store blocks in cached blocks.
716742 // ! \param blockKeys Key of each block.
717743 // ! \param blockIds Id of each block.
718- void storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
744+ // ! \return Number of actual blocks stored.
745+ SizeType32 storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
719746
720747 [[nodiscard]] bool verifyQueueIntegrity ();
721748
@@ -796,6 +823,8 @@ class WindowBlockManager
796823 SizeType32 mSchedulingNumFreeBlocks ;
797824 // Number of tokens per one block
798825 SizeType32 mTokensPerBlock ;
826+ // Whether this window is sliding window attention/full attention
827+ bool mIsSWA ;
799828 // List of all blocks by idx
800829 std::vector<BlockPtr> mAllBlocksById ;
801830 // Dummy block acting as root for BlockToken searches
@@ -917,19 +946,20 @@ class BlockManager
917946
918947 void startScheduling ();
919948
920- [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize () const
949+ [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize (
950+ SizeType32 cacheLevel = kPrimaryLevel ) const
921951 {
922952 std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
923953 for (auto const & [windowSize, manager] : mWindowBlockManagers )
924954 {
925- numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks ();
955+ numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks (cacheLevel );
926956 }
927957 return numFreeBlocksPerWindowSize;
928958 }
929959
930- [[nodiscard]] SizeType32 getNumFreeBlocks () const
960+ [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel = kPrimaryLevel ) const
931961 {
932- return sumWindows ([](auto const & manager) { return manager.getNumFreeBlocks (); });
962+ return sumWindows ([cacheLevel ](auto const & manager) { return manager.getNumFreeBlocks (cacheLevel ); });
933963 }
934964
935965 [[nodiscard]] bool schedulingHasFreeBlocks (SizeType32 numRequired, SizeType32 windowSize) const
@@ -1088,14 +1118,6 @@ class BlockManager
10881118 // ! \brief Store newest block for reuse
10891119 void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
10901120
1091- [[nodiscard]] static bool isUseOneMoreBlock (
1092- SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1093- {
1094- bool const isCyclicWindowSize = maxSequenceLength.has_value () && maxSequenceLength.value () > windowSize;
1095- bool const isBeamSearch = maxBeamWidth > 1 ;
1096- return isCyclicWindowSize && isBeamSearch;
1097- }
1098-
10991121 // ! \brief Perform per-request bookkeeping
11001122 void refreshBlocks ();
11011123
@@ -1114,12 +1136,12 @@ class BlockManager
11141136 // ! \brief Update cache offsets for blocks initiated from sequence
11151137 void updateSequenceCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
11161138
1117- // ! \brief Update cache offsets for last block
1118- void updateLastCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
1119-
11201139 // ! \brief Update cache offsets for block at index
11211140 void updateCacheBlockOffsetsAtIdx (GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11221141
1142+ // ! \brief Add block to the sequence if needed
1143+ void addBlockIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
1144+
11231145private:
11241146 [[nodiscard]] WindowBlockManager const & windowManagerByLayer (SizeType32 layerIdx) const
11251147 {
0 commit comments