Skip to content

Commit 3328235

Browse files
authored
[TRTLLM-6106][feat] Add support for KVCache transfer from KVCache reuse path (#6348)
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent a36b48b commit 3328235

File tree

34 files changed

+1046
-360
lines changed

34 files changed

+1046
-360
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <optional>
4242
#include <set>
4343
#include <unordered_map>
44+
#include <utility>
4445
#include <vector>
4546

4647
namespace kvc = tensorrt_llm::executor::kv_cache;
@@ -84,6 +85,32 @@ using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
8485
template <typename T>
8586
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
8687

88+
//! \brief Split vector into list of blocks of given size.
89+
//! \param vec vector to split
90+
//! \param usableSize part of the vector that is processed
91+
//! \param elementsPerBlock desired size of blocks
92+
//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end
93+
//! \return list of blocks
94+
template <typename T>
95+
std::list<std::vector<T>> chopVectorIntoBlocks(
96+
std::vector<T> const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial)
97+
{
98+
TLLM_CHECK_WITH_INFO(
99+
usableSize <= static_cast<SizeType32>(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size());
100+
std::list<std::vector<T>> blockedVectors;
101+
auto const vecEnd = vec.begin() + usableSize;
102+
for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock)
103+
{
104+
auto blockSize = std::min(elementsPerBlock, static_cast<SizeType32>(std::distance(begin, vecEnd)));
105+
auto end = begin + blockSize;
106+
if (blockSize == elementsPerBlock || allowPartial)
107+
{
108+
blockedVectors.emplace_back(begin, end);
109+
}
110+
}
111+
return blockedVectors;
112+
}
113+
87114
struct TempAttentionWindowInputs
88115
{
89116
bool pagedContextFMHA;
@@ -114,6 +141,9 @@ struct WindowSizeMetadata
114141
}
115142
};
116143

144+
std::vector<MmKey> generateBlockHashExtraKeys(
145+
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx);
146+
117147
struct BlockKey
118148
{
119149
bool usesExtraIds = false;
@@ -147,11 +177,7 @@ struct BlockKey
147177
{
148178
}
149179

150-
bool operator==(BlockKey const& other) const noexcept
151-
{
152-
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
153-
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
154-
}
180+
bool operator==(BlockKey const& other) const noexcept;
155181

156182
int partialMatch(BlockKey const& other) const noexcept
157183
{
@@ -166,6 +192,8 @@ struct BlockKey
166192
}
167193
};
168194

195+
std::vector<BlockKey> buildBlockKeys(std::list<VecUniqueTokens>& blockedUniqueTokens, LlmRequest const& llmRequest);
196+
169197
// Implement hash functor for BlockKey.
170198
// This allows us to use unordered_map with BlockKey as key.
171199
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
@@ -577,14 +605,18 @@ class WindowBlockManager
577605

578606
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);
579607

580-
//! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
581-
std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const;
608+
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
609+
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);
582610

583611
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
584612

613+
//! \brief Pin blocks associated with a sequence to prevent eviction.
614+
void pinBlocks(GenerationRequest& sequence);
615+
585616
//! \brief Release blocks of the sequence.
586617
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
587-
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
618+
std::optional<KVCacheBlock::IdType> releaseBlocks(
619+
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
588620

589621
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
590622
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
@@ -757,8 +789,11 @@ class WindowBlockManager
757789
//! \brief Store blocks in cached blocks.
758790
//! \param blockKeys Key of each block.
759791
//! \param blockIds Id of each block.
760-
//! \return Number of actual blocks stored.
761-
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
792+
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
793+
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
794+
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
795+
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
796+
bool pinBlocks = false);
762797

763798
[[nodiscard]] bool verifyQueueIntegrity();
764799

@@ -786,6 +821,11 @@ class WindowBlockManager
786821
return mIsSWA;
787822
}
788823

824+
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);
825+
826+
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
827+
void unpinBlocksById(KVCacheBlock::IdType blockId);
828+
789829
private:
790830
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
791831
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -890,6 +930,9 @@ class WindowBlockManager
890930
bool mCopyOnPartialReuse;
891931
// The kv cache connector manager
892932
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
933+
934+
// Mutex for the cached blocks root
935+
std::mutex mCachedBlocksRootMutex;
893936
};
894937

895938
class BlockManager
@@ -940,13 +983,20 @@ class BlockManager
940983

941984
void replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx);
942985

