Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 101 additions & 61 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,119 +17,159 @@
#pragma once

#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/runtime/iTensor.h"

namespace tensorrt_llm::batch_manager::kv_cache_manager
{

class BlockIterator;

class BlockRange
class BlockRangeForWindow
{
public:
// C++20 std::default_sentinel_t equivalent
BlockRangeForWindow(std::vector<SizeType32> blockIds, runtime::ITensor::SharedPtr pool)
: mBlockIds(std::move(blockIds))
, mPool(std::move(pool))
{
}

struct Sentinel
{
};

static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId,
SizeType32 beam = kFIRST_AND_ONLY_BEAM)
{
assert(kFIRST_AND_ONLY_BEAM == beam);
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
return BlockRange(cacheManager, blockIds, requestId);
}
friend class BlockIterator;
BlockIterator begin() const;

static BlockRange fromNewlyAllocatedBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
[[nodiscard]] Sentinel end() const
{
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
return BlockRange(cacheManager, blockIds, requestId);
return {};
}

BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
: mManager{nullptr}
, mPool{std::move(pool)}
, mWindowSize{0}
, mRequestId{0}
, mBlockIds{blockIds}
[[nodiscard]] size_t size() const
{
TLLM_CHECK(mPool);
return mBlockIds.size();
}

[[nodiscard]] BlockIterator begin() const;
private:
std::vector<SizeType32> mBlockIds;
runtime::ITensor::SharedPtr mPool;
};

[[nodiscard]] Sentinel end() const
class BlockRange
{
public:
static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
{
return {};

return BlockRange(cacheManager, requestId);
}

[[nodiscard]] size_t size() const
static BlockRange fromNewlyAllocatedBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
{
return mBlockIds.size();
std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow;

auto windowsMetadata = cacheManager.getBlockManager().getWindowSizesMetadata();
for (auto const& [windowSize, metadata] : windowsMetadata)
{
blockIdsPerWindow[windowSize] = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
}
return BlockRange(cacheManager, std::move(blockIdsPerWindow), requestId);
}

[[nodiscard]] std::vector<SizeType32> const& getBlockIds() const
void setBlockIdsForWindow(SizeType32 windowSize, std::vector<SizeType32> blockIds)
{
return mBlockIds;
TLLM_CHECK_WITH_INFO(mBlockIdsPerWindow.find(windowSize) != mBlockIdsPerWindow.end(),
"Window size %d should exists", windowSize);
mBlockIdsPerWindow[windowSize] = std::move(blockIds);
}

void setBlockIds(std::vector<SizeType32> blockIds)
void setBlockIdsForAllWindows(std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow)
{
mBlockIds = std::move(blockIds);
for (auto const& [windowSize, blockIds] : blockIdsPerWindow)
{
TLLM_CHECK_WITH_INFO(
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d should exists", windowSize);
}
mBlockIdsPerWindow = std::move(blockIdsPerWindow);
}

[[nodiscard]] std::vector<size_t> getBlockHashes() const
[[nodiscard]] std::unordered_map<SizeType32, std::vector<size_t>> getBlockHashesPerWindow() const
{
TLLM_CHECK(mManager);
std::vector<size_t> blockHashes;
blockHashes.reserve(mBlockIds.size());
std::unordered_map<SizeType32, std::vector<size_t>> blockHashesPerWindow;
auto& blockManager = mManager->getBlockManager();
for (auto id : mBlockIds)
for (auto const& [windowSize, blockIds] : mBlockIdsPerWindow)
{
blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash());
for (auto const& blockId : blockIds)
{
blockHashesPerWindow[windowSize].emplace_back(
blockManager.getBlockById(blockId, windowSize)->getHash());
}
}
return blockHashes;
return blockHashesPerWindow;
}

void updatePoolIdx(SizeType32 poolIdx)
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const
{
TLLM_CHECK(mManager);
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx);
if (newWindowSize != mWindowSize)
TLLM_CHECK_WITH_INFO(
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize);
auto pool = mPoolsPerWindow.at(windowSize).front();
auto blockIds = mBlockIdsPerWindow.at(windowSize);
return BlockRangeForWindow(std::move(blockIds), std::move(pool));
}

std::vector<SizeType32> getWindowSizes() const
{
std::vector<SizeType32> windowSizes;
for (auto const& [windowSize, _] : mPoolsPerWindow)
{
mWindowSize = newWindowSize;
mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM);
windowSizes.push_back(windowSize);
}
return windowSizes;
}
Comment on lines +122 to 130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Return window sizes in deterministic order

Current implementation iterates an unordered_map, making order non-deterministic across processes and builds. Sorting eliminates test flakiness and keeps sender/receiver blockIdx alignment.

     std::vector<SizeType32> getWindowSizes() const
     {
-        std::vector<SizeType32> windowSizes;
-        for (auto const& [windowSize, _] : mPoolsPerWindow)
-        {
-            windowSizes.push_back(windowSize);
-        }
-        return windowSizes;
+        std::vector<SizeType32> windowSizes;
+        windowSizes.reserve(mPoolsPerWindow.size());
+        for (auto const& [windowSize, _] : mPoolsPerWindow)
+        {
+            windowSizes.push_back(windowSize);
+        }
+        std::sort(windowSizes.begin(), windowSizes.end());
+        return windowSizes;
     }

