Skip to content

Commit 209052a

Browse files
committed
Dont pass connector manager through add_sequence
Signed-off-by: jthomson04 <[email protected]>
1 parent 0b210f0 commit 209052a

File tree

5 files changed

+75
-71
lines changed

5 files changed

+75
-71
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,8 @@ class WindowBlockManager
537537
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
538538
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
539539
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
540-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
540+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
541+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
541542

542543
~WindowBlockManager();
543544

@@ -548,8 +549,8 @@ class WindowBlockManager
548549
void startScheduling();
549550

550551
//! \brief Assign blocks for new sequence. Try to reuse blocks.
551-
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
552-
LlmRequest& llmRequest, OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
552+
void addSequence(
553+
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
553554

554555
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
555556
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
@@ -834,6 +835,8 @@ class WindowBlockManager
834835
bool mEnablePartialReuse;
835836
// Whether partially matched blocks that are already in use should be copied and reused.
836837
bool mCopyOnPartialReuse;
838+
// The kv cache connector manager
839+
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
837840
};
838841

839842
class BlockManager
@@ -851,7 +854,8 @@ class BlockManager
851854
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
852855
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
853856
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
854-
bool copyOnPartialReuse = true);
857+
bool copyOnPartialReuse = true,
858+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
855859

856860
BlockManager(BlockManager const&) = delete;
857861
BlockManager& operator=(BlockManager const&) = delete;
@@ -868,8 +872,7 @@ class BlockManager
868872
void allocatePools(bool useUvm);
869873

870874
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
871-
LlmRequest& llmRequest, OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
872-
SizeType32 windowSize);
875+
LlmRequest& llmRequest, SizeType32 windowSize);
873876

874877
void addSequence(
875878
GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize);
@@ -1213,8 +1216,7 @@ class BaseKVCacheManager
12131216
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
12141217
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
12151218
virtual void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
1216-
OptionalRef<LlmRequest> llmRequest = std::nullopt,
1217-
OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = std::nullopt)
1219+
OptionalRef<LlmRequest> llmRequest = std::nullopt)
12181220
= 0;
12191221

12201222
virtual void removeSequence(
@@ -1361,7 +1363,8 @@ class KVCacheManager : public BaseKVCacheManager
13611363
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13621364
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13631365
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1364-
bool copyOnpartialReuse = true);
1366+
bool copyOnpartialReuse = true,
1367+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13651368

13661369
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13671370
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1371,7 +1374,8 @@ class KVCacheManager : public BaseKVCacheManager
13711374
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13721375
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13731376
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1374-
bool copyOnpartialReuse = true);
1377+
bool copyOnpartialReuse = true,
1378+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13751379

13761380
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13771381
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1381,7 +1385,8 @@ class KVCacheManager : public BaseKVCacheManager
13811385
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
13821386
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13831387
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
1384-
bool copyOnpartialReuse = true);
1388+
bool copyOnpartialReuse = true,
1389+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
13851390

13861391
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13871392
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1513,8 +1518,7 @@ class KVCacheManager : public BaseKVCacheManager
15131518
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
15141519
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
15151520
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
1516-
OptionalRef<LlmRequest> llmRequest = std::nullopt,
1517-
OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = std::nullopt) override;
1521+
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;
15181522

15191523
void removeSequence(
15201524
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest = std::nullopt) override;

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
504504
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
505505
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
506506
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
507+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
508+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
508509
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
509510
, mTokensPerBlock{tokensPerBlock}
510511
, mEventManager{std::move(eventManager)}
@@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
513514
{
514515
auto const uniqueWindowSizeToLayers
515516
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
517+
518+
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
519+
"KV Cache Connector is not supported with multiple window sizes");
520+
516521
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
517522

518523
mIsVariableWindow = numUniqueWindowSizes > 1;
@@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
530535
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
531536
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
532537
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
533-
copyOnPartialReuse);
538+
copyOnPartialReuse, kvCacheConnectorManager);
534539
}
535540

536541
auto const numAllPools = getNumPools();
@@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
572577
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
573578
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
574579
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
575-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
580+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
581+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
576582
: mDataType{dtype}
577583
, mWindowSize{windowSize}
578584
, mNumPrimaryBlocks{blocksInPrimaryPool}
@@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
596602
, mTotalInputTokens{0.0}
597603
, mEnablePartialReuse{enablePartialReuse}
598604
, mCopyOnPartialReuse{copyOnPartialReuse}
605+
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
599606
{
600607
std::map<SizeType32, SizeType32> numLayersPerPool;
601608

@@ -1147,15 +1154,13 @@ void WindowBlockManager::refreshBlocks()
11471154
}
11481155