943-
std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(
944-
GenerationRequest const& sequence, SizeType32 windowSize) const;
986+
std::optional<KVCacheBlock::IdType> releaseBlocks(
987+
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
945988

946-
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt);
989+
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
990+
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);
947991

948992
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
949993

994+
/// @brief Pin all blocks associated with a sequence across all window managers.
995+
/// @param sequence The generation request whose blocks should be pinned.
996+
void pinBlocks(GenerationRequest& sequence);
997+
998+
void unpinBlocksById(KVCacheBlock::IdType blockId);
999+
9501000
void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);
9511001

9521002
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
@@ -966,10 +1016,11 @@ class BlockManager
9661016
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
9671017
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
9681018

969-
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
970-
SizeType32 windowSize)
1019+
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
1020+
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
1021+
SizeType32 windowSize, bool pinBlocks = false)
9711022
{
972-
mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds);
1023+
return mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds, pinBlocks);
9731024
}
9741025

9751026
[[nodiscard]] bool verifyQueueIntegrity(SizeType32 windowSize);
@@ -1003,6 +1054,15 @@ class BlockManager
10031054
return sumWindows([](auto const& manager) { return manager.getNumAllocTotalBlocks(); });
10041055
}
10051056

1057+
[[nodiscard]] SizeType32 getFirstWindowSize() const
1058+
{
1059+
if (mWindowBlockManagers.empty())
1060+
{
1061+
return 0;
1062+
}
1063+
return mWindowBlockManagers.begin()->first;
1064+
}
1065+
10061066
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
10071067
{
10081068
return sumWindows([](auto const& manager) { return manager.getNumAllocNewBlocks(); });
@@ -1133,6 +1193,12 @@ class BlockManager
11331193
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
11341194
}
11351195

1196+
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
1197+
BlockKey const& blockKey, SizeType32 windowSize)
1198+
{
1199+
return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey);
1200+
}
1201+
11361202
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
11371203
{
11381204
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
@@ -1274,6 +1340,10 @@ class BaseKVCacheManager
12741340
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
12751341
= 0;
12761342

1343+
/// @brief Pin blocks associated with a request to prevent eviction.
1344+
/// @param requestId The ID of the request whose blocks should be pinned.
1345+
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;
1346+
12771347
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
12781348
virtual void addToken(LlmRequest::RequestIdType requestId) = 0;
12791349

@@ -1287,8 +1357,8 @@ class BaseKVCacheManager
12871357
OptionalRef<LlmRequest> llmRequest = std::nullopt)
12881358
= 0;
12891359

1290-
virtual void removeSequence(
1291-
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest = std::nullopt)
1360+
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
1361+
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
12921362
= 0;
12931363

12941364
virtual void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) = 0;
@@ -1332,6 +1402,11 @@ class BaseKVCacheManager
13321402
//! \details This block become reusable from next step.
13331403
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;
13341404

1405+
/// \brief Store blocks for reuse for a given request id
1406+
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
1407+
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
1408+
= 0;
1409+
13351410
//! \brief Get the block ids of a request [per beam] **for a given window size block manager**
13361411
[[nodiscard]] virtual std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
13371412
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
@@ -1342,8 +1417,8 @@ class BaseKVCacheManager
13421417
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const
13431418
= 0;
13441419

