@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5353
5454static  constexpr  SizeType32 kSecondaryLevel  = 1 ;
5555
56+ //  Extra block buffer allocated for SWA to be able to always keep "window size"
57+ //  tokens held in the blocks.
58+ static  constexpr  SizeType32 kSWAExtraBlock  = 1 ;
59+ 
5660class  KVCacheBlock ;
5761class  BlockManager ;
5862class  KVCacheManager ;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
8892    SizeType32 allottedSecondaryBlocks;  //  Number of secondary blocks allotted to the windowSize
8993    SizeType32 absolutePoolsOffset;      //  cumulative number of pools up to manager
9094    SizeType32 numPools;                 //  number of managed pools
91-     SizeType32 maxTokenNum ;               //  Maximum token length (including bubble )
92-     SizeType32 maxBlocksPerSeq;
95+     SizeType32 maxTokensPerSeq ;          //  Maximum token length per sequence (TODO: account for streamLLM )
96+     SizeType32 maxBlocksPerSeq;           //  Maximum number of blocks per sequence 
9397    SizeType32 maxNumBlocks;             //  Number of primary+secondary blocks allotted to the windowSize
9498    SizeType32 temporaryAttentionWindow; //  Temporary kv cache length per sequence.
9599                                         //  Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99103    {
100104        return  tensorrt_llm::common::fmtstr (
101105            " WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, " 
102-             " .numPools=%d, .maxTokenNum =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" 
103-             allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq ,
104-             maxNumBlocks, temporaryAttentionWindow);
106+             " .numPools=%d, .maxTokensPerSeq =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" 
107+             allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq ,
108+             maxBlocksPerSeq,  maxNumBlocks, temporaryAttentionWindow);
105109    }
106110};
107111
@@ -203,6 +207,7 @@ class KVCacheBlock
203207    using  IdType = std::int32_t ;
204208
205209    static  constexpr  IdType kCachedBlocksRootId  = -1 ;
210+     static  constexpr  IdType kInvalidBlockId  = -2 ;
206211
207212    explicit  KVCacheBlock (IdType blockId, kernels::KVCacheIndex blockIdx);
208213
@@ -335,14 +340,7 @@ class GenerationRequest
335340        , mNumTokens(numTokens)
336341        , mBeamWidth(beamWidth)
337342        , mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338-         //  min window size + sink bubble length
339-         //  Why use the minimum window size:
340-         //  Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341-         //  `mContextCurrentPosition` - this cannot be done for some windows sizes and
342-         //  not for others, the state needs to remain identical for all window sizes. So
343-         //  we currently resort to strictly disabling the reuse code path for all window
344-         //  sizes at once or enable it for all window sizes at once.
345-         , mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
343+         , mNumFrontBlocksRemoved(0 )
346344    {
347345        auto  const  numWindowSizes = windowSizeToMetadata.size ();
348346        mCacheBlockIds .reserve (numWindowSizes);
@@ -385,6 +383,11 @@ class GenerationRequest
385383        return  mNumTokens ;
386384    }
387385
386+     [[nodiscard]] SizeType32 getNumFrontBlocksRemoved () const 
387+     {
388+         return  mNumFrontBlocksRemoved ;
389+     }
390+ 
388391    [[nodiscard]] SizeType32 getBeamWidth () const 
389392    {
390393        return  mBeamWidth ;
@@ -422,6 +425,26 @@ class GenerationRequest
422425        {
423426            beamBlockIds.clear ();
424427        }
428+         mNumFrontBlocksRemoved  = 0 ;
429+     }
430+ 
431+     void  removeFrontBlock (SizeType32 windowSize)
432+     {
433+         for  (auto & beamBlockIds : mCacheBlockIds .at (windowSize))
434+         {
435+             if  (mNumFrontBlocksRemoved  < static_cast <SizeType32>(beamBlockIds.size ()))
436+             {
437+                 //  Doesn't actually remove from mCacheBlockIds like removeLastBlock,
438+                 //  block id is set to -1 instead because we preserve the blocks
439+                 //  for reuse when reuse is enabled.
440+                 beamBlockIds[mNumFrontBlocksRemoved ] = KVCacheBlock::kInvalidBlockId ;
441+             }
442+             else 
443+             {
444+                 TLLM_LOG_WARNING (" RequestID %lu: removeFrontBlock called but nothing to remove" mRequestId );
445+             }
446+         }
447+         ++mNumFrontBlocksRemoved ;
425448    }
426449
427450    void  removeLastBlock (SizeType32 windowSize)
@@ -442,14 +465,6 @@ class GenerationRequest
442465        return  mKvCacheRetentionConfig .getDecodeDurationMs ();
443466    }
444467
445-     //  @brief Check whether the sequence uses cyclic KV cache.
446-     //  @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
447-     //  @details If `true`, we cannot store the sequence's KV cache for reuse.
448-     [[nodiscard]] bool  isCyclic () const 
449-     {
450-         return  mNumTokens  >= mCyclicThreshold ;
451-     }
452- 
453468private: 
454469    //  Request id of the sequence
455470    LlmRequest::RequestIdType mRequestId ;
@@ -463,9 +478,8 @@ class GenerationRequest
463478    std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices ;
464479    //  The retention priority to assign to decode blocks
465480    executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
466- 
467-     //  Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
468-     SizeType32 mCyclicThreshold ;
481+     //  Number of front blocks removed from the sequence
482+     SizeType32 mNumFrontBlocksRemoved ;
469483};
470484
471485//  attach metadata to a pool pointer
@@ -533,7 +547,7 @@ class WindowBlockManager
533547
534548    explicit  WindowBlockManager (nvinfer1::DataType dtype, SizeType32 windowSize,
535549        std::vector<SizeType32> const & managedLayers, std::vector<SizeType32> const & numKvHeadsPerLayer,
536-         SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
550+         SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool  isSWA,  SizeType32 blocksInPrimaryPool,
537551        SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538552        bool  onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539553        std::shared_ptr<KVCacheEventManager> eventManager, bool  enablePartialReuse, bool  copyOnPartialReuse);
@@ -567,14 +581,26 @@ class WindowBlockManager
567581    void  storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
568582
569583    // ! \brief Release blocks of the sequence.
570-     void  releaseBlocks (GenerationRequest& sequence);
584+     // ! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
585+     void  releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt );
571586
572587    // ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
573588    void  schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
574589
590+     // ! \brief Update cache offsets for last block
591+     void  updateLastCacheBlockOffsets (GenerationRequest& seq);
592+ 
575593    // ! \brief Release last block in the sequence
576594    void  releaseLastBlock (GenerationRequest& sequence);
577595
596+     // ! \brief Detach block from the sequence
597+     void  detachBlock (GenerationRequest& sequence, bool  isEnableBlockReuse);
598+ 
599+     // ! \brief Check and add a block to the sequence if needed.
600+     // ! \details Out-of-window blocks will be detached. If reuse is enabled,
601+     // ! the detached block will be stored via offload.
602+     void  addBlockIfNeeded (GenerationRequest& sequence, bool  isEnableBlockReuse);
603+ 
578604    [[nodiscard]] SizeType32 getWindowSize () const  noexcept 
579605    {
580606        return  mWindowSize ;
@@ -585,7 +611,7 @@ class WindowBlockManager
585611        return  mLogPrefix ;
586612    }
587613
588-     [[nodiscard]] SizeType32 getNumFreeBlocks () const  noexcept ;
614+     [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel =  kPrimaryLevel ) const  noexcept ;
589615
590616    [[nodiscard]] SizeType32 getNumAllocTotalBlocks () const 
591617    {
@@ -715,7 +741,8 @@ class WindowBlockManager
715741    // ! \brief Store blocks in cached blocks.
716742    // ! \param blockKeys Key of each block.
717743    // ! \param blockIds Id of each block.
718-     void  storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
744+     // ! \return Number of actual blocks stored.
745+     SizeType32 storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
719746
720747    [[nodiscard]] bool  verifyQueueIntegrity ();
721748
@@ -796,6 +823,8 @@ class WindowBlockManager
796823    SizeType32 mSchedulingNumFreeBlocks ;
797824    //  Number of tokens per one block
798825    SizeType32 mTokensPerBlock ;
826+     //  Whether this window is sliding window attention/full attention
827+     bool  mIsSWA ;
799828    //  List of all blocks by idx
800829    std::vector<BlockPtr> mAllBlocksById ;
801830    //  Dummy block acting as root for BlockToken searches
@@ -917,19 +946,20 @@ class BlockManager
917946
918947    void  startScheduling ();
919948
920-     [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize () const 
949+     [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize (
950+         SizeType32 cacheLevel = kPrimaryLevel ) const 
921951    {
922952        std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
923953        for  (auto  const & [windowSize, manager] : mWindowBlockManagers )
924954        {
925-             numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks ();
955+             numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks (cacheLevel );
926956        }
927957        return  numFreeBlocksPerWindowSize;
928958    }
929959
930-     [[nodiscard]] SizeType32 getNumFreeBlocks () const 
960+     [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel =  kPrimaryLevel ) const 
931961    {
932-         return  sumWindows ([](auto  const & manager) { return  manager.getNumFreeBlocks (); });
962+         return  sumWindows ([cacheLevel ](auto  const & manager) { return  manager.getNumFreeBlocks (cacheLevel ); });
933963    }
934964
935965    [[nodiscard]] bool  schedulingHasFreeBlocks (SizeType32 numRequired, SizeType32 windowSize) const 
@@ -1088,14 +1118,6 @@ class BlockManager
10881118    // ! \brief Store newest block for reuse
10891119    void  storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
10901120
1091-     [[nodiscard]] static  bool  isUseOneMoreBlock (
1092-         SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1093-     {
1094-         bool  const  isCyclicWindowSize = maxSequenceLength.has_value () && maxSequenceLength.value () > windowSize;
1095-         bool  const  isBeamSearch = maxBeamWidth > 1 ;
1096-         return  isCyclicWindowSize && isBeamSearch;
1097-     }
1098- 
10991121    // ! \brief Perform per-request bookkeeping
11001122    void  refreshBlocks ();
11011123
@@ -1114,12 +1136,12 @@ class BlockManager
11141136    // ! \brief Update cache offsets for blocks initiated from sequence
11151137    void  updateSequenceCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
11161138
1117-     // ! \brief Update cache offsets for last block
1118-     void  updateLastCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
1119- 
11201139    // ! \brief Update cache offsets for block at index
11211140    void  updateCacheBlockOffsetsAtIdx (GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11221141
1142+     // ! \brief Add block to the sequence if needed
1143+     void  addBlockIfNeeded (GenerationRequest& sequence, bool  isEnableBlockReuse);
1144+ 
11231145private: 
11241146    [[nodiscard]] WindowBlockManager const & windowManagerByLayer (SizeType32 layerIdx) const 
11251147    {
0 commit comments