11491156
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
1150-
LlmRequest& llmRequest, OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
1151-
SizeType32 windowSize)
1157+
LlmRequest& llmRequest, SizeType32 windowSize)
11521158
{
1153-
mWindowBlockManagers.at(windowSize)
1154-
.addSequence(sequence, inputLength, numContextBlocks, llmRequest, kvCacheConnectorManager);
1159+
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
11551160
}
11561161

1157-
void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
1158-
LlmRequest& llmRequest, OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
1162+
void WindowBlockManager::addSequence(
1163+
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
11591164
{
11601165
auto const requestId = sequence.getRequestId();
11611166
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
@@ -1190,9 +1195,9 @@ void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 inp
11901195
SizeType32 numConnectorMatchedTokens = 0;
11911196

11921197
// If we're using a KV cache connector, check if any additional blocks can be loaded.
1193-
if (kvCacheConnectorManager)
1198+
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
11941199
{
1195-
numConnectorMatchedTokens = kvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
1200+
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
11961201
}
11971202

11981203
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
@@ -1208,6 +1213,13 @@ void BlockManager::addSequence(
12081213

12091214
void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx)
12101215
{
1216+
if (mKvCacheConnectorManager)
1217+
{
1218+
TLLM_LOG_WARNING(
1219+
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
1220+
"ignored.");
1221+
}
1222+
12111223
auto const requestId = sequence.getRequestId();
12121224
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
12131225
TLLM_CHECK(emplaceDone);
@@ -1620,12 +1632,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16201632
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
16211633
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16221634
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1623-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1635+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1636+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16241637
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
16251638
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16261639
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
16271640
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
1628-
copyOnPartialReuse)
1641+
copyOnPartialReuse, kvCacheConnectorManager)
16291642
{
16301643
}
16311644

@@ -1636,7 +1649,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16361649
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16371650
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16381651
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1639-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1652+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1653+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16401654
: mMaxBeamWidth(maxBeamWidth)
16411655
, mDataType(dtype)
16421656
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
@@ -1646,7 +1660,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16461660
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
16471661
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
16481662
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1649-
enablePartialReuse, copyOnPartialReuse)
1663+
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
16501664
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
16511665
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
16521666
{
@@ -1668,11 +1682,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
16681682
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
16691683
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
16701684
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1671-
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
1685+
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
1686+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16721687
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
16731688
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16741689
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1675-
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
1690+
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
16761691
{
16771692
}
16781693

@@ -1973,17 +1988,12 @@ std::optional<BlockKey> KVCacheManager::findNewContextBlock(
19731988
return newContextBlockOpt;
19741989
}
19751990

1976-
void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
1977-
OptionalRef<LlmRequest> llmRequest, OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
1991+
void KVCacheManager::addSequence(
1992+
RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef<LlmRequest> llmRequest)
19781993
{
19791994
// Need to add the bubble after the sink tokens to use even block size
19801995
inputLength += mSinkBubbleLength;
19811996

1982-
if (kvCacheConnectorManager)
1983-
{
1984-
TLLM_CHECK_WITH_INFO(beamWidth == 1, "KV Cache Connector is not supported with beam search");
1985-
}
1986-
19871997
auto kvCacheRetentionConfig = llmRequest
19881998
? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig())
19891999
: executor::KvCacheRetentionConfig();
@@ -2027,8 +2037,7 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength
20272037
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
20282038
if (!sequence.isCyclic() && mEnableBlockReuse)
20292039
{
2030-
mBlockManager.addSequence(
2031-
sequence, effectiveInputLength, numContextBlocks, *llmRequest, kvCacheConnectorManager, windowSize);
2040+
mBlockManager.addSequence(sequence, effectiveInputLength, numContextBlocks, *llmRequest, windowSize);
20322041
}
20332042
else
20342043
{
@@ -2040,12 +2049,6 @@ void KVCacheManager::addSequence(RequestIdType requestId, SizeType32 inputLength
20402049
"will "
20412050
"have no effect.",
20422051
llmRequest->mRequestId);
2043-
if (kvCacheConnectorManager.has_value())
2044-
{
2045-
TLLM_LOG_WARNING(
2046-
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
2047-
"ignored.");
2048-
}
20492052
}
20502053
mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize);
20512054
}

0 commit comments

Comments
 (0)