Skip to content

Commit cf10093

Browse files
eopXDtomeras91
andauthored
[TRTLLM-6341][feature] Support SWA KV cache reuse (#6768)
This merge request attempts to support more SWA KV cache functionality inside the KV cache manager. Before this merge request, the KV cache for sliding window attention (SWA) only holds "window size" number of blocks and reuse them in a cyclic manner. We will not be able to utilize more GPU memory with this design, leading to a limited max batch size throughput. Additionally, we will not be able to support KV cache reuse with this design. In this MR, we change such behavior to let the manager write blocks in a linear manner. With a linear block writing behavior, as the attention window moves on, the out-of-window (OOW) blocks will be detached. Right now for the sake of a correct feature first, we directly offload the OOW block from the primary block pool (GPU memory) to the secondary block pool (host memory). We will improve this in the future by delegating the block movement to the eviction policy. KV cache reuse for SWA is not developed in this merge request and will be amended in a follow-up merge request. Writing the blocks linearly, the maximum number of blocks allocated for a sequence(`GenerationRequest`) is the "max sequence length" specified. The `GenerationRequest` that stores the cache block bookkeeping structure will now keep "max sequence length" tokens of blocks. Given the above, main changes are (more context in the MR): - Remove "cyclic" concept under the kv cache manager, such concept originally guards the block reuse under kv cache manager. - Add detach mechanism and have it under `KVCacheManager::addToken`. Please note that detach is still guarded off for SWA when reuse is enabled. A follow-up merge request will proceed to improve this. - Enforce "max sequence length" to be a non-optional parameter to the `KVCacheManager`/`BlockManager` - Let all window size resource pool get identical proportion of memory - Fix free memory calculation under `resource_manager.py` Signed-off-by: eopXD <[email protected]> Co-authored-by: Tomer Asida <[email protected]>
1 parent 5ccb2de commit cf10093

File tree

16 files changed

+874
-527
lines changed

16 files changed

+874
-527
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5757

5858
static constexpr SizeType32 kSecondaryLevel = 1;
5959

60+
// Extra block buffer allocated for SWA to be able to always keep "window size"
61+
// tokens held in the blocks.
62+
static constexpr SizeType32 kSWAExtraBlock = 1;
63+
6064
class KVCacheBlock;
6165
class BlockManager;
6266
class KVCacheManager;
@@ -93,8 +97,8 @@ struct WindowSizeMetadata
9397
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
9498
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9599
SizeType32 numPools; // number of managed pools
96-
SizeType32 maxTokenNum; // Maximum token length (including bubble)
97-
SizeType32 maxBlocksPerSeq;
100+
SizeType32 maxTokenNum; // Maximum token length per sequence (TODO: account for streamLLM)
101+
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
98102
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
99103
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
100104
// Only needed when chunked context + sliding window attention are used
@@ -344,14 +348,7 @@ class GenerationRequest
344348
, mNumTokens(numTokens)
345349
, mBeamWidth(beamWidth)
346350
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
347-
// min window size + sink bubble length
348-
// Why use the minimum window size:
349-
// Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
350-
// `mContextCurrentPosition` - this cannot be done for some windows sizes and
351-
// not for others, the state needs to remain identical for all window sizes. So
352-
// we currently resort to strictly disabling the reuse code path for all window
353-
// sizes at once or enable it for all window sizes at once.
354-
, mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
351+
, mNumFrontBlocksRemoved(0)
355352
{
356353
auto const numWindowSizes = windowSizeToMetadata.size();
357354
mCacheBlockIds.reserve(numWindowSizes);
@@ -394,6 +391,11 @@ class GenerationRequest
394391
return mNumTokens;
395392
}
396393

394+
[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
395+
{
396+
return mNumFrontBlocksRemoved;
397+
}
398+
397399
[[nodiscard]] SizeType32 getBeamWidth() const
398400
{
399401
return mBeamWidth;
@@ -431,6 +433,12 @@ class GenerationRequest
431433
{
432434
beamBlockIds.clear();
433435
}
436+
mNumFrontBlocksRemoved = 0;
437+
}
438+
439+
void removeFrontBlock(SizeType32 windowSize)
440+
{
441+
++mNumFrontBlocksRemoved;
434442
}
435443

436444
void removeLastBlock(SizeType32 windowSize)
@@ -461,14 +469,6 @@ class GenerationRequest
461469
return mKvCacheRetentionConfig.getDirectory();
462470
}
463471

464-
// @brief Check whether the sequence uses cyclic KV cache.
465-
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
466-
// @details If `true`, we cannot store the sequence's KV cache for reuse.
467-
[[nodiscard]] bool isCyclic() const
468-
{
469-
return mNumTokens >= mCyclicThreshold;
470-
}
471-
472472
private:
473473
// Request id of the sequence
474474
LlmRequest::RequestIdType mRequestId;
@@ -482,9 +482,8 @@ class GenerationRequest
482482
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
483483
// The retention priority to assign to decode blocks
484484
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
485-
486-
// Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
487-
SizeType32 mCyclicThreshold;
485+
// Number of front blocks removed from the sequence
486+
SizeType32 mNumFrontBlocksRemoved;
488487
};
489488

490489
// attach metadata to a pool pointer
@@ -550,7 +549,7 @@ class WindowBlockManager
550549

