-
Couldn't load subscription status.
- Fork 1.8k
Draft:[TRTLLM-7078][chore] optimal kvcache transfer for VWSA #6861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,119 +17,159 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "tensorrt_llm/batch_manager/kvCacheManager.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "tensorrt_llm/runtime/iTensor.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tensorrt_llm::batch_manager::kv_cache_manager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BlockIterator; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BlockRange | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BlockRangeForWindow | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // C++20 std::default_sentinel_t equivalent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRangeForWindow(std::vector<SizeType32> blockIds, runtime::ITensor::SharedPtr pool) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : mBlockIds(std::move(blockIds)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mPool(std::move(pool)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| struct Sentinel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SizeType32 beam = kFIRST_AND_ONLY_BEAM) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(kFIRST_AND_ONLY_BEAM == beam); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const windowSize = firstWindowSize(cacheManager); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return BlockRange(cacheManager, blockIds, requestId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| friend class BlockIterator; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockIterator begin() const; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static BlockRange fromNewlyAllocatedBlockIds( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] Sentinel end() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const windowSize = firstWindowSize(cacheManager); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return BlockRange(cacheManager, blockIds, requestId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : mManager{nullptr} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mPool{std::move(pool)} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mWindowSize{0} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mRequestId{0} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mBlockIds{blockIds} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] size_t size() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK(mPool); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mBlockIds.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] BlockIterator begin() const; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<SizeType32> mBlockIds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| runtime::ITensor::SharedPtr mPool; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] Sentinel end() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BlockRange | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static BlockRange fromAllBlockIds(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return BlockRange(cacheManager, requestId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] size_t size() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static BlockRange fromNewlyAllocatedBlockIds( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mBlockIds.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto windowsMetadata = cacheManager.getBlockManager().getWindowSizesMetadata(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto const& [windowSize, metadata] : windowsMetadata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockIdsPerWindow[windowSize] = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return BlockRange(cacheManager, std::move(blockIdsPerWindow), requestId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] std::vector<SizeType32> const& getBlockIds() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void setBlockIdsForWindow(SizeType32 windowSize, std::vector<SizeType32> blockIds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mBlockIds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK_WITH_INFO(mBlockIdsPerWindow.find(windowSize) != mBlockIdsPerWindow.end(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Window size %d should exists", windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mBlockIdsPerWindow[windowSize] = std::move(blockIds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void setBlockIds(std::vector<SizeType32> blockIds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void setBlockIdsForAllWindows(std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mBlockIds = std::move(blockIds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto const& [windowSize, blockIds] : blockIdsPerWindow) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK_WITH_INFO( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d should exists", windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mBlockIdsPerWindow = std::move(blockIdsPerWindow); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] std::vector<size_t> getBlockHashes() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] std::unordered_map<SizeType32, std::vector<size_t>> getBlockHashesPerWindow() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK(mManager); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<size_t> blockHashes; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockHashes.reserve(mBlockIds.size()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<size_t>> blockHashesPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto& blockManager = mManager->getBlockManager(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto id : mBlockIds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto const& [windowSize, blockIds] : mBlockIdsPerWindow) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto const& blockId : blockIds) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockHashesPerWindow[windowSize].emplace_back( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockManager.getBlockById(blockId, windowSize)->getHash()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return blockHashes; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return blockHashesPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void updatePoolIdx(SizeType32 poolIdx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK(mManager); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mPool = mManager->getBlockManager().getPrimaryPool(poolIdx); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (newWindowSize != mWindowSize) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TLLM_CHECK_WITH_INFO( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto pool = mPoolsPerWindow.at(windowSize).front(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto blockIds = mBlockIdsPerWindow.at(windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return BlockRangeForWindow(std::move(blockIds), std::move(pool)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<SizeType32> getWindowSizes() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<SizeType32> windowSizes; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto const& [windowSize, _] : mPoolsPerWindow) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mWindowSize = newWindowSize; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| windowSizes.push_back(windowSize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return windowSizes; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+122
to
130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Return window sizes in deterministic order Current implementation iterates an std::vector<SizeType32> getWindowSizes() const
{
- std::vector<SizeType32> windowSizes;
- for (auto const& [windowSize, _] : mPoolsPerWindow)
- {
- windowSizes.push_back(windowSize);
- }
- return windowSizes;
+ std::vector<SizeType32> windowSizes;
+ windowSizes.reserve(mPoolsPerWindow.size());
+ for (auto const& [windowSize, _] : mPoolsPerWindow)
+ {
+ windowSizes.push_back(windowSize);
+ }
+ std::sort(windowSizes.begin(), windowSizes.end());
+ return windowSizes;
}Add the missing header include as well: #pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
+#include <algorithm>📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| friend class BlockIterator; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<SizeType32>> const& getBlockIdsPerWindow() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mBlockIdsPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRange( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRange(BaseKVCacheManager const& cacheManager, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<SizeType32>> blockIdsPerWindow, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : mManager(&cacheManager) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mWindowSize(firstWindowSize(cacheManager)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mRequestId(requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mBlockIds(std::move(blockIds)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mBlockIdsPerWindow(std::move(blockIdsPerWindow)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // cacheManager.getBlockManager.getPrimaryPool(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto poolNum = mManager->getNumPools(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : mManager(&cacheManager) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mRequestId(requestId) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constexpr SizeType32 FIRST_POOL_IDX = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto poolNum = mManager->getNumPools(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mBlockIdsPerWindow[windowSize] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| = cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BaseKVCacheManager const* mManager; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| runtime::ITensor::SharedPtr mPool; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SizeType32 mWindowSize; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const LlmRequest::RequestIdType mRequestId; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<SizeType32> mBlockIds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| LlmRequest::RequestIdType const mRequestId; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static constexpr SizeType32 kFIRST_POOL_INDEX = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -144,7 +184,7 @@ class BlockIterator | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using reference = value_type&; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using SizeType32 = tensorrt_llm::runtime::SizeType32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockIterator(BlockRange const* range, size_t idx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockIterator(BlockRangeForWindow const* range, size_t idx) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : mRange{range} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| , mIdx{idx} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -187,7 +227,7 @@ class BlockIterator | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mIdx == other.mIdx && mRange == other.mRange; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] bool operator==(BlockRange::Sentinel other) const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [[nodiscard]] bool operator==(BlockRangeForWindow::Sentinel other) const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mIdx == mRange->mBlockIds.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -207,12 +247,12 @@ class BlockIterator | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRange const* mRange; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BlockRangeForWindow const* mRange; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| runtime::ITensor::SharedPtr mCurrent; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| size_t mIdx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inline BlockIterator BlockRange::begin() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inline BlockIterator BlockRangeForWindow::begin() const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {this, 0}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.