Skip to content
Merged
143 changes: 118 additions & 25 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <optional>
#include <set>
#include <unordered_map>
#include <utility>
#include <vector>

namespace kvc = tensorrt_llm::executor::kv_cache;
Expand Down Expand Up @@ -84,6 +85,32 @@ using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;

//! \brief Split vector into list of blocks of given size.
//! \param vec vector to split
//! \param usableSize part of the vector that is processed
//! \param elementsPerBlock desired size of blocks
//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end
//! \return list of blocks
template <typename T>
std::list<std::vector<T>> chopVectorIntoBlocks(
std::vector<T> const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial)
{
TLLM_CHECK_WITH_INFO(
usableSize <= static_cast<SizeType32>(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size());
std::list<std::vector<T>> blockedVectors;
auto const vecEnd = vec.begin() + usableSize;
for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock)
{
auto blockSize = std::min(elementsPerBlock, static_cast<SizeType32>(std::distance(begin, vecEnd)));
auto end = begin + blockSize;
if (blockSize == elementsPerBlock || allowPartial)
{
blockedVectors.emplace_back(begin, end);
}
}
return blockedVectors;
}

struct TempAttentionWindowInputs
{
bool pagedContextFMHA;
Expand Down Expand Up @@ -114,6 +141,9 @@ struct WindowSizeMetadata
}
};

std::vector<MmKey> generateBlockHashExtraKeys(
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx);

struct BlockKey
{
bool usesExtraIds = false;
Expand Down Expand Up @@ -147,11 +177,7 @@ struct BlockKey
{
}

bool operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
}
bool operator==(BlockKey const& other) const noexcept;

int partialMatch(BlockKey const& other) const noexcept
{
Expand All @@ -166,6 +192,8 @@ struct BlockKey
}
};

std::vector<BlockKey> buildBlockKeys(std::list<VecUniqueTokens>& blockedUniqueTokens, LlmRequest const& llmRequest);

// Implement hash functor for BlockKey.
// This allows us to use unordered_map with BlockKey as key.
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
Expand Down Expand Up @@ -577,14 +605,18 @@ class WindowBlockManager

void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx);

//! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const;
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false);

void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Pin blocks associated with a sequence to prevent eviction.
void pinBlocks(GenerationRequest& sequence);

//! \brief Release blocks of the sequence.
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
Expand Down Expand Up @@ -757,8 +789,11 @@ class WindowBlockManager
//! \brief Store blocks in cached blocks.
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \return Number of actual blocks stored.
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
//! \param pinBlocks If true, increment ref count for blocks while storing (pin on store).
//! \return Pair of (num blocks stored for reuse, id of the last block stored if any).
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool pinBlocks = false);

[[nodiscard]] bool verifyQueueIntegrity();

Expand Down Expand Up @@ -786,6 +821,11 @@ class WindowBlockManager
return mIsSWA;
}

[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey);

//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand Down Expand Up @@ -890,6 +930,9 @@ class WindowBlockManager
bool mCopyOnPartialReuse;
// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;

// Mutex for the cached blocks root
std::mutex mCachedBlocksRootMutex;
};

class BlockManager
Expand Down Expand Up @@ -940,13 +983,20 @@ class BlockManager

void replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx);

std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(
GenerationRequest const& sequence, SizeType32 windowSize) const;
std::optional<KVCacheBlock::IdType> releaseBlocks(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);

void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt);
[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinBlocks = false);

void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);

/// @brief Pin all blocks associated with a sequence across all window managers.
/// @param sequence The generation request whose blocks should be pinned.
void pinBlocks(GenerationRequest& sequence);

void unpinBlocksById(KVCacheBlock::IdType blockId);

void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize);

void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
Expand All @@ -966,10 +1016,11 @@ class BlockManager
void offloadBlock(BlockPtr const& block, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize)
[[nodiscard]] std::pair<SizeType32, std::optional<KVCacheBlock::IdType>> storeBlocks(
std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
SizeType32 windowSize, bool pinBlocks = false)
{
mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds);
return mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds, pinBlocks);
}

[[nodiscard]] bool verifyQueueIntegrity(SizeType32 windowSize);
Expand Down Expand Up @@ -1003,6 +1054,15 @@ class BlockManager
return sumWindows([](auto const& manager) { return manager.getNumAllocTotalBlocks(); });
}

[[nodiscard]] SizeType32 getFirstWindowSize() const
{
if (mWindowBlockManagers.empty())
{
return 0;
}
return mWindowBlockManagers.begin()->first;
}

[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumAllocNewBlocks(); });
Expand Down Expand Up @@ -1133,6 +1193,12 @@ class BlockManager
return mWindowBlockManagers.at(windowSize).getBlockById(blockId);
}

[[nodiscard]] std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize)
{
return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey);
}

[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
{
return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); });
Expand Down Expand Up @@ -1274,6 +1340,10 @@ class BaseKVCacheManager
[[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const
= 0;

/// @brief Pin blocks associated with a request to prevent eviction.
/// @param requestId The ID of the request whose blocks should be pinned.
virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0;

/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
virtual void addToken(LlmRequest::RequestIdType requestId) = 0;

Expand All @@ -1287,8 +1357,8 @@ class BaseKVCacheManager
OptionalRef<LlmRequest> llmRequest = std::nullopt)
= 0;

virtual void removeSequence(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest = std::nullopt)
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
= 0;

virtual void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) = 0;
Expand Down Expand Up @@ -1332,6 +1402,11 @@ class BaseKVCacheManager
//! \details This block become reusable from next step.
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;