551550
explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
552551
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
553-
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
552+
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
554553
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
555554
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
556555
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
@@ -581,19 +580,32 @@ class WindowBlockManager
581580
//! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
582581
std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const;
583582

584-
void storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
585-
586583
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
587584

588585
//! \brief Release blocks of the sequence.
589-
void releaseBlocks(GenerationRequest& sequence);
586+
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
587+
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
590588

591589
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
592590
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
593591

592+
//! \brief Update cache offsets for last block
593+
void updateLastCacheBlockOffsets(GenerationRequest& seq);
594+
594595
//! \brief Release last block in the sequence
595596
void releaseLastBlock(GenerationRequest& sequence);
596597

598+
//! \brief Detach front block from the sequence
599+
void detachFrontBlock(GenerationRequest& sequence, bool isEnableBlockReuse);
600+
601+
//! \brief Add/detach block(s) to/from the sequence if needed
602+
//! \details When we need a new block, we add it. For sliding window
603+
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
604+
//! and store it if reuse is enabled. If this called in the first step of
605+
//! the generation phase, we may detach more than a single block since
606+
//! there may be more than one context block that goes OOW.
607+
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
608+
597609
[[nodiscard]] SizeType32 getWindowSize() const noexcept
598610
{
599611
return mWindowSize;
@@ -745,7 +757,8 @@ class WindowBlockManager
745757
//! \brief Store blocks in cached blocks.
746758
//! \param blockKeys Key of each block.
747759
//! \param blockIds Id of each block.
748-
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
760+
//! \return Number of actual blocks stored.
761+
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
749762

750763
[[nodiscard]] bool verifyQueueIntegrity();
751764

@@ -767,6 +780,12 @@ class WindowBlockManager
767780
return 0;
768781
}
769782

783+
//! \brief Return whether this window is SWA.
784+
[[nodiscard]] bool isSWA() const
785+
{
786+
return mIsSWA;
787+
}
788+
770789
private:
771790
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
772791
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -828,6 +847,8 @@ class WindowBlockManager
828847
SizeType32 mSchedulingNumFreeBlocks;
829848
// Number of tokens per one block
830849
SizeType32 mTokensPerBlock;
850+
// Whether this window is sliding window attention/full attention
851+
bool mIsSWA;
831852
// List of all blocks by idx
832853
std::vector<BlockPtr> mAllBlocksById;
833854
// Dummy block acting as root for BlockToken searches
@@ -880,7 +901,7 @@ class BlockManager
880901

881902
explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
882903
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
883-
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
904+
CudaStreamPtr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth,
884905
std::vector<SizeType32> const& maxAttentionWindowVec,
885906
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
886907
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
@@ -1128,14 +1149,6 @@ class BlockManager
11281149
//! \brief Store newest block for reuse
11291150
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
11301151

1131-
[[nodiscard]] static bool isUseOneMoreBlock(
1132-
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1133-
{
1134-
bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize;
1135-
bool const isBeamSearch = maxBeamWidth > 1;
1136-
return isCyclicWindowSize && isBeamSearch;
1137-
}
1138-
11391152
//! \brief Perform per-request bookkeeping
11401153
void refreshBlocks();
11411154

@@ -1154,12 +1167,17 @@ class BlockManager
11541167
//! \brief Update cache offsets for blocks initiated from sequence
11551168
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
11561169

1157-
//! \brief Update cache offsets for last block
1158-
void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
1159-
11601170
//! \brief Update cache offsets for block at index
11611171
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11621172

1173+
//! \brief Add/detach block(s) to/from the sequence if needed
1174+
//! \details When we need a new block, we add it. For sliding window
1175+
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
1176+
//! and store it if reuse is enabled. If this called in the first step of
1177+
//! the generation phase, we may detach more than a single block since
1178+
//! there may be more than one context block that goes OOW.
1179+
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
1180+
11631181
private:
11641182
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
11651183
{
@@ -1411,8 +1429,8 @@ class KVCacheManager : public BaseKVCacheManager
14111429
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14121430
std::vector<SizeType32> const& maxAttentionWindowVec,
14131431
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
1414-
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
1415-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1432+
SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
1433+
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
14161434
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
14171435
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
14181436
bool copyOnpartialReuse = true,
@@ -1422,8 +1440,8 @@ class KVCacheManager : public BaseKVCacheManager
14221440
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14231441
std::vector<SizeType32> const& maxAttentionWindowVec,
14241442
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
1425-
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
1426-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1443+
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
1444+
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
14271445
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
14281446
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
14291447
bool copyOnpartialReuse = true,
@@ -1433,8 +1451,8 @@ class KVCacheManager : public BaseKVCacheManager
14331451
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14341452
std::vector<SizeType32> const& maxAttentionWindowVec,
14351453
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
1436-
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
1437-
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1454+
SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = true,
1455+
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
14381456
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
14391457
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
14401458
bool copyOnpartialReuse = true,
@@ -1444,9 +1462,9 @@ class KVCacheManager : public BaseKVCacheManager
14441462
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
14451463
std::vector<SizeType32> const& maxAttentionWindowVec,
14461464
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
1447-
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
1448-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1449-
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
1465+
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
1466+
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
1467+
bool copyOnpartialReuse = true);
14501468

14511469
~KVCacheManager() override = default;
14521470

0 commit comments

Comments
 (0)