Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ class WindowBlockManager
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest);

//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);

//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
Expand Down Expand Up @@ -869,8 +869,13 @@ class BlockManager
void addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
LlmRequest& llmRequest, SizeType32 windowSize);

//! \brief Assign blocks for a new sequence.
//! \param sequence The GenerationRequest to process.
//! \param numContextBlocks Number of context blocks to allocate.
//! \param windowSize Attention window size
//! \param isShareLastContextBlock If true, the last context block is shared among beams.
void addSequence(
GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize);
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);

void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);

Expand Down Expand Up @@ -1106,6 +1111,15 @@ class BlockManager
return mWindowBlockManagers.at(windowSize).getPool(relativePoolIndex);
}

//! \brief Update cache offsets for blocks initiated from sequence
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);

//! \brief Update cache offsets for last block
void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);

//! \brief Update cache offsets for block at index
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);

private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
{
Expand Down Expand Up @@ -1637,12 +1651,6 @@ class KVCacheManager : public BaseKVCacheManager
[[nodiscard]] static SizeType32 calculateMaxAttentionWindow(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 blockCapacity, SizeType32 beamWidth, SizeType32 tokensPerBlock);

private:
void cacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
void cacheNewBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
void updateNewBlockPointer(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
void updateToken(GenerationRequest& sequence, bool addToken);

private:
// Maximum number of sequences
SizeType32 mMaxNumSequences;
Expand Down
143 changes: 54 additions & 89 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,12 +1146,16 @@ void WindowBlockManager::refreshBlocks()
mTransferManager->syncTransfers();
}

// There are two versions of BlockManager::addSequence function.
// This is called when block reuse is enabled.
void BlockManager::addSequence(GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks,
LlmRequest& llmRequest, SizeType32 windowSize)
{
mWindowBlockManagers.at(windowSize).addSequence(sequence, inputLength, numContextBlocks, llmRequest);
}

// There are two versions of WindowBlockManager::addSequence function.
// This is called when block reuse is enabled.
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 inputLength, SizeType32 numContextBlocks, LlmRequest& llmRequest)
{
Expand Down Expand Up @@ -1189,24 +1193,29 @@ void WindowBlockManager::addSequence(
inputLength, prepopulatedPromptLen);
}

// There are two versions of BlockManager::addSequence function.
// This is called when block reuse is disabled.
void BlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx, SizeType32 windowSize)
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock)
{
mWindowBlockManagers.at(windowSize).addSequence(sequence, numBlocks, unsharedBlockIdx);
mWindowBlockManagers.at(windowSize).addSequence(sequence, numContextBlocks, isShareLastContextBlock);
}

void WindowBlockManager::addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx)
// There are two versions of WindowBlockManager::addSequence function.
// This is called when block reuse is disabled.
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
{
auto const requestId = sequence.getRequestId();
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
TLLM_CHECK(emplaceDone);

// Allocate blocks
for (SizeType32 bi = 0; bi < numBlocks; ++bi)
TLLM_CHECK_WITH_INFO(numContextBlocks > 0, "numContextBlocks must be greater than 0");
for (SizeType32 bi = 0; bi < numContextBlocks - 1; ++bi)
{
bool shareAmongBeams = bi != unsharedBlockIdx;
allocateBlock(sequence, shareAmongBeams);
allocateBlock(sequence, /*shareAmongBeams=*/true);
}
allocateBlock(sequence, /*shareAmongBeams=*/isShareLastContextBlock);
}