Add the missing header include as well:

 #pragma once

 #include "tensorrt_llm/batch_manager/kvCacheManager.h"
 #include "tensorrt_llm/runtime/iTensor.h"
+#include <algorithm>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::vector<SizeType32> getWindowSizes() const
{
std::vector<SizeType32> windowSizes;
for (auto const& [windowSize, _] : mPoolsPerWindow)
{
mWindowSize = newWindowSize;
mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM);
windowSizes.push_back(windowSize);
}
return windowSizes;
}
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <algorithm>
std::vector<SizeType32> getWindowSizes() const
{
std::vector<SizeType32> windowSizes;
windowSizes.reserve(mPoolsPerWindow.size());
for (auto const& [windowSize, _] : mPoolsPerWindow)
{
windowSizes.push_back(windowSize);
}
std::sort(windowSizes.begin(), windowSizes.end());
return windowSizes;
}
🤖 Prompt for AI Agents
In cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h around lines 122 to
130, the function returns window sizes by iterating an unordered_map which
yields a non-deterministic order; collect the window sizes as you do now then
sort the vector before returning (use std::sort on the collected windowSizes) to
ensure deterministic ordering, and also add the missing include for <algorithm>
at the top of the header so std::sort is available.


friend class BlockIterator;
std::unordered_map<SizeType32, std::vector<SizeType32>> const& getBlockIdsPerWindow() const
{
return mBlockIdsPerWindow;
}

private:
BlockRange(
BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, LlmRequest::RequestIdType requestId)
BlockRange(BaseKVCacheManager const& cacheManager,
std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow, LlmRequest::RequestIdType requestId)
: mManager(&cacheManager)
, mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX))
, mWindowSize(firstWindowSize(cacheManager))
, mRequestId(requestId)
, mBlockIds(std::move(blockIds))
, mBlockIdsPerWindow(std::move(blockIdsPerWindow))
{

// cacheManager.getBlockManager.getPrimaryPool(0);
auto poolNum = mManager->getNumPools();
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
{
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
}
}

static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager)
BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
: mManager(&cacheManager)
, mRequestId(requestId)
{
constexpr SizeType32 FIRST_POOL_IDX = 0;
return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX);
auto poolNum = mManager->getNumPools();
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
{
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
mBlockIdsPerWindow[windowSize]
= cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
}
}

private:
BaseKVCacheManager const* mManager;
runtime::ITensor::SharedPtr mPool;
SizeType32 mWindowSize;
const LlmRequest::RequestIdType mRequestId;
std::vector<SizeType32> mBlockIds;
LlmRequest::RequestIdType const mRequestId;
std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow;
std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow;

static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;
Expand All @@ -144,7 +184,7 @@ class BlockIterator
using reference = value_type&;
using SizeType32 = tensorrt_llm::runtime::SizeType32;

BlockIterator(BlockRange const* range, size_t idx)
BlockIterator(BlockRangeForWindow const* range, size_t idx)
: mRange{range}
, mIdx{idx}
{
Expand Down Expand Up @@ -187,7 +227,7 @@ class BlockIterator
return mIdx == other.mIdx && mRange == other.mRange;
}

[[nodiscard]] bool operator==(BlockRange::Sentinel other) const
[[nodiscard]] bool operator==(BlockRangeForWindow::Sentinel other) const
{
return mIdx == mRange->mBlockIds.size();
}
Expand All @@ -207,12 +247,12 @@ class BlockIterator
}
}

BlockRange const* mRange;
BlockRangeForWindow const* mRange;
runtime::ITensor::SharedPtr mCurrent;
size_t mIdx;
};

inline BlockIterator BlockRange::begin() const
inline BlockIterator BlockRangeForWindow::begin() const
{
return {this, 0};
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -1831,14 +1831,14 @@ class GenericLlmRequest
}
}

void setRequestedBlockHashes(std::vector<size_t> hashes)
void setRequestedBlockHashes(std::unordered_map<SizeType32, std::vector<size_t>>&& hashesPerWindow)
{
mRequestedBlockHashes = std::move(hashes);
mRequestedBlockHashesPerWindow = std::move(hashesPerWindow);
}

[[nodiscard]] std::vector<size_t> const& getRequestedBlockHashes() const
[[nodiscard]] std::unordered_map<SizeType32, std::vector<size_t>> const& getRequestedBlockHashesPerWindow() const
{
return mRequestedBlockHashes;
return mRequestedBlockHashesPerWindow;
}

void setIsDummyRequest(bool isDummyRequest)
Expand Down Expand Up @@ -2033,7 +2033,7 @@ class GenericLlmRequest
TensorMap mAdditionalGenerationOutputTensors;

// Context request only. The hashes of the blocks that are requested by the corresponding generation request.
std::vector<size_t> mRequestedBlockHashes;
std::unordered_map<SizeType32, std::vector<size_t>> mRequestedBlockHashesPerWindow;

bool mIsDummyRequest{false};

Expand Down
Loading
Loading