1345-
[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(
1346-
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
1420+
/// @brief Get the last block id (beam 0) for a given sequence and window size
1421+
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const
13471422
= 0;
13481423

13491424
[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
@@ -1414,6 +1489,12 @@ class BaseKVCacheManager
14141489
[[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0;
14151490

14161491
[[nodiscard]] virtual CacheType getCacheType() const = 0;
1492+
1493+
[[nodiscard]] virtual std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
1494+
BlockKey const& blockKey, SizeType32 windowSize)
1495+
= 0;
1496+
1497+
virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
14171498
};
14181499

14191500
class KVCacheManager : public BaseKVCacheManager
@@ -1591,8 +1672,8 @@ class KVCacheManager : public BaseKVCacheManager
15911672
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
15921673
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;
15931674

1594-
void removeSequence(
1595-
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest = std::nullopt) override;
1675+
[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
1676+
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;
15961677

15971678
void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) override;
15981679

@@ -1652,6 +1733,9 @@ class KVCacheManager : public BaseKVCacheManager
16521733
//! \brief Store newest blocks for reuse
16531734
void storeNewBlock(LlmRequest const& llmRequest) override;
16541735

1736+
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
1737+
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;
1738+
16551739
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
16561740

16571741
[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override;
@@ -1668,6 +1752,12 @@ class KVCacheManager : public BaseKVCacheManager
16681752
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
16691753
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);
16701754

1755+
void pinBlocks(LlmRequest::RequestIdType requestId) override;
1756+
1757+
void unpinBlocksById(KVCacheBlock::IdType blockId) override;
1758+
1759+
std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;
1760+
16711761
/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
16721762
///
16731763
/// @param sequenceLength The total length of the sequence (input and output).
@@ -1684,9 +1774,6 @@ class KVCacheManager : public BaseKVCacheManager
16841774
std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds(
16851775
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const override;
16861776

1687-
std::vector<SizeType32> getNewlyAllocatedBlockIds(
1688-
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;
1689-
16901777
runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
16911778
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;
16921779

@@ -1706,6 +1793,12 @@ class KVCacheManager : public BaseKVCacheManager
17061793
mBlockManager.flushIterationEvents();
17071794
}
17081795

1796+
std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
1797+
BlockKey const& blockKey, SizeType32 windowSize) override
1798+
{
1799+
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
1800+
}
1801+
17091802
/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
17101803
///
17111804
/// @param inputLength The number of input tokens in the sequence.

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,32 @@ class BlockRange
4040
return BlockRange(cacheManager, blockIds, requestId);
4141
}
4242

43-
static BlockRange fromNewlyAllocatedBlockIds(
44-
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
43+
static BlockRange fromReuseTree(
44+
BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd)
4545
{
4646
auto const windowSize = firstWindowSize(cacheManager);
47-
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
48-
return BlockRange(cacheManager, blockIds, requestId);
47+
// Find the last block in the reuse tree for the provided full sequence of block keys
48+
auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize);
49+
// TODO: handle the case where the last block is not found
50+
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
51+
int32_t const numBlocksToCollect = indexFromEnd + 1;
52+
53+
std::vector<SizeType32> blockIds;
54+
blockIds.reserve(numBlocksToCollect);
55+
for (int32_t i = 0; i < numBlocksToCollect; ++i)
56+
{
57+
TLLM_CHECK_WITH_INFO(
58+
lastBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId, "last block has no block id");
59+
blockIds.push_back(lastBlock->getBlockId());
60+
if (i + 1 < numBlocksToCollect)
61+
{
62+
TLLM_CHECK_WITH_INFO(lastBlock->getPrevBlock(), "last block has no prev block");
63+
lastBlock = lastBlock->getPrevBlock();
64+
}
65+
}
66+
// Reverse to chronological order: oldest to newest
67+
std::reverse(blockIds.begin(), blockIds.end());
68+
return BlockRange(cacheManager, blockIds, 0);
4969
}
5070

5171
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
@@ -80,19 +100,6 @@ class BlockRange
80100
mBlockIds = std::move(blockIds);
81101
}
82102

83-
[[nodiscard]] std::vector<size_t> getBlockHashes() const
84-
{
85-
TLLM_CHECK(mManager);
86-
std::vector<size_t> blockHashes;
87-
blockHashes.reserve(mBlockIds.size());
88-
auto& blockManager = mManager->getBlockManager();
89-
for (auto id : mBlockIds)
90-
{
91-
blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash());
92-
}
93-
return blockHashes;
94-
}
95-
96103
void updatePoolIdx(SizeType32 poolIdx)
97104
{
98105
TLLM_CHECK(mManager);

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,16 +1843,6 @@ class GenericLlmRequest
18431843
}
18441844
}
18451845

1846-
void setRequestedBlockHashes(std::vector<size_t> hashes)
1847-
{
1848-
mRequestedBlockHashes = std::move(hashes);
1849-
}
1850-
1851-
[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
1852-
{
1853-
return mRequestedBlockHashes;
1854-
}
1855-
18561846
void setIsDummyRequest(bool isDummyRequest)
18571847
{
18581848
mIsDummyRequest = isDummyRequest;
@@ -2044,9 +2034,6 @@ class GenericLlmRequest
20442034
// Tensors containing the additional generation output.
20452035
TensorMap mAdditionalGenerationOutputTensors;
20462036

2047-
// Context request only. The hashes of the blocks that are requested by the corresponding generation request.
2048-
std::vector<size_t> mRequestedBlockHashes;
2049-
20502037
bool mIsDummyRequest{false};
20512038

20522039
bool mUseDraftModel{false};

0 commit comments

Comments
 (0)