void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx)
Expand Down Expand Up @@ -1639,6 +1648,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0,
"[kv cache manager] streamLLM is not supported at the moment");
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
!= maxAttentionWindowVec.end());
// The sink tokens are stored in blocks separate from other tokens.
Expand Down Expand Up @@ -1751,7 +1762,6 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(
{
auto const maxTokensToAddToKVCache = req.mMaxNewTokens;
auto const maxDraftTokensToAdd = std::min(req.getNumDraftTokens(), maxTokensToAddToKVCache);
// Assumes shared among beam = True
auto const promptCacheLen
= std::min((isCrossKv() ? req.getEncoderOutputLen() : req.mPromptLen) + maxDraftTokensToAdd, windowSize)
+ mSinkBubbleLength;
Expand Down Expand Up @@ -1834,7 +1844,7 @@ SizeType32 KVCacheManager::getRemainingBlocksToCompletion(LlmRequest const& req,
return (numTotalBlocksPerBeam - numAllocBlocksPerBeam) * req.mSamplingConfig.beamWidth;
}

void KVCacheManager::cacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize)
void BlockManager::updateSequenceCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
Expand All @@ -1849,12 +1859,12 @@ void KVCacheManager::cacheBlockOffsets(GenerationRequest& sequence, SizeType32 w
for (SizeType32 blockIdx = 0; blockIdx < static_cast<SizeType32>(beamCacheBlock.size()); ++blockIdx)
{
auto const blockId = beamCacheBlock.at(blockIdx);
mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}
}

void KVCacheManager::cacheNewBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize)
void BlockManager::updateLastCacheBlockOffsets(GenerationRequest& sequence, SizeType32 windowSize)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
Expand All @@ -1868,11 +1878,11 @@ void KVCacheManager::cacheNewBlockOffsets(GenerationRequest& sequence, SizeType3
auto const& beamCacheBlock = cacheBlocks[beamIdx];
auto const blockId = beamCacheBlock.back();
auto const blockIdx = static_cast<SizeType32>(beamCacheBlock.size() - 1);
mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}

void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
void BlockManager::updateCacheBlockOffsetsAtIdx(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx)
{
auto const& cacheBlocks = sequence.getCacheBlockIds(windowSize);
auto& cacheBlocksTensor = sequence.getCacheBlockIndices(windowSize);
Expand All @@ -1885,76 +1895,37 @@ void KVCacheManager::updateNewBlockPointer(GenerationRequest& sequence, SizeType
{
auto const& beamCacheBlock = cacheBlocks[beamIdx];
auto const blockId = beamCacheBlock.at(blockIdx);
mBlockManager.setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId, windowSize);
mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId);
}
}

void KVCacheManager::updateToken(GenerationRequest& sequence, bool addToken)
void KVCacheManager::addToken(RequestIdType requestId)
{
auto currNumTokens = sequence.getNumTokens();

if (addToken)
{
sequence.addNewTokens(1);
}
else
{
sequence.removeTokens(1);
}

auto newNumTokens = sequence.getNumTokens();

if (!addToken)
{
std::swap(currNumTokens, newNumTokens);
}

// TODO: add streamLLM support
auto& sequence = getSequence(requestId);
sequence.addNewTokens(1);
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
auto const maxTokenNum = metadata.maxTokenNum;
SizeType32 const cyclicTokenNum = maxTokenNum - mSinkBlockTokenLength;
SizeType32 const nextTokenIdxInCycle = (currNumTokens - mSinkBlockTokenLength) % cyclicTokenNum;
SizeType32 const nextTokenIdxInCache = mSinkBlockTokenLength + nextTokenIdxInCycle;

// (nextTokenIdxInCache - mSinkBlockTokenLength) % cyclicTokenNum == 0)
// <=> nextTokenIdxInCycle == 0
// <=> nextTokenIdxInCache == mSinkBlockTokenLength
// => nextTokenIdxInCache % getTokensPerBlock() == 0

// Check if require a new block
if (nextTokenIdxInCache % getTokensPerBlock() == 0)
if ((sequence.getNumTokens() - 1) % getTokensPerBlock() == 0)
{
if (newNumTokens <= maxTokenNum)
if (sequence.getNumTokens() <= windowSize)
{
if (addToken)
{
mBlockManager.allocateBlock(sequence, windowSize);
cacheNewBlockOffsets(sequence, windowSize);
}
else
{
mBlockManager.releaseLastBlock(sequence, windowSize);
}
// Allocate new unshared blocks until the window can always
// accommodate "window size" number of tokens.
mBlockManager.allocateBlock(sequence, windowSize);
mBlockManager.updateLastCacheBlockOffsets(sequence, windowSize);
}
else if (sequence.getBeamWidth() > 1)
{
TLLM_CHECK_WITH_INFO(addToken, "Remove token is not supported with beam search");
// Get next block index
SizeType32 nextBlockIdx = nextTokenIdxInCache / getTokensPerBlock();
// Replace the shared block with the unshared ones
// For beam search, shared block is replaced with unshared ones
auto const nextBlockIdx = (sequence.getNumTokens() - 1) / getTokensPerBlock();
mBlockManager.replaceSharedBlock(sequence, windowSize, nextBlockIdx);
updateNewBlockPointer(sequence, windowSize, nextBlockIdx);
mBlockManager.updateCacheBlockOffsetsAtIdx(sequence, windowSize, nextBlockIdx);
}
}
}
}

