@@ -57,6 +57,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5757
5858static constexpr SizeType32 kSecondaryLevel = 1 ;
5959
60+ // Extra block buffer allocated for SWA to be able to always keep "window size"
61+ // tokens held in the blocks.
62+ static constexpr SizeType32 kSWAExtraBlock = 1 ;
63+
6064class KVCacheBlock ;
6165class BlockManager ;
6266class KVCacheManager ;
@@ -93,8 +97,8 @@ struct WindowSizeMetadata
9397 SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
9498 SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9599 SizeType32 numPools; // number of managed pools
96- SizeType32 maxTokenNum; // Maximum token length (including bubble )
97- SizeType32 maxBlocksPerSeq;
100+ SizeType32 maxTokenNum; // Maximum token length per sequence (TODO: account for streamLLM )
101+ SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
98102 SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
99103 SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
100104 // Only needed when chunked context + sliding window attention are used
@@ -344,14 +348,7 @@ class GenerationRequest
344348 , mNumTokens(numTokens)
345349 , mBeamWidth(beamWidth)
346350 , mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
347- // min window size + sink bubble length
348- // Why use the minimum window size:
349- // Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
350- // `mContextCurrentPosition` - this cannot be done for some windows sizes and
351- // not for others, the state needs to remain identical for all window sizes. So
352- // we currently resort to strictly disabling the reuse code path for all window
353- // sizes at once or enable it for all window sizes at once.
354- , mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
351+ , mNumFrontBlocksRemoved(0 )
355352 {
356353 auto const numWindowSizes = windowSizeToMetadata.size ();
357354 mCacheBlockIds .reserve (numWindowSizes);
@@ -394,6 +391,11 @@ class GenerationRequest
394391 return mNumTokens ;
395392 }
396393
394+ [[nodiscard]] SizeType32 getNumFrontBlocksRemoved () const
395+ {
396+ return mNumFrontBlocksRemoved ;
397+ }
398+
397399 [[nodiscard]] SizeType32 getBeamWidth () const
398400 {
399401 return mBeamWidth ;
@@ -431,6 +433,12 @@ class GenerationRequest
431433 {
432434 beamBlockIds.clear ();
433435 }
436+ mNumFrontBlocksRemoved = 0 ;
437+ }
438+
439+ void removeFrontBlock (SizeType32 windowSize)
440+ {
441+ ++mNumFrontBlocksRemoved ;
434442 }
435443
436444 void removeLastBlock (SizeType32 windowSize)
@@ -461,14 +469,6 @@ class GenerationRequest
461469 return mKvCacheRetentionConfig .getDirectory ();
462470 }
463471
464- // @brief Check whether the sequence uses cyclic KV cache.
465- // @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
466- // @details If `true`, we cannot store the sequence's KV cache for reuse.
467- [[nodiscard]] bool isCyclic () const
468- {
469- return mNumTokens >= mCyclicThreshold ;
470- }
471-
472472private:
473473 // Request id of the sequence
474474 LlmRequest::RequestIdType mRequestId ;
@@ -482,9 +482,8 @@ class GenerationRequest
482482 std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices ;
483483 // The retention priority to assign to decode blocks
484484 executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
485-
486- // Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
487- SizeType32 mCyclicThreshold ;
485+ // Number of front blocks removed from the sequence
486+ SizeType32 mNumFrontBlocksRemoved ;
488487};
489488
490489// attach metadata to a pool pointer
@@ -550,7 +549,7 @@ class WindowBlockManager
550549
551550 explicit WindowBlockManager (nvinfer1::DataType dtype, SizeType32 windowSize,
552551 std::vector<SizeType32> const & managedLayers, std::vector<SizeType32> const & numKvHeadsPerLayer,
553- SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
552+ SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
554553 SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
555554 bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
556555 std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
@@ -581,19 +580,32 @@ class WindowBlockManager
581580 // ! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
582581 std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds (GenerationRequest const & sequence) const ;
583582
584- void storeBlocksForReuse (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
585-
586583 void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
587584
588585 // ! \brief Release blocks of the sequence.
589- void releaseBlocks (GenerationRequest& sequence);
586+ // ! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
587+ void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
590588
591589 // ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
592590 void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
593591
592+ // ! \brief Update cache offsets for last block
593+ void updateLastCacheBlockOffsets (GenerationRequest& seq);
594+
594595 // ! \brief Release last block in the sequence
595596 void releaseLastBlock (GenerationRequest& sequence);
596597
598+ // ! \brief Detach front block from the sequence
599+ void detachFrontBlock (GenerationRequest& sequence, bool isEnableBlockReuse);
600+
601+ // ! \brief Add/detach block(s) to/from the sequence if needed
602+ // ! \details When we need a new block, we add it. For sliding window
603+ // ! attention (SWA), when a block goes out-of-window (OOW), we detach it
604+ // ! and store it if reuse is enabled. If this called in the first step of
605+ // ! the generation phase, we may detach more than a single block since
606+ // ! there may be more than one context block that goes OOW.
607+ void adjustBlocksIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
608+
597609 [[nodiscard]] SizeType32 getWindowSize () const noexcept
598610 {
599611 return mWindowSize ;
@@ -745,7 +757,8 @@ class WindowBlockManager
745757 // ! \brief Store blocks in cached blocks.
746758 // ! \param blockKeys Key of each block.
747759 // ! \param blockIds Id of each block.
748- void storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
760+ // ! \return Number of actual blocks stored.
761+ SizeType32 storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
749762
750763 [[nodiscard]] bool verifyQueueIntegrity ();
751764
@@ -767,6 +780,12 @@ class WindowBlockManager
767780 return 0 ;
768781 }
769782
783+ // ! \brief Return whether this window is SWA.
784+ [[nodiscard]] bool isSWA () const
785+ {
786+ return mIsSWA ;
787+ }
788+
770789private:
771790 // ! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
772791 void addBlockToBeam (BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -828,6 +847,8 @@ class WindowBlockManager
828847 SizeType32 mSchedulingNumFreeBlocks ;
829848 // Number of tokens per one block
830849 SizeType32 mTokensPerBlock ;
850+ // Whether this window is sliding window attention/full attention
851+ bool mIsSWA ;
831852 // List of all blocks by idx
832853 std::vector<BlockPtr> mAllBlocksById ;
833854 // Dummy block acting as root for BlockToken searches
@@ -880,7 +901,7 @@ class BlockManager
880901
881902 explicit BlockManager (std::vector<SizeType32> const & numKvHeadsPerLayer, SizeType32 sizePerHead,
882903 SizeType32 tokensPerBlock, BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences,
883- CudaStreamPtr stream, std::optional< SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
904+ CudaStreamPtr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth,
884905 std::vector<SizeType32> const & maxAttentionWindowVec,
885906 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
886907 SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF ,
@@ -1128,14 +1149,6 @@ class BlockManager
11281149 // ! \brief Store newest block for reuse
11291150 void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
11301151
1131- [[nodiscard]] static bool isUseOneMoreBlock (
1132- SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1133- {
1134- bool const isCyclicWindowSize = maxSequenceLength.has_value () && maxSequenceLength.value () > windowSize;
1135- bool const isBeamSearch = maxBeamWidth > 1 ;
1136- return isCyclicWindowSize && isBeamSearch;
1137- }
1138-
11391152 // ! \brief Perform per-request bookkeeping
11401153 void refreshBlocks ();
11411154
@@ -1154,12 +1167,17 @@ class BlockManager
11541167 // ! \brief Update cache offsets for blocks initiated from sequence
11551168 void updateSequenceCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
11561169
1157- // ! \brief Update cache offsets for last block
1158- void updateLastCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
1159-
11601170 // ! \brief Update cache offsets for block at index
11611171 void updateCacheBlockOffsetsAtIdx (GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11621172
1173+ // ! \brief Add/detach block(s) to/from the sequence if needed
1174+ // ! \details When we need a new block, we add it. For sliding window
1175+ // ! attention (SWA), when a block goes out-of-window (OOW), we detach it
1176+ // ! and store it if reuse is enabled. If this called in the first step of
1177+ // ! the generation phase, we may detach more than a single block since
1178+ // ! there may be more than one context block that goes OOW.
1179+ void adjustBlocksIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
1180+
11631181private:
11641182 [[nodiscard]] WindowBlockManager const & windowManagerByLayer (SizeType32 layerIdx) const
11651183 {
@@ -1411,8 +1429,8 @@ class KVCacheManager : public BaseKVCacheManager
14111429 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14121430 std::vector<SizeType32> const & maxAttentionWindowVec,
14131431 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
1414- SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional< SizeType32> maxSequenceLength,
1415- bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
1432+ SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false ,
1433+ bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
14161434 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
14171435 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
14181436 bool copyOnpartialReuse = true ,
@@ -1422,8 +1440,8 @@ class KVCacheManager : public BaseKVCacheManager
14221440 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14231441 std::vector<SizeType32> const & maxAttentionWindowVec,
14241442 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
1425- SizeType32 sinkTokenLength, int64_t stream, std::optional< SizeType32> maxSequenceLength,
1426- bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
1443+ SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false ,
1444+ bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
14271445 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
14281446 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
14291447 bool copyOnpartialReuse = true ,
@@ -1433,8 +1451,8 @@ class KVCacheManager : public BaseKVCacheManager
14331451 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14341452 std::vector<SizeType32> const & maxAttentionWindowVec,
14351453 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
1436- SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional< SizeType32> maxSequenceLength,
1437- bool enableBlockReuse = true , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
1454+ SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = true ,
1455+ bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
14381456 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
14391457 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
14401458 bool copyOnpartialReuse = true ,
@@ -1444,9 +1462,9 @@ class KVCacheManager : public BaseKVCacheManager
14441462 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14451463 std::vector<SizeType32> const & maxAttentionWindowVec,
14461464 std::optional<TempAttentionWindowInputs> const & tempAttentionWindowInputs, nvinfer1::DataType dtype,
1447- SizeType32 sinkTokenLength, int64_t stream, std::optional< SizeType32> maxSequenceLength,
1448- bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
1449- bool enablePartialReuse = true , bool copyOnpartialReuse = true );
1465+ SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false ,
1466+ bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF , bool enablePartialReuse = true ,
1467+ bool copyOnpartialReuse = true );
14501468
14511469 ~KVCacheManager () override = default ;
14521470
0 commit comments