Skip to content

Commit d751a9b

Browse files
committed
[KV cache manager] Support SWA kv cache reuse
Signed-off-by: eopXD <[email protected]>
1 parent 0202754 commit d751a9b

File tree

4 files changed

+982
-333
lines changed

4 files changed

+982
-333
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5353

5454
static constexpr SizeType32 kSecondaryLevel = 1;
5555

56+
// Extra block buffer allocated for SWA to be able to always keep "window size"
57+
// tokens held in the blocks.
58+
static constexpr SizeType32 kSWAExtraBlock = 1;
59+
5660
class KVCacheBlock;
5761
class BlockManager;
5862
class KVCacheManager;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
8892
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8993
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9094
SizeType32 numPools; // number of managed pools
91-
SizeType32 maxTokenNum; // Maximum token length (including bubble)
92-
SizeType32 maxBlocksPerSeq;
95+
SizeType32 maxTokensPerSeq; // Maximum token length per sequence
96+
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
9397
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
9498
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
9599
// Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99103
{
100104
return tensorrt_llm::common::fmtstr(
101105
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
102-
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
103-
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
104-
maxNumBlocks, temporaryAttentionWindow);
106+
".numPools=%d, .maxTokensPerSeq=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
107+
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq,
108+
maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
105109
}
106110
};
107111

@@ -335,14 +339,7 @@ class GenerationRequest
335339
, mNumTokens(numTokens)
336340
, mBeamWidth(beamWidth)
337341
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338-
// min window size + sink bubble length
339-
// Why use the minimum window size:
340-
// Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341-
// `mContextCurrentPosition` - this cannot be done for some windows sizes and
342-
// not for others, the state needs to remain identical for all window sizes. So
343-
// we currently resort to strictly disabling the reuse code path for all window
344-
// sizes at once or enable it for all window sizes at once.
345-
, mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
342+
, mNumFrontBlocksRemoved(0)
346343
{
347344
auto const numWindowSizes = windowSizeToMetadata.size();
348345
mCacheBlockIds.reserve(numWindowSizes);
@@ -385,6 +382,11 @@ class GenerationRequest
385382
return mNumTokens;
386383
}
387384

385+
[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
386+
{
387+
return mNumFrontBlocksRemoved;
388+
}
389+
388390
[[nodiscard]] SizeType32 getBeamWidth() const
389391
{
390392
return mBeamWidth;
@@ -418,6 +420,17 @@ class GenerationRequest
418420
}
419421
}
420422

423+
void removeFrontBlock(SizeType32 windowSize)
424+
{
425+
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
426+
{
427+
// Does not actually remove from mCacheBlockIds like removeLastBlock
428+
// Id is set to -1 instead.
429+
beamBlockIds[mNumFrontBlocksRemoved] = -1;
430+
}
431+
++mNumFrontBlocksRemoved;
432+
}
433+
421434
void removeLastBlock(SizeType32 windowSize)
422435
{
423436
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
@@ -436,14 +449,6 @@ class GenerationRequest
436449
return mKvCacheRetentionConfig.getDecodeDurationMs();
437450
}
438451

439-
// @brief Check whether the sequence uses cyclic KV cache.
440-
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
441-
// @details If `true`, we cannot store the sequence's KV cache for reuse.
442-
[[nodiscard]] bool isCyclic() const
443-
{
444-
return mNumTokens >= mCyclicThreshold;
445-
}
446-
447452
private:
448453
// Request id of the sequence
449454
LlmRequest::RequestIdType mRequestId;
@@ -457,9 +462,8 @@ class GenerationRequest
457462
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
458463
// The retention priority to assign to decode blocks
459464
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
460-
461-
// Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
462-
SizeType32 mCyclicThreshold;
465+
// Number of front blocks removed from the sequence
466+
SizeType32 mNumFrontBlocksRemoved;
463467
};
464468

465469
// attach metadata to a pool pointer
@@ -560,14 +564,26 @@ class WindowBlockManager
560564
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
561565

562566
//! \brief Release blocks of the sequence.
563-
void releaseBlocks(GenerationRequest& sequence);
567+
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
568+
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt);
564569

565570
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
566571
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
567572

568573
//! \brief Release last block in the sequence
569574
void releaseLastBlock(GenerationRequest& sequence);
570575

576+
//! \brief Detach block from the sequence
577+
void detachBlock(GenerationRequest& sequence, bool isEnableBlockReuse);
578+
579+
//! \brief Check and add a block to the sequence if needed.
580+
//! \details Out-of-window blocks will be detached. If reuse is enabled,
581+
//! the detached block will be stored via offload.
582+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
583+
584+
//! \brief Cache offsets for new block
585+
void cacheNewBlockOffset(GenerationRequest& sequence);
586+
571587
[[nodiscard]] SizeType32 getWindowSize() const noexcept
572588
{
573589
return mWindowSize;
@@ -578,7 +594,7 @@ class WindowBlockManager
578594
return mLogPrefix;
579595
}
580596

581-
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
597+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const noexcept;
582598

583599
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
584600
{
@@ -713,7 +729,8 @@ class WindowBlockManager
713729
//! \brief Store blocks in cached blocks.
714730
//! \param blockKeys Key of each block.
715731
//! \param blockIds Id of each block.
716-
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
732+
//! \return Number of actual blocks stored.
733+
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
717734

718735
void addBlockToHashMap(BlockPtr const& block);
719736

@@ -916,19 +933,20 @@ class BlockManager
916933

917934
void startScheduling();
918935

919-
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
936+
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize(
937+
SizeType32 cacheLevel = kPrimaryLevel) const
920938
{
921939
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
922940
for (auto const& [windowSize, manager] : mWindowBlockManagers)
923941
{
924-
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks();
942+
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks(cacheLevel);
925943
}
926944
return numFreeBlocksPerWindowSize;
927945
}
928946

929-
[[nodiscard]] SizeType32 getNumFreeBlocks() const
947+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const
930948
{
931-
return sumWindows([](auto const& manager) { return manager.getNumFreeBlocks(); });
949+
return sumWindows([cacheLevel](auto const& manager) { return manager.getNumFreeBlocks(cacheLevel); });
932950
}
933951

934952
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType32 numRequired, SizeType32 windowSize) const
@@ -1102,12 +1120,10 @@ class BlockManager
11021120
//! \brief Store newest block for reuse
11031121
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
11041122

1105-
[[nodiscard]] static bool isUseOneMoreBlock(
1106-
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1123+
[[nodiscard]] static bool isUseOneMoreBlock()
11071124
{
1108-
bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize;
1109-
bool const isBeamSearch = maxBeamWidth > 1;
1110-
return isCyclicWindowSize && isBeamSearch;
1125+
//
1126+
return false;
11111127
}
11121128

11131129
//! \brief Perform per-request bookkeeping
@@ -1128,8 +1144,8 @@ class BlockManager
11281144
//! \brief Cache offsets for blocks initiated from sequence
11291145
void cacheSequenceBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize);
11301146

1131-
//! \brief Cache offsets for new block
1132-
void cacheNewBlockOffset(GenerationRequest& sequence, SizeType32 windowSize);
1147+
//! \brief Add block to the sequence if needed
1148+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
11331149

11341150
private:
11351151
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const

0 commit comments

Comments
 (0)