void KVCacheManager::addToken(RequestIdType requestId)
{
auto& sequence = getSequence(requestId);
updateToken(sequence, true);
}

std::optional<BlockKey> KVCacheManager::findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const
{
Expand All @@ -1965,9 +1936,7 @@ std::optional<BlockKey> KVCacheManager::findNewContextBlock(
void KVCacheManager::addSequence(
RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef<LlmRequest> llmRequest)
{
// Need to add the bubble after the sink tokens to use even block size
inputLength += mSinkBubbleLength;

// TODO: add streamLLM support
auto kvCacheRetentionConfig = llmRequest
? llmRequest->getKvCacheRetentionConfig().value_or(executor::KvCacheRetentionConfig())
: executor::KvCacheRetentionConfig();
Expand All @@ -1992,20 +1961,6 @@ void KVCacheManager::addSequence(
auto const maxTokenNum = metadata.maxTokenNum;
auto const temporaryAttentionWindow = metadata.temporaryAttentionWindow;

// Get the final token index in kv cache
SizeType32 const finalTokenKVIdx = mSinkBlockTokenLength
+ ((inputLength - 1 - mSinkBlockTokenLength) % (maxTokenNum - mSinkBlockTokenLength));

// Get block index that with shareAmongBeams=False.
// For cross kv cache in encoder-decoder models, always shareAmongBeams=True.
SizeType32 unsharedBlockIdx = -1;
if ((!sequence.isCyclic() || beamWidth > 1 || finalTokenKVIdx % getTokensPerBlock() > 0) && !isCrossKv())
{
unsharedBlockIdx = ((finalTokenKVIdx + 1) % getTokensPerBlock() == 0)
? finalTokenKVIdx / getTokensPerBlock() + 1
: finalTokenKVIdx / getTokensPerBlock();
}

// Consider the temporaryAttentionWindow when allocating blocks.
auto const effectiveInputLength = std::min(inputLength, maxTokenNum + temporaryAttentionWindow);
auto const numContextBlocks = tc::ceilDiv(effectiveInputLength, getTokensPerBlock());
Expand All @@ -2024,9 +1979,11 @@ void KVCacheManager::addSequence(
"have no effect.",
llmRequest->mRequestId);
}
mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize);
bool isShareLastContextBlock = isCrossKv() || (sequence.isCyclic() && beamWidth == 1)
|| effectiveInputLength % getTokensPerBlock() == 0;
mBlockManager.addSequence(sequence, numContextBlocks, windowSize, isShareLastContextBlock);
}
cacheBlockOffsets(sequence, windowSize);
mBlockManager.updateSequenceCacheBlockOffsets(sequence, windowSize);
}

if (llmRequest)
Expand Down Expand Up @@ -2353,15 +2310,23 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi

void KVCacheManager::removeToken(RequestIdType requestId)
{
// TODO: add streamLLM support
auto& sequence = getSequence(requestId);
auto const beamWidth = sequence.getBeamWidth();

TLLM_CHECK_WITH_INFO(beamWidth == 1, "removeToken does not support beamWidth > 1");
if (sequence.getNumTokens() == 0)
{
return;
}
updateToken(sequence, false);
TLLM_CHECK_WITH_INFO(sequence.getBeamWidth() == 1, "[kv cache manager] removeToken does not support beamWidth > 1");
sequence.removeTokens(1);
for (auto const [windowSize, metadata] : mBlockManager.getWindowSizesMetadata())
{
SizeType32 const maxTokensInWindow = metadata.maxTokenNum;
SizeType32 const tokensInWindow = sequence.getNumTokens() % maxTokensInWindow;
if (tokensInWindow % getTokensPerBlock() == 0 && tokensInWindow <= maxTokensInWindow)
{
mBlockManager.releaseLastBlock(sequence, windowSize);
}
}
}

void KVCacheManager::rewindKVCache(RequestIdType requestId, SizeType32 rewindLengths)
Expand Down
Loading