Skip to content

Commit e956fbf

Browse files
committed
[KV cache manager] Simplify shared/unshared last context block logic
Signed-off-by: eopXD <[email protected]>
1 parent 1f4b819 commit e956fbf

File tree

2 files changed

+17
-26
lines changed

2 files changed

+17
-26
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ class WindowBlockManager
546546
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);
547547

548548
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
549-
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
549+
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);
550550

551551
//! \brief Allocate new block for each beam of the sequence.
552552
//! \details Might free cached blocks if no free blocks are available.
@@ -876,7 +876,7 @@ class BlockManager
876876
LlmRequest& llmRequest, SizeType32 windowSize);
877877

878878
void addSequence(
879-
GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize);
879+
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);
880880

881881
void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);
882882

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,25 +1241,25 @@ void WindowBlockManager::addSequence(
12411241
// There are two versions of BlockManager::addSequence function.
12421242
// This is called when block reuse is disabled.
12431243
void BlockManager::addSequence(
1244-
GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize)
1244+
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock)
12451245
{
1246-
mWindowBlockManagers.at(windowSize).addSequence(sequence, numBlocks, unsharedBlockIdx);
1246+
mWindowBlockManagers.at(windowSize).addSequence(sequence, numContextBlocks, isShareLastContextBlock);
12471247
}
12481248

12491249
// There are two versions of BlockManager::addSequence function.
12501250
// This is called when block reuse is disabled.
1251-
void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx)
1251+
void WindowBlockManager::addSequence(
1252+
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
12521253
{
12531254
auto const requestId = sequence.getRequestId();
12541255
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
12551256
TLLM_CHECK(emplaceDone);
12561257

1257-
// Allocate blocks
1258-
for (SizeType32 bi = 0; bi < numBlocks; ++bi)
1258+
for (SizeType32 bi = 0; bi < numContextBlocks - 1; ++bi)
12591259
{
1260-
bool shareAmongBeams = bi != unsharedBlockIdx;
1261-
allocateBlock(sequence, shareAmongBeams);
1260+
allocateBlock(sequence, /*shareAmongBeams=*/true);
12621261
}
1262+
allocateBlock(sequence, /*shareAmongBeams=*/isShareLastContextBlock);
12631263
}
12641264

12651265
void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx)
@@ -1919,8 +1919,8 @@ std::optional<BlockKey> KVCacheManager::findNewContextBlock(
19191919
void KVCacheManager::addSequence(
19201920
RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef<LlmRequest> llmRequest)
19211921
{
1922-
// Need to add the bubble after the sink tokens to use even block size
1923-
inputLength += mSinkBubbleLength;
1922+
TLLM_CHECK_WITH_INFO(
1923+
mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0, "streamLLM is not supported at the moment");
19241924

19251925
auto kvCacheRetentionConfig = llmRequest
19261926
? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig())
@@ -1935,6 +1935,10 @@ void KVCacheManager::addSequence(
19351935
TLLM_CHECK(emplaceDone);
19361936
auto& sequence = seqIt->second;
19371937

1938+
if (sequence.getBeamWidth() > 1)
1939+
{
1940+
TLLM_LOG_WARNING("[KV cache manager] Beam search is not supported at the moment");
1941+
}
19381942
// Get statistics for block allocations/reuse pre request.
19391943
SizeType32 const numAllocTotalBlocksPreRequest = mBlockManager.getNumAllocTotalBlocks();
19401944
SizeType32 const numAllocNewBlocksPreRequest = mBlockManager.getNumAllocNewBlocks();
@@ -1946,20 +1950,6 @@ void KVCacheManager::addSequence(
19461950
auto const maxTokenNum = metadata.maxTokenNum;
19471951
auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow;
19481952

1949-
// Get the final token index in kv cache
1950-
SizeType32 const finalTokenKVIdx = mSinkBlockTokenLength
1951-
+ ((inputLength - 1 - mSinkBlockTokenLength) % (maxTokenNum - mSinkBlockTokenLength));
1952-
1953-
// Get block index that with shareAmongBeams=False.
1954-
// For cross kv cache in encoder-decoder models, always shareAmongBeams=True.
1955-
SizeType32 unsharedBlockIdx = -1;
1956-
if ((!sequence.isCyclic() || beamWidth > 1 || finalTokenKVIdx % getTokensPerBlock() > 0) && !isCrossKv())
1957-
{
1958-
unsharedBlockIdx = ((finalTokenKVIdx + 1) % getTokensPerBlock() == 0)
1959-
? finalTokenKVIdx / getTokensPerBlock() + 1
1960-
: finalTokenKVIdx / getTokensPerBlock();
1961-
}
1962-
19631953
// Consider the temporaryAttentionWindow when allocating blocks.
19641954
auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow);
19651955
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
@@ -1978,7 +1968,8 @@ void KVCacheManager::addSequence(
19781968
"have no effect.",
19791969
llmRequest->mRequestId);
19801970
}
1981-
mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize);
1971+
bool const isShareLastContextBlock = isCrossKv() || effectiveInputLength % getTokensPerBlock() == 0;
1972+
mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock);
19821973
if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1)
19831974
{
19841975
constexpr SizeType32 beamIdx = 0;

0 commit comments

Comments
 (0)