/// \brief Store blocks for reuse for a given request id
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false)
= 0;

//! \brief Get the block ids of a request [per beam] **for a given window size block manager**
[[nodiscard]] virtual std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
Expand All @@ -1342,8 +1417,8 @@ class BaseKVCacheManager
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const
= 0;

[[nodiscard]] virtual std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
/// @brief Get the last block id (beam 0) for a given sequence and window size
[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const
= 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
Expand Down Expand Up @@ -1414,6 +1489,12 @@ class BaseKVCacheManager
[[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0;

[[nodiscard]] virtual CacheType getCacheType() const = 0;

[[nodiscard]] virtual std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize)
= 0;

virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0;
};

class KVCacheManager : public BaseKVCacheManager
Expand Down Expand Up @@ -1591,8 +1672,8 @@ class KVCacheManager : public BaseKVCacheManager
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;

void removeSequence(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest = std::nullopt) override;
[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;

void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) override;

Expand Down Expand Up @@ -1652,6 +1733,9 @@ class KVCacheManager : public BaseKVCacheManager
//! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override;

[[nodiscard]] std::optional<KVCacheBlock::IdType> storeBlocksForReuse(
LlmRequest::RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest, bool pinBlocks = false) override;

[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);

[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override;
Expand All @@ -1668,6 +1752,12 @@ class KVCacheManager : public BaseKVCacheManager
[[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength,
SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock);

void pinBlocks(LlmRequest::RequestIdType requestId) override;

void unpinBlocksById(KVCacheBlock::IdType blockId) override;

std::optional<KVCacheBlock::IdType> getLastBlockId(LlmRequest::RequestIdType requestId) const override;

/// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam.
///
/// @param sequenceLength The total length of the sequence (input and output).
Expand All @@ -1684,9 +1774,6 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds(
std::vector<LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const override;

std::vector<SizeType32> getNewlyAllocatedBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;

runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;

Expand All @@ -1706,6 +1793,12 @@ class KVCacheManager : public BaseKVCacheManager
mBlockManager.flushIterationEvents();
}

std::shared_ptr<KVCacheBlock> findBlocksInReuseTreeByBlockKey(
BlockKey const& blockKey, SizeType32 windowSize) override
{
return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize);
}

/// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity.
///
/// @param inputLength The number of input tokens in the sequence.
Expand Down
41 changes: 24 additions & 17 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,32 @@ class BlockRange
return BlockRange(cacheManager, blockIds, requestId);
}

static BlockRange fromNewlyAllocatedBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
static BlockRange fromReuseTree(
BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd)
{
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
return BlockRange(cacheManager, blockIds, requestId);
// Find the last block in the reuse tree for the provided full sequence of block keys
auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize);
// TODO: handle the case where the last block is not found
TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree");
int32_t const numBlocksToCollect = indexFromEnd + 1;

std::vector<SizeType32> blockIds;
blockIds.reserve(numBlocksToCollect);
for (int32_t i = 0; i < numBlocksToCollect; ++i)
{
TLLM_CHECK_WITH_INFO(
lastBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId, "last block has no block id");
blockIds.push_back(lastBlock->getBlockId());
if (i + 1 < numBlocksToCollect)
{
TLLM_CHECK_WITH_INFO(lastBlock->getPrevBlock(), "last block has no prev block");
lastBlock = lastBlock->getPrevBlock();
}
}
// Reverse to chronological order: oldest to newest
std::reverse(blockIds.begin(), blockIds.end());
return BlockRange(cacheManager, blockIds, 0);
}

BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
Expand Down Expand Up @@ -80,19 +100,6 @@ class BlockRange
mBlockIds = std::move(blockIds);
}

[[nodiscard]] std::vector<size_t> getBlockHashes() const
{
TLLM_CHECK(mManager);
std::vector<size_t> blockHashes;
blockHashes.reserve(mBlockIds.size());
auto& blockManager = mManager->getBlockManager();
for (auto id : mBlockIds)
{
blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash());
}
return blockHashes;
}

void updatePoolIdx(SizeType32 poolIdx)
{
TLLM_CHECK(mManager);
Expand Down
13 changes: 0 additions & 13 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -1843,16 +1843,6 @@ class GenericLlmRequest
}
}

void setRequestedBlockHashes(std::vector<size_t> hashes)
{
mRequestedBlockHashes = std::move(hashes);
}

[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
{
return mRequestedBlockHashes;
}

void setIsDummyRequest(bool isDummyRequest)
{
mIsDummyRequest = isDummyRequest;
Expand Down Expand Up @@ -2044,9 +2034,6 @@ class GenericLlmRequest
// Tensors containing the additional generation output.
TensorMap mAdditionalGenerationOutputTensors;

// Context request only. The hashes of the blocks that are requested by the corresponding generation request.
std::vector<size_t> mRequestedBlockHashes;

bool mIsDummyRequest{false};

bool mUseDraftModel{false};
Expand Down
Loading