Skip to content

Commit 44643d2

Browse files
committed
[KV cache manager] Support SWA KV cache reuse
Signed-off-by: eopXD <[email protected]>
1 parent 6eda9e8 commit 44643d2

File tree

5 files changed

+968
-329
lines changed

5 files changed

+968
-329
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5353

5454
static constexpr SizeType32 kSecondaryLevel = 1;
5555

56+
// Extra block buffer allocated for SWA to be able to always keep "window size"
57+
// tokens held in the blocks.
58+
static constexpr SizeType32 kSWAExtraBlock = 1;
59+
5660
class KVCacheBlock;
5761
class BlockManager;
5862
class KVCacheManager;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
8892
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8993
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9094
SizeType32 numPools; // number of managed pools
91-
SizeType32 maxTokenNum; // Maximum token length (including bubble)
92-
SizeType32 maxBlocksPerSeq;
95+
SizeType32 maxTokensPerSeq; // Maximum token length per sequence (TODO: account for streamLLM)
96+
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
9397
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
9498
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
9599
// Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99103
{
100104
return tensorrt_llm::common::fmtstr(
101105
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
102-
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
103-
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
104-
maxNumBlocks, temporaryAttentionWindow);
106+
".numPools=%d, .maxTokensPerSeq=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
107+
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq,
108+
maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
105109
}
106110
};
107111

@@ -203,6 +207,7 @@ class KVCacheBlock
203207
using IdType = std::int32_t;
204208

205209
static constexpr IdType kCachedBlocksRootId = -1;
210+
static constexpr IdType kInvalidBlockId = -2;
206211

207212
explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx);
208213

@@ -335,14 +340,7 @@ class GenerationRequest
335340
, mNumTokens(numTokens)
336341
, mBeamWidth(beamWidth)
337342
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338-
// min window size + sink bubble length
339-
// Why use the minimum window size:
340-
// Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341-
// `mContextCurrentPosition` - this cannot be done for some windows sizes and
342-
// not for others, the state needs to remain identical for all window sizes. So
343-
// we currently resort to strictly disabling the reuse code path for all window
344-
// sizes at once or enable it for all window sizes at once.
345-
, mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
343+
, mNumFrontBlocksRemoved(0)
346344
{
347345
auto const numWindowSizes = windowSizeToMetadata.size();
348346
mCacheBlockIds.reserve(numWindowSizes);
@@ -385,6 +383,11 @@ class GenerationRequest
385383
return mNumTokens;
386384
}
387385

386+
[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
387+
{
388+
return mNumFrontBlocksRemoved;
389+
}
390+
388391
[[nodiscard]] SizeType32 getBeamWidth() const
389392
{
390393
return mBeamWidth;
@@ -422,6 +425,26 @@ class GenerationRequest
422425
{
423426
beamBlockIds.clear();
424427
}
428+
mNumFrontBlocksRemoved = 0;
429+
}
430+
431+
void removeFrontBlock(SizeType32 windowSize)
432+
{
433+
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
434+
{
435+
if (mNumFrontBlocksRemoved < static_cast<SizeType32>(beamBlockIds.size()))
436+
{
437+
// Doesn't actually remove from mCacheBlockIds like removeLastBlock,
438+
// block id is set to -1 instead because we preserve the blocks
439+
// for reuse when reuse is enabled.
440+
beamBlockIds[mNumFrontBlocksRemoved] = KVCacheBlock::kInvalidBlockId;
441+
}
442+
else
443+
{
444+
TLLM_LOG_WARNING("RequestID %d: removeFrontBlock called but nothing to remove", mRequestId);
445+
}
446+
}
447+
++mNumFrontBlocksRemoved;
425448
}
426449

427450
void removeLastBlock(SizeType32 windowSize)
@@ -442,14 +465,6 @@ class GenerationRequest
442465
return mKvCacheRetentionConfig.getDecodeDurationMs();
443466
}
444467

445-
// @brief Check whether the sequence uses cyclic KV cache.
446-
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
447-
// @details If `true`, we cannot store the sequence's KV cache for reuse.
448-
[[nodiscard]] bool isCyclic() const
449-
{
450-
return mNumTokens >= mCyclicThreshold;
451-
}
452-
453468
private:
454469
// Request id of the sequence
455470
LlmRequest::RequestIdType mRequestId;
@@ -463,9 +478,8 @@ class GenerationRequest
463478
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
464479
// The retention priority to assign to decode blocks
465480
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
466-
467-
// Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
468-
SizeType32 mCyclicThreshold;
481+
// Number of front blocks removed from the sequence
482+
SizeType32 mNumFrontBlocksRemoved;
469483
};
470484

471485
// attach metadata to a pool pointer
@@ -533,7 +547,7 @@ class WindowBlockManager
533547

534548
explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
535549
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
536-
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
550+
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
537551
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538552
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539553
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
@@ -567,14 +581,26 @@ class WindowBlockManager
567581
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
568582

569583
//! \brief Release blocks of the sequence.
570-
void releaseBlocks(GenerationRequest& sequence);
584+
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
585+
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt);
571586

572587
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
573588
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
574589

590+
//! \brief Update cache offsets for last block
591+
void updateLastCacheBlockOffsets(GenerationRequest& seq);
592+
575593
//! \brief Release last block in the sequence
576594
void releaseLastBlock(GenerationRequest& sequence);
577595

596+
//! \brief Detach block from the sequence
597+
void detachBlock(GenerationRequest& sequence, bool isEnableBlockReuse);
598+
599+
//! \brief Check and add a block to the sequence if needed.
600+
//! \details Out-of-window blocks will be detached. If reuse is enabled,
601+
//! the detached block will be stored via offload.
602+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
603+
578604
[[nodiscard]] SizeType32 getWindowSize() const noexcept
579605
{
580606
return mWindowSize;
@@ -585,7 +611,7 @@ class WindowBlockManager
585611
return mLogPrefix;
586612
}
587613

588-
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
614+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const noexcept;
589615

590616
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
591617
{
@@ -715,7 +741,8 @@ class WindowBlockManager
715741
//! \brief Store blocks in cached blocks.
716742
//! \param blockKeys Key of each block.
717743
//! \param blockIds Id of each block.
718-
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
744+
//! \return Number of actual blocks stored.
745+
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
719746

720747
[[nodiscard]] bool verifyQueueIntegrity();
721748

@@ -796,6 +823,8 @@ class WindowBlockManager
796823
SizeType32 mSchedulingNumFreeBlocks;
797824
// Number of tokens per one block
798825
SizeType32 mTokensPerBlock;
826+
// Whether this window is SWA
827+
bool mIsSWA;
799828
// List of all blocks by idx
800829
std::vector<BlockPtr> mAllBlocksById;
801830
// Dummy block acting as root for BlockToken searches
@@ -917,19 +946,20 @@ class BlockManager
917946

918947
void startScheduling();
919948

920-
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
949+
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize(
950+
SizeType32 cacheLevel = kPrimaryLevel) const
921951
{
922952
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
923953
for (auto const& [windowSize, manager] : mWindowBlockManagers)
924954
{
925-
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks();
955+
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks(cacheLevel);
926956
}
927957
return numFreeBlocksPerWindowSize;
928958
}
929959

930-
[[nodiscard]] SizeType32 getNumFreeBlocks() const
960+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const
931961
{
932-
return sumWindows([](auto const& manager) { return manager.getNumFreeBlocks(); });
962+
return sumWindows([cacheLevel](auto const& manager) { return manager.getNumFreeBlocks(cacheLevel); });
933963
}
934964

935965
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType32 numRequired, SizeType32 windowSize) const
@@ -1088,14 +1118,6 @@ class BlockManager
10881118
//! \brief Store newest block for reuse
10891119
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
10901120

1091-
[[nodiscard]] static bool isUseOneMoreBlock(
1092-
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1093-
{
1094-
bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize;
1095-
bool const isBeamSearch = maxBeamWidth > 1;
1096-
return isCyclicWindowSize && isBeamSearch;
1097-
}
1098-
10991121
//! \brief Perform per-request bookkeeping
11001122
void refreshBlocks();
11011123

@@ -1114,12 +1136,12 @@ class BlockManager
11141136
//! \brief Update cache offsets for blocks initiated from sequence
11151137
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
11161138

1117-
//! \brief Update cache offsets for last block
1118-
void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
1119-
11201139
//! \brief Update cache offsets for block at index
11211140
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11221141

1142+
//! \brief Add block to the sequence if needed
1143+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
1144+
11231145
private:
11241146
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
11251147
{

0 commit comments

Comments
 (0)