From f892467c14e9c5d8923513ca14c27ca6d8831df4 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 12 Aug 2025 13:05:33 -0700 Subject: [PATCH 01/17] Initial iteration for supporting block hash transfer Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Add unittest for findBlocksInReuseTreeByHashes Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> fixes Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Switch from hash id to block key Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Add support for blockKeys Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Fix bugs Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Fix accuracy bug and add tests Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 95 +++++++- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 39 ++-- .../tensorrt_llm/batch_manager/llmRequest.h | 13 -- .../executor/dataTransceiverState.h | 16 +- .../tensorrt_llm/executor/serialization.h | 5 + .../batch_manager/cacheFormatter.cpp | 46 ++-- .../batch_manager/cacheFormatter.h | 79 ++++++- .../batch_manager/cacheTransceiver.cpp | 11 +- .../batch_manager/dataTransceiver.cpp | 103 +++++---- .../batch_manager/dataTransceiver.h | 39 +++- .../batch_manager/kvCacheManager.cpp | 218 +++++++++++++----- .../batch_manager/mlaCacheFormatter.cpp | 11 +- cpp/tensorrt_llm/common/envUtils.cpp | 6 - .../agent_utils/connection.cpp | 1 - cpp/tensorrt_llm/executor/serialization.cpp | 40 +++- cpp/tensorrt_llm/executor/serializeUtils.h | 90 ++++++++ .../nanobind/batch_manager/kvCacheManager.cpp | 4 +- .../pybind/batch_manager/kvCacheManager.cpp | 5 +- .../batch_manager/kvCacheManagerTest.cpp | 91 ++++++++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 19 +- .../_torch/pyexecutor/resource_manager.py | 6 + .../accuracy/test_disaggregated_serving.py | 26 ++- .../test_lists/test-db/l0_dgx_h100.yml | 19 ++ 23 files changed, 782 insertions(+), 200 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d97b87086f5..b670a18c90c 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -84,6 +84,32 @@ using MmKey = std::pair, SizeType32>; template using OptionalRef = tensorrt_llm::common::OptionalRef; +//! \brief Split vector into list of blocks of given size. +//! \param vec vector to split +//! \param usableSize part of the vector that is processed +//! \param elementsPerBlock desired size of blocks +//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end +//! \return list of blocks +template +std::list> chopVectorIntoBlocks( + std::vector const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial) +{ + TLLM_CHECK_WITH_INFO( + usableSize <= static_cast(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size()); + std::list> blockedVectors; + auto const vecEnd = vec.begin() + usableSize; + for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock) + { + auto blockSize = std::min(elementsPerBlock, static_cast(std::distance(begin, vecEnd))); + auto end = begin + blockSize; + if (blockSize == elementsPerBlock || allowPartial) + { + blockedVectors.emplace_back(begin, end); + } + } + return blockedVectors; +} + struct TempAttentionWindowInputs { bool pagedContextFMHA; @@ -114,6 +140,9 @@ struct WindowSizeMetadata } }; +std::vector generateBlockHashExtraKeys( + tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx); + struct BlockKey { bool usesExtraIds = false; @@ -147,11 +176,7 @@ struct BlockKey { } - bool operator==(BlockKey const& other) const noexcept - { - return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId - && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); - } + bool operator==(BlockKey const& other) const noexcept; int partialMatch(BlockKey const& other) const noexcept { @@ -166,6 +191,8 @@ struct BlockKey } }; +std::vector buildBlockKeys(std::list& blockedUniqueTokens, LlmRequest const& llmRequest); + // Implement hash functor for BlockKey. // This allows us to use unordered_map with BlockKey as key. // Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933 @@ -582,6 +609,9 @@ class WindowBlockManager void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); + //! \brief Pin blocks associated with a sequence to prevent eviction. + void pinBlocks(GenerationRequest& sequence); + //! \brief Release blocks of the sequence. //! \details When llmRequest is provided and reuse is enabled, blocks will be stored. void releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest); @@ -785,6 +815,11 @@ class WindowBlockManager { return mIsSWA; } + [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey); + + //! \brief Unpin blocks by starting from a block id and walking prev pointers. + void unpinBlocksById(KVCacheBlock::IdType blockId); private: //! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq. @@ -890,6 +925,9 @@ class WindowBlockManager bool mCopyOnPartialReuse; // The kv cache connector manager std::shared_ptr mKvCacheConnectorManager; + + // Mutex for the cached blocks root + std::mutex mCachedBlocksRootMutex; }; class BlockManager @@ -947,6 +985,12 @@ class BlockManager void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); + /// @brief Pin all blocks associated with a sequence across all window managers. + /// @param sequence The generation request whose blocks should be pinned. + void pinBlocks(GenerationRequest& sequence); + + void unpinBlocksById(KVCacheBlock::IdType blockId); + void releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize); void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx, @@ -1003,6 +1047,15 @@ class BlockManager return sumWindows([](auto const& manager) { return manager.getNumAllocTotalBlocks(); }); } + [[nodiscard]] SizeType32 getFirstWindowSize() const + { + if (mWindowBlockManagers.empty()) + { + return 0; + } + return mWindowBlockManagers.begin()->first; + } + [[nodiscard]] SizeType32 getNumAllocNewBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumAllocNewBlocks(); }); @@ -1133,6 +1186,12 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } + [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey, SizeType32 windowSize) + { + return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey); + } + [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); @@ -1274,6 +1333,10 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req, SizeType32 windowSize) const = 0; + /// @brief Pin blocks associated with a request to prevent eviction. + /// @param requestId The ID of the request whose blocks should be pinned. + virtual void pinBlocks(LlmRequest::RequestIdType requestId) = 0; + /// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed. virtual void addToken(LlmRequest::RequestIdType requestId) = 0; @@ -1346,6 +1409,10 @@ class BaseKVCacheManager LlmRequest::RequestIdType requestId, SizeType32 windowSize) const = 0; + /// @brief Get the last block id (beam 0) for a given sequence and window size + [[nodiscard]] virtual std::optional getLastBlockId(LlmRequest::RequestIdType requestId) const + = 0; + [[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0; [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; @@ -1414,6 +1481,12 @@ class BaseKVCacheManager [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const = 0; [[nodiscard]] virtual CacheType getCacheType() const = 0; + + [[nodiscard]] virtual std::optional> findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey, SizeType32 windowSize) + = 0; + + virtual void unpinBlocksById(KVCacheBlock::IdType blockId) = 0; }; class KVCacheManager : public BaseKVCacheManager @@ -1668,6 +1741,12 @@ class KVCacheManager : public BaseKVCacheManager [[nodiscard]] static SizeType32 calculateMaxBlockRequirements(SizeType32 inputLength, SizeType32 outputLength, SizeType32 sinkTokenLength, SizeType32 windowSize, SizeType32 beamWidth, SizeType32 tokensPerBlock); + void pinBlocks(LlmRequest::RequestIdType requestId) override; + + void unpinBlocksById(KVCacheBlock::IdType blockId) override; + + std::optional getLastBlockId(LlmRequest::RequestIdType requestId) const override; + /// @brief Calculates the number of kv-cache blocks that a sequence will require, for a single beam. /// /// @param sequenceLength The total length of the sequence (input and output). @@ -1706,6 +1785,12 @@ class KVCacheManager : public BaseKVCacheManager mBlockManager.flushIterationEvents(); } + std::optional> findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey, SizeType32 windowSize) override + { + return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize); + } + /// @brief Finds the maximum attention window that can be used on a sequence, given some kv-cache block capacity. /// /// @param inputLength The number of input tokens in the sequence. diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 2aebf77b96d..0e7bad3a585 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -48,6 +48,32 @@ class BlockRange return BlockRange(cacheManager, blockIds, requestId); } + static BlockRange fromReuseTree( + BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd) + { + auto const windowSize = firstWindowSize(cacheManager); + // Find the last block in the reuse tree for the provided full sequence of block keys + auto lastBlock = *cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize); + // TODO: handle the case where the last block is not found + TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); + int32_t const numBlocksToCollect = indexFromEnd + 1; + + std::vector blockIds; + blockIds.reserve(numBlocksToCollect); + for (int32_t i = 0; i < numBlocksToCollect; ++i) + { + blockIds.push_back(lastBlock->getBlockId()); + if (i + 1 < numBlocksToCollect) + { + lastBlock = lastBlock->getPrevBlock(); + TLLM_CHECK_WITH_INFO(lastBlock, "Previous block not found while traversing reuse tree"); + } + } + // Reverse to chronological order: oldest to newest + std::reverse(blockIds.begin(), blockIds.end()); + return BlockRange(cacheManager, blockIds, 0); + } + BlockRange(runtime::ITensor::SharedPtr pool, std::vector const& blockIds) // Only used in tests : mManager{nullptr} , mPool{std::move(pool)} @@ -80,19 +106,6 @@ class BlockRange mBlockIds = std::move(blockIds); } - [[nodiscard]] std::vector getBlockHashes() const - { - TLLM_CHECK(mManager); - std::vector blockHashes; - blockHashes.reserve(mBlockIds.size()); - auto& blockManager = mManager->getBlockManager(); - for (auto id : mBlockIds) - { - blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash()); - } - return blockHashes; - } - void updatePoolIdx(SizeType32 poolIdx) { TLLM_CHECK(mManager); diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 85c9a3ac942..275bc75721a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1843,16 +1843,6 @@ class GenericLlmRequest } } - void setRequestedBlockHashes(std::vector hashes) - { - mRequestedBlockHashes = std::move(hashes); - } - - [[nodiscard]] std::vector const& getRequestedBlockHashes() const - { - return mRequestedBlockHashes; - } - void setIsDummyRequest(bool isDummyRequest) { mIsDummyRequest = isDummyRequest; @@ -2044,9 +2034,6 @@ class GenericLlmRequest // Tensors containing the additional generation output. TensorMap mAdditionalGenerationOutputTensors; - // Context request only. The hashes of the blocks that are requested by the corresponding generation request. - std::vector mRequestedBlockHashes; - bool mIsDummyRequest{false}; bool mUseDraftModel{false}; diff --git a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h index d49447a09a0..6e2c9d40c9d 100644 --- a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h +++ b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h @@ -50,7 +50,7 @@ class CacheState final CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, - AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2) + AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false) : mModelConfig(std::move(modelConfig)) , mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(), worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), @@ -58,32 +58,35 @@ class CacheState final , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { + mEnableBlockReuse = enableBlockReuse; } CacheState(std::vector nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0) + int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false) : mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock} , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, attentionLayerNumPerPP} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { + mEnableBlockReuse = enableBlockReuse; } CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0) + int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false) : mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock} , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, attentionLayerNumPerPP} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { + mEnableBlockReuse = enableBlockReuse; } [[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept @@ -166,6 +169,11 @@ class CacheState final return mDataType; } + [[nodiscard]] bool getEnableBlockReuse() const + { + return mEnableBlockReuse; + } + [[nodiscard]] std::string toString() const { std::stringstream sstring; @@ -185,6 +193,7 @@ class CacheState final sstring << "kvFactor:" << mAttentionConfig.mKvFactor << "\n"; sstring << "dpRank:" << mParallelConfig.mDPrank << "\n"; sstring << "dpSize:" << mParallelConfig.mDPsize << "\n"; + sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n"; return sstring.str(); } @@ -194,6 +203,7 @@ class CacheState final ParallelConfig mParallelConfig; nvinfer1::DataType mDataType; AttentionConfig mAttentionConfig; + bool mEnableBlockReuse{false}; }; struct MpiState diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index c370a652350..1d30da2027c 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/tensor.h" @@ -36,6 +37,10 @@ struct SocketState; class Serialization { public: + // BlockKey (KV cache) + static size_t serializedSize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key); + static void serialize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key, std::ostream& os); + static tensorrt_llm::batch_manager::kv_cache_manager::BlockKey deserializeBlockKey(std::istream& is); // TimePoint [[nodiscard]] static RequestPerfMetrics::TimePoint deserializeTimePoint(std::istream& is); static void serialize(RequestPerfMetrics::TimePoint const& tp, std::ostream& os); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 168ea89693f..cecfa62df41 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -40,37 +40,34 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { -BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) +BlockRange getBlockRangeForSending( + BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, BlockKey const& lastBlockKey, int32_t indexFromEnd) { - size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size(); constexpr SizeType32 beam{0}; - auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0) { - // disable selective cache transfer for poolNum > 1 + auto blockRange = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); return blockRange; } - if (requestBlockNum < blockRange.size() && requestBlockNum > 0) - { - // handle block reuse, the prefix blocks are reused - // TODO(zhengd): pass the hashes directly instead of from llmRequest; use hash instead of block num - auto const& ids = blockRange.getBlockIds(); - blockRange.setBlockIds({ids.end() - requestBlockNum, ids.end()}); - } - return blockRange; + TLLM_CHECK_WITH_INFO(lastBlockKey.uniqueTokens.size() > 0, "lastBlockKey must be non-empty when reuse is enabled"); + return BlockRange::fromReuseTree(*cacheManager, lastBlockKey, indexFromEnd); } -BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest) +BlockRange getBlockRangeForReceiving( + BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse) { auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum > 1 || common::getEnvDisableSelectiveCacheTransfer()) + if (poolNum == 1 && cacheManager->isEnableBlockReuse() && srcEnableBlockReuse) + { + return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + } + else { constexpr SizeType32 beam{0}; return BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); } - return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); } bool CacheFormatter::needSendCache( @@ -168,6 +165,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); + auto indexFromEnd = session.getIndexFromEnd(); auto& bufferManager = session.getBufferManager(); // Some TP rank don't need to send cache since duplicate header is not needed. if (!needSendCache(selfConfig, destConfig, selfIdx)) @@ -175,8 +173,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio return; } auto& blockManager = mCacheManager->getBlockManager(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); - + auto const& lastBlockKey = session.getLastBlockKey(); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd); auto const numPools = blockManager.getNumPools(); // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1... @@ -225,7 +223,10 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio std::map> inputKvCacheBlocksPerWindow; for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) { - blockRange.updatePoolIdx(poolIdx); + if (numPools > 1) + { + blockRange.updatePoolIdx(poolIdx); + } SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); TLLM_CHECK_WITH_INFO(inputKvCacheBlocksPerWindow.find(window) == inputKvCacheBlocksPerWindow.end(), "window size already exists, which is not supported"); @@ -482,7 +483,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); auto& bufferManager = session.getBufferManager(); - auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest); + auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime; bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point(); @@ -498,7 +499,10 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess size_t cacheBlockSizeSum = 0; for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) { - blockRange.updatePoolIdx(poolIdx); + if (numPools > 1) + { + blockRange.updatePoolIdx(poolIdx); + } SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); TLLM_CHECK_WITH_INFO(outputBuffersPerWindow.find(window) == outputBuffersPerWindow.end(), "window size already exists, which is not supported"); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index 0071627af67..beca72696ea 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -42,6 +42,8 @@ class TransferSession; namespace tensorrt_llm::batch_manager::kv_cache_manager { +BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, + BlockKey const& lastBlockKey, SizeType32 indexFromEnd); using DataContext = tensorrt_llm::executor::kv_cache::DataContext; using Connection = tensorrt_llm::executor::kv_cache::Connection; @@ -51,8 +53,83 @@ using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager; using BlockRange = kv_cache_manager::BlockRange; BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); +BlockRange getBlockRangeForReceiving( + BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse); -BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest); +class KvCacheMeasureHelper +{ +public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + + KvCacheMeasureHelper(std::string output_path) + : mOutputPath(std::move(output_path)) + { + } + + void markAsSender(bool isSender) + { + mIsSender = isSender; + } + + void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) + { + auto bandwidth = size * 8 / (duration / 1000) / 1e9; + if (mOutputPath.empty()) + { + return; + } + + std::lock_guard lock(mMutex); + mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); + } + + ~KvCacheMeasureHelper() + { + if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) + { + TLLM_CHECK(mIsSender.has_value()); + auto rank = mpi::MpiComm::world().getRank(); + std::string outFilePath + = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; + std::ofstream outFile(outFilePath); + + TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); + + size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); + + outFile << "RequestID"; + for (size_t i = 0; i < numTransferMeasure; i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + + for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) + { + outFile << requestID; + + for (auto const& measure : measures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n'; + } + + outFile.close(); + } + } + +private: + std::map> mRequestKVCacheTranfserMeasure; + std::string mOutputPath; + std::mutex mMutex; + std::optional mIsSender; +}; // Used to support the cache transmission with different layouts and different protocols. class BaseCacheFormatter diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 2d0031c1bea..d832a80b358 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -117,11 +117,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa : mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) , mCacheTransceiverConfig{cacheTransceiverConfig} { - if (worldConfig.isPipelineParallel()) - { - mMpiGroupPipeParaComm = std::make_shared( - mMpiGroupComm->split(worldConfig.getTensorParallelRank(), worldConfig.getPipelineParallelRank())); - } + using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter; if (worldConfig.isTensorParallel()) { mMpiGroupTensorParaComm = std::make_shared( @@ -132,8 +128,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa { kvFactor = 1; } - mCacheState = std::make_unique( - cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, dataType, attentionType, kvFactor); + mCacheState = std::make_unique(cacheStateModelCfg, worldConfig, + attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse()); if (mCacheState->getParallelConfig().mEnableAttentionDP) { @@ -311,7 +307,6 @@ std::vector gatherRequestIds( int localSize = static_cast(requestIds.size()); std::vector sizes(mpiComm.getSize()); mpiComm.allgather(&localSize, sizes.data(), 1, mpi::MpiType::kINT32); - // std::vector all_data(total_size); std::vector displs(mpiComm.getSize()); int totalSize = 0; for (int i = 0; i < mpiComm.getSize(); i++) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 527291b220b..58007d1ac04 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -141,11 +141,6 @@ void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) cons outFile << '\n' << std::flush; } -std::vector const& RequestInfo::getBlockHashes() const noexcept -{ - return mBlockHashes; -} - using runtime::SizeType32; using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager; using DataContext = tensorrt_llm::executor::kv_cache::DataContext; @@ -189,17 +184,19 @@ RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTran { } -RequestInfo::RequestInfo( - LlmRequest::RequestIdType requestId, std::vector blockHashes, executor::DataTransceiverState transState) +RequestInfo::RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState, + int32_t indexFromEnd, BlockKey const& lastBlockKey) : mRequestId{requestId} - , mBlockHashes{std::move(blockHashes)} + , mIndexFromEnd{indexFromEnd} + , mLastBlockKey{lastBlockKey} , mTransState{std::move(transState)} { } bool RequestInfo::operator==(RequestInfo const& rhs) const { - return mRequestId == rhs.mRequestId && mBlockHashes == rhs.mBlockHashes && mTransState == rhs.mTransState; + return mRequestId == rhs.mRequestId && mIndexFromEnd == rhs.mIndexFromEnd && mLastBlockKey == rhs.mLastBlockKey + && mTransState == rhs.mTransState; } LlmRequest::RequestIdType RequestInfo::getRequestId() const noexcept @@ -216,7 +213,8 @@ void RequestInfo::serialize(RequestInfo const& requestInfo, std::ostream& os) { namespace su = executor::serialize_utils; su::serialize(requestInfo.mRequestId, os); - su::serialize(requestInfo.mBlockHashes, os); + su::serialize(requestInfo.mIndexFromEnd, os); + su::serialize(requestInfo.mLastBlockKey, os); su::serialize(requestInfo.mTransState, os); } @@ -224,9 +222,10 @@ RequestInfo RequestInfo::deserialize(std::istream& is) { namespace su = executor::serialize_utils; auto requestId = su::deserialize(is); - auto blockHashes = su::deserialize(is); + auto indexFromEnd = su::deserialize(is); + auto lastBlockKey = su::deserialize(is); auto transState = su::deserialize(is); - return RequestInfo{requestId, std::move(blockHashes), std::move(transState)}; + return RequestInfo{requestId, std::move(transState), indexFromEnd, lastBlockKey}; } std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) @@ -234,7 +233,8 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) namespace su = executor::serialize_utils; std::size_t totalSize = 0; totalSize += su::serializedSize(requestInfo.mRequestId); - totalSize += su::serializedSize(requestInfo.mBlockHashes); + totalSize += su::serializedSize(requestInfo.mIndexFromEnd); + totalSize += su::serializedSize(requestInfo.mLastBlockKey); totalSize += su::serializedSize(requestInfo.mTransState); return totalSize; } @@ -244,6 +244,12 @@ class CacheSender::Impl public: using RequestIdType = LlmRequest::RequestIdType; + struct Response + { + LlmRequest* mRequest; + std::promise mPromise; + }; + Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} @@ -322,25 +328,18 @@ class CacheSender::Impl auto* agentConnectionManager = dynamic_cast(mManager); bool isAgent = agentConnectionManager != nullptr; - auto agentRecvFun = [&](RequestInfo& requestInfo) - { - auto const* connection = agentConnectionManager->recvConnectionAndRequestInfo(requestInfo); - return connection; - }; TransceiverTag::Id id; RequestInfo info; - auto const* connection = isAgent ? agentRecvFun(info) + auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info) : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); if (!isAgent) { TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND); std::uint64_t infoSize{0}; - connection->recv( - executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->recv(DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); std::string serializedInfo; serializedInfo.resize(infoSize); - connection->recv( - executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + connection->recv(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); std::istringstream iss(serializedInfo); info = RequestInfo::deserialize(iss); } @@ -363,8 +362,8 @@ class CacheSender::Impl if (it == mRequestToSession.end()) { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, - !common::getEnvKVCacheTransferOutputPath().empty()); + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, + info.getIndexFromEnd(), info.getLastBlockKey(), nullptr, !common::getEnvKVCacheTransferOutputPath().empty()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -448,7 +447,7 @@ class CacheSender::Impl } catch (std::exception const& e) { - TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s ", e.what()); + TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s request id: %ld", e.what(), id); resp.mPromise.set_exception(std::current_exception()); } } @@ -460,7 +459,7 @@ class CacheSender::Impl mAsyncSendResource.mCVforQueue.notify_one(); } - void sendResponse(std::vector const& blockHashes, std::map::iterator it) + void sendResponse(std::map::iterator it) { auto reqId = mCurrentRequest.value(); auto count = --mRemainSendCount[reqId]; @@ -469,10 +468,6 @@ class CacheSender::Impl { mRemainSendCount.erase(reqId); - // TODO(zhengd): pass the hashes directly instead of update llmRequest - auto llmRequest = it->second.mRequest; - llmRequest->setRequestedBlockHashes(std::move(blockHashes)); - asyncSendAndRemoveResponse(it->first, std::move(it->second)); removeResponse(it); } @@ -496,12 +491,10 @@ class CacheSender::Impl { break; } - std::vector blockHashes; if (!mReadyResponses.empty()) { auto const& requestInfo = recvRequestInfo(); auto reqId = requestInfo.getRequestId(); - blockHashes = requestInfo.getBlockHashes(); mCurrentRequest = reqId; if (mRemainSendCount.find(reqId) == mRemainSendCount.end()) @@ -512,7 +505,7 @@ class CacheSender::Impl auto it = getCurrentResponse(); if (it != mReadyResponses.end()) { - sendResponse(blockHashes, it); + sendResponse(it); } else { @@ -527,7 +520,7 @@ class CacheSender::Impl } it = getCurrentResponse(); } - sendResponse(blockHashes, it); + sendResponse(it); } } } @@ -587,7 +580,7 @@ class CacheSender::Impl std::map mReadyResponses; std::mutex mSenderMutex, mCondMutex; std::atomic mAnyReady{false}, mTerminate{false}; - std::condition_variable mSenderCv; + std::condition_variable mSenderCv, mResponderCv; std::future mResponseFuture; std::unordered_map mRemainSendCount; AsyncSendResource mAsyncSendResource; @@ -686,14 +679,31 @@ class CacheReceiver::Impl RequestInfo requestInfo(requestId, mSelfState); - auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer() - || (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1); - if (!disableSelectiveCacheTransfer) + if (mFormatter->getCacheManager()->getBlockManager().getNumPools() == 1) { auto* cacheManager = mFormatter->getCacheManager(); - auto blockRange - = kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); - requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState); + auto beam = 0; + auto requestedBlockRange + = getBlockRangeForReceiving(cacheManager, llmRequest, destCacheState.getEnableBlockReuse()); + + auto const& uniqueTokens = llmRequest.getUniqueTokens(beam); + auto lastBlockKey + = BlockKey(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), uniqueTokens); + if (llmRequest.getInputTokensExtraIds().has_value()) + { + auto tokensPerBlock = cacheManager->getBlockManager().getTokensPerBlock(); + SizeType32 startTokenIdx + = static_cast(uniqueTokens.size() / tokensPerBlock) * tokensPerBlock; + SizeType32 endTokenIdx = static_cast(uniqueTokens.size()); + auto extraKeys = kv_cache_manager::generateBlockHashExtraKeys(llmRequest, startTokenIdx, endTokenIdx); + lastBlockKey.extraKeys = std::move(extraKeys); + } + // Compute indexFromEnd from the number of requested blocks + size_t requestedBlockSize = requestedBlockRange.getBlockIds().size(); + TLLM_CHECK_WITH_INFO(requestedBlockSize > 0, "requestedBlockSize must be > 0"); + int32_t indexFromEnd = static_cast(requestedBlockSize - 1); + + requestInfo = RequestInfo(requestId, mSelfState, indexFromEnd, lastBlockKey); } auto* agentConnectionManager = dynamic_cast(mManager); @@ -738,7 +748,8 @@ class CacheReceiver::Impl } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); + contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(), + &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); } std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) @@ -766,9 +777,9 @@ class CacheReceiver::Impl auto const& serializedInfo = oss.str(); std::size_t const infoSize = serializedInfo.size(); TransceiverTag::Id id{TransceiverTag::Id::REQUEST_SEND}; - connection->send(executor::kv_cache::DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); - connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); - connection->send(executor::kv_cache::DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); + connection->send(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id)); + connection->send(DataContext{TransceiverTag::kINFO_SIZE_TAG}, &infoSize, sizeof(infoSize)); + connection->send(DataContext{TransceiverTag::kINFO_TAG}, serializedInfo.data(), infoSize); } ~Impl() diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 2de48dc0bc3..a41c7eba480 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -42,6 +42,7 @@ class BaseCacheFormatter; } using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter; +using BlockKey = kv_cache_manager::BlockKey; // TODO: unify the following class into a namespace like tensorrt_llm::transmission using DataContext = tensorrt_llm::executor::kv_cache::DataContext; @@ -61,7 +62,8 @@ class TransferSession TransferSession(std::vector connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) + runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey, + LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) @@ -69,6 +71,8 @@ class TransferSession , mBufferManager(&bufferManager) , mRequest(llmRequest) , mRecordMeasure(recordMeasure) + , mIndexFromEnd(indexFromEnd) + , mLastBlockKey(lastBlockKey) { TLLM_CHECK(!mConnections.empty()); } @@ -100,6 +104,16 @@ class TransferSession // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file void exportMeasure(std::ofstream& outFile, bool isContext) const; + [[nodiscard]] int32_t getIndexFromEnd() const noexcept + { + return mIndexFromEnd; + } + + [[nodiscard]] BlockKey const& getLastBlockKey() const noexcept + { + return mLastBlockKey; + } + private: std::vector mConnections; DataContext mDataContext; @@ -109,7 +123,11 @@ class TransferSession LlmRequest const* mRequest; std::vector mMeasures; bool mRecordMeasure{false}; + int32_t mIndexFromEnd{0}; + BlockKey mLastBlockKey{}; }; +using UniqueToken = tensorrt_llm::runtime::UniqueToken; +using BlockKey = tensorrt_llm::batch_manager::kv_cache_manager::BlockKey; struct TransceiverTag { @@ -134,8 +152,8 @@ class RequestInfo /// @param transState The state of the data transceiver. RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState); - RequestInfo(LlmRequest::RequestIdType requestId, std::vector blockHashes, - executor::DataTransceiverState transState); + RequestInfo(LlmRequest::RequestIdType requestId, executor::DataTransceiverState transState, int32_t indexFromEnd, + BlockKey const& lastBlockKey); RequestInfo() = default; /// @brief Equality comparison operator. @@ -146,12 +164,20 @@ class RequestInfo /// @return The request ID. [[nodiscard]] LlmRequest::RequestIdType getRequestId() const noexcept; - [[nodiscard]] std::vector const& getBlockHashes() const noexcept; + [[nodiscard]] int32_t getIndexFromEnd() const noexcept + { + return mIndexFromEnd; + } /// @brief Return the state of the data transceiver. /// @return The state of the data transceiver. [[nodiscard]] executor::DataTransceiverState const& getTransState() const noexcept; + [[nodiscard]] BlockKey const& getLastBlockKey() const noexcept + { + return mLastBlockKey; + } + /// @brief Serialization. /// @param requestInfo Request information to be serialized. /// @param os The output stream to which the serialization result points. @@ -169,8 +195,11 @@ class RequestInfo private: // The ID used in the context phase of the current request. LlmRequest::RequestIdType mRequestId; + // Index from end indicating how many trailing blocks to transfer (index+1) + int32_t mIndexFromEnd{0}; - std::vector mBlockHashes; + // Last block key, used to derive other block keys on receiver + BlockKey mLastBlockKey{}; // The state of the data transceiver. executor::DataTransceiverState mTransState; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index b2997d70c8f..19048f7767d 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -51,37 +51,50 @@ using BlocksPerWindow = std::map> namespace { -//! \brief Split vector into list of blocks of given size. -//! \param vec vector to split -//! \param usableSize part of the vector that is processed -//! \param elementsPerBlock desired size of blocks -//! \param allowPartial whether to append a block smaller than `elementsPerBlock` at the end -//! \return list of blocks -template -std::list> chopVectorIntoBlocks( - std::vector const& vec, SizeType32 usableSize, SizeType32 elementsPerBlock, bool allowPartial) +inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept { - TLLM_CHECK_WITH_INFO( - usableSize <= static_cast(vec.size()), "usableSize=%d > %ld=vec.size()", usableSize, vec.size()); - std::list> blockedVectors; - auto const vecEnd = vec.begin() + usableSize; - for (auto begin = vec.begin(); begin < vecEnd; begin += elementsPerBlock) - { - auto blockSize = std::min(elementsPerBlock, static_cast(std::distance(begin, vecEnd))); - auto end = begin + blockSize; - if (blockSize == elementsPerBlock || allowPartial) - { - blockedVectors.emplace_back(begin, end); - } - } - return blockedVectors; + return static_cast((hashPart >> (24 - byteIdx * 8)) & 0xFF); } -inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept +//! \brief Get all blocks in a sequence by traversing backwards from the last block. +//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from +//! \return Vector of BlockPtr-s in sequence order +std::vector getAllSequenceBlocks(BlockPtr lastBlock) { - return static_cast((hashPart >> (24 - byteIdx * 8)) & 0xFF); + // First count the number of blocks to pre-allocate the vector + auto currentBlock = lastBlock; + size_t blockCount = 0; + while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + blockCount++; + currentBlock = currentBlock->getPrevBlockInSeq(); + } + + if (blockCount == 0) + { + return {}; + } + // Create and pre-allocate the vector with the correct size + std::vector sequenceBlocks(blockCount); + + // Now traverse backwards and fill from the end + currentBlock = lastBlock; + size_t currentIndex = blockCount - 1; + while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + sequenceBlocks[currentIndex--] = currentBlock; + currentBlock = currentBlock->getPrevBlockInSeq(); + } + + return sequenceBlocks; } + + +} // namespace + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ std::vector generateBlockHashExtraKeys( tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx) { @@ -157,43 +170,12 @@ std::vector buildBlockKeys( return blockKeys; } -//! \brief Get all blocks in a sequence by traversing backwards from the last block. -//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from -//! \return Vector of BlockPtr-s in sequence order -std::vector getAllSequenceBlocks(BlockPtr lastBlock) +bool BlockKey::operator==(BlockKey const& other) const noexcept { - // First count the number of blocks to pre-allocate the vector - auto currentBlock = lastBlock; - size_t blockCount = 0; - while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) - { - blockCount++; - currentBlock = currentBlock->getPrevBlockInSeq(); - } - - if (blockCount == 0) - { - return {}; - } - // Create and pre-allocate the vector with the correct size - std::vector sequenceBlocks(blockCount); - - // Now traverse backwards and fill from the end - currentBlock = lastBlock; - size_t currentIndex = blockCount - 1; - while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId) - { - sequenceBlocks[currentIndex--] = currentBlock; - currentBlock = currentBlock->getPrevBlockInSeq(); - } - - return sequenceBlocks; + return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens + && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); } -} // namespace - -namespace tensorrt_llm::batch_manager::kv_cache_manager -{ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept { // Hashing algorithm adapted from StackOverflow: @@ -1115,10 +1097,39 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; } +std::optional> WindowBlockManager::findBlocksInReuseTreeByBlockKey( + BlockKey const& blockKey) +{ + std::lock_guard lock(mCachedBlocksRootMutex); + auto blockedUniqueTokens + = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); + std::vector blockKeys; + for (auto const& blockedUniqueTokens : blockedUniqueTokens) + { + blockKeys.push_back(blockKey); + blockKeys.back().uniqueTokens = blockedUniqueTokens; + } + auto searchRoot = mCachedBlocksRoot; + for (auto const& blockKey : blockKeys) + { + auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr + ? searchRoot->findMatchingBlock(blockKey, true, true) + : std::make_tuple(false, 0, nullptr); + if (matchingBlock == nullptr) + { + return std::nullopt; + } + + searchRoot = std::move(matchingBlock); + } + return searchRoot; +} + SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence, std::vector const& perBlockRetentions, executor::KvCacheTransferMode mode, std::string const& directory) { + std::lock_guard lock(mCachedBlocksRootMutex); SizeType32 numMatchedTokens{0}; auto searchRoot = mCachedBlocksRoot; @@ -1156,6 +1167,12 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& mTransferManager->onboard(matchingBlock, newBlock, mPools, numMatched, mode, directory); // TODO: (optional) Send out event matchingBlock = newBlock; + if (blockItr != blockKeys.end()) + { + matchingBlock->setBlockKey( + *blockItr, blockItr->uniqueTokens.size() == static_cast(mTokensPerBlock)); + } + matchingBlock->setHash(); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Copied partially filled block %d", mLogPrefix.c_str(), matchingBlockId); } @@ -1447,6 +1464,7 @@ SizeType32 WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds) { SizeType32 numBlocksStoredForReuse = 0; + std::lock_guard lock(mCachedBlocksRootMutex); TLLM_LOG_DEBUG( "%s::storeBlocks - %zu blockKeys, %zu blockIds", mLogPrefix.c_str(), blockKeys.size(), blockIds.size()); @@ -1478,6 +1496,8 @@ SizeType32 WindowBlockManager::storeBlocks( // No match TLLM_LOG_DEBUG("%s::storeBlocks - No match, inserting block %d into search structure", mLogPrefix.c_str(), block->getBlockId()); + TLLM_CHECK_WITH_INFO(block->getBlockId() == bid, + "Block id mismatch " + std::to_string(block->getBlockId()) + " != " + std::to_string(bid)); needMatch = false; // no matching needed for following blocks block->setBlockKey(blockKey, static_cast(blockKey.uniqueTokens.size()) == mTokensPerBlock); block->setPrevBlock(searchRoot); @@ -1633,6 +1653,53 @@ void BlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRefsecond; + firstManager.unpinBlocksById(blockId); +} + +void WindowBlockManager::pinBlocks(GenerationRequest& sequence) +{ + auto const requestId = sequence.getRequestId(); + auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId); + for (auto& block : allocatedBlocks) + { + block->incRefCount(); + } +} + +void WindowBlockManager::unpinBlocksById(KVCacheBlock::IdType blockId) +{ + if (blockId < 0 || static_cast(blockId) >= mAllBlocksById.size()) + { + return; + } + auto block = mAllBlocksById[blockId]; + while (block && block->getBlockId() != KVCacheBlock::kCachedBlocksRootId) + { + block->decRefCount(); + if (!block->hasRefs()) + { + mEvictionPolicy->releaseBlock(block); + } + block = std::move(block->getPrevBlock()); + } +} + void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest) { for (auto& [_, manager] : mWindowBlockManagers) @@ -2229,6 +2296,11 @@ void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef llmRequest) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); + if (mBlockManager.getNumPools() == 1 + && llmRequest->getLlmRequestType() == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY && mEnableBlockReuse) + { + pinBlocks(requestId); + } auto sequenceNode = [this, requestId] { std::scoped_lock lock(mSequencesMtx); @@ -2254,6 +2326,17 @@ void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) mBlockManager.schedulingReleaseBlocks(requestId); } +void KVCacheManager::pinBlocks(RequestIdType requestId) +{ + auto& sequence = getSequence(requestId); + mBlockManager.pinBlocks(sequence); +} + +void KVCacheManager::unpinBlocksById(KVCacheBlock::IdType blockId) +{ + mBlockManager.unpinBlocksById(blockId); +} + SizeType32 KVCacheManager::copyBlockOffsets(ITensor& output, SizeType32 outputSlotOffset, RequestIdType requestId) const { auto const& sequence = getSequence(requestId); @@ -2633,6 +2716,23 @@ std::vector KVCacheManager::getNewlyAllocatedBlockIds( return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize); } +std::optional KVCacheManager::getLastBlockId(LlmRequest::RequestIdType requestId) const +{ + auto const& seq = getSequence(requestId); + // Use the first window size + auto firstWindowSize = mBlockManager.getFirstWindowSize(); + if (firstWindowSize == 0) + { + return std::nullopt; + } + auto const& perBeam = seq.getCacheBlockIds(firstWindowSize); + if (perBeam.empty() || perBeam[0].empty()) + { + return std::nullopt; + } + return perBeam[0].back(); +} + runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const { TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1, diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index aa45c241aae..6e3093cd452 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -128,6 +128,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); + auto indexFromEnd = session.getIndexFromEnd(); + auto const& lastBlockKey = session.getLastBlockKey(); auto const& connections = session.getConnections(); auto& bufferManager = session.getBufferManager(); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); @@ -138,7 +140,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses } auto const numPools = mCacheManager->getBlockManager().getNumPools(); - auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest); + auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd); auto lastTokenTime = llmRequest.getPerfMetrics().timingMetrics.lastTokenTime; bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point(); @@ -147,7 +149,10 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses std::vector inputKvCacheBlocks; for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) { - blockRange.updatePoolIdx(poolIdx); + if (numPools > 1) + { + blockRange.updatePoolIdx(poolIdx); + } for (auto it = blockRange.begin(); it != blockRange.end(); ++it) { blockNum++; @@ -346,7 +351,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime; bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point(); auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); - auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest); + auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); std::vector recvBufferTmps; std::vector outputBuffers; auto const numPools = mCacheManager->getBlockManager().getNumPools(); diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 80be36c30c7..39af4c984a5 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -318,12 +318,6 @@ bool getEnvDisaggLayerwise() return disaggLayerwise; } -bool getEnvDisableSelectiveCacheTransfer() -{ - static bool const disableSelectiveCacheTransfer = getBoolEnv("TRTLLM_DISABLE_SELECTIVE_CACHE_TRANSFER"); - return disableSelectiveCacheTransfer; -} - bool getEnvRequestKVCacheConcurrent() { static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT"); diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index 6ee50ab8e49..f2148401f3e 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -261,7 +261,6 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc while (true) { - updateUnhandledNotifications(); std::scoped_lock lock(mNotificationMutex); auto it = mUnhandledNotifications.begin(); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index b3726029ed5..1786a43bdbe 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -16,6 +16,7 @@ */ #include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/requestImpl.h" @@ -539,9 +540,10 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is) auto dataType = su::deserialize(is); auto attentionType = su::deserialize(is); auto kvFactor = su::deserialize(is); + auto enableBlockReuse = su::deserialize(is); return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, contextParallelism, attentionLayerNumPerPP, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, - DPsize}; + DPsize, enableBlockReuse}; } void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os) @@ -559,6 +561,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o su::serialize(state.mDataType, os); su::serialize(state.mAttentionConfig.mAttentionType, os); su::serialize(state.mAttentionConfig.mKvFactor, os); + su::serialize(state.mEnableBlockReuse, os); } size_t Serialization::serializedSize(kv_cache::CacheState const& state) @@ -577,6 +580,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state) totalSize += su::serializedSize(state.mDataType); totalSize += su::serializedSize(state.mAttentionConfig.mAttentionType); totalSize += su::serializedSize(state.mAttentionConfig.mKvFactor); + totalSize += su::serializedSize(state.mEnableBlockReuse); return totalSize; } @@ -2444,4 +2448,38 @@ ModelType Serialization::deserializeModelType(std::istream& is) return su::deserialize(is); } +// BlockKey (KV cache) +size_t Serialization::serializedSize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(key.usesExtraIds); + totalSize += su::serializedSize(key.loraTaskId); + totalSize += su::serializedSize(key.uniqueTokens); + // std::vector where MmKey is pair, SizeType32> + totalSize += su::serializedSize(key.extraKeys); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::batch_manager::kv_cache_manager::BlockKey const& key, std::ostream& os) +{ + su::serialize(key.usesExtraIds, os); + su::serialize(key.loraTaskId, os); + su::serialize(key.uniqueTokens, os); + su::serialize(key.extraKeys, os); +} + +tensorrt_llm::batch_manager::kv_cache_manager::BlockKey Serialization::deserializeBlockKey(std::istream& is) +{ + auto usesExtraIds = su::deserialize(is); + auto loraTaskId = su::deserialize>(is); + auto uniqueTokens = su::deserialize>(is); + auto extraKeys = su::deserialize>(is); + tensorrt_llm::batch_manager::kv_cache_manager::BlockKey key; + key.usesExtraIds = usesExtraIds; + key.loraTaskId = std::move(loraTaskId); + key.uniqueTokens = std::move(uniqueTokens); + key.extraKeys = std::move(extraKeys); + return key; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 40b50f92309..1f1e90e0a3b 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -21,12 +21,14 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/executor/types.h" +#include #include #include #include #include #include #include +#include #include #include @@ -74,6 +76,44 @@ struct is_variant> : std::true_type template constexpr bool is_variant_v = is_variant::value; +// Detect std::array +template +struct is_std_array : std::false_type +{ +}; + +template +struct is_std_array> : std::true_type +{ + using value_type = U; + static constexpr std::size_t size = N; +}; + +template +constexpr bool is_std_array_v = is_std_array::value; + +template +using array_value_type_t = typename is_std_array::value_type; + +template +constexpr std::size_t array_size_v = is_std_array::size; + +// Detect std::pair +template +struct is_std_pair : std::false_type +{ +}; + +template +struct is_std_pair> : std::true_type +{ + using first_type = A; + using second_type = B; +}; + +template +constexpr bool is_std_pair_v = is_std_pair::value; + // SerializedSize template bool constexpr hasSerializedSize(...) @@ -161,6 +201,21 @@ size_t serializedSize(T const& data) } return size; } + // std::array + else if constexpr (is_std_array_v) + { + size_t size = 0; + for (auto const& elem : data) + { + size += serializedSize(elem); + } + return size; + } + // std::pair + else if constexpr (is_std_pair_v) + { + return serializedSize(data.first) + serializedSize(data.second); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -266,6 +321,20 @@ void serialize(T const& data, std::ostream& os) serialize(element, os); } } + // std::array + else if constexpr (is_std_array_v) + { + for (auto const& element : data) + { + serialize(element, os); + } + } + // std::pair + else if constexpr (is_std_pair_v) + { + serialize(data.first, os); + serialize(data.second, os); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -575,6 +644,10 @@ T deserialize(std::istream& is) { return Serialization::deserializeUniqueToken(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeBlockKey(is); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -604,6 +677,23 @@ T deserialize(std::istream& is) } return container; } + // std::array + else if constexpr (is_std_array_v) + { + T container{}; + for (std::size_t i = 0; i < array_size_v; ++i) + { + container[i] = deserialize>(is); + } + return container; + } + // std::pair + else if constexpr (is_std_pair_v) + { + auto first = deserialize::first_type>(is); + auto second = deserialize::second_type>(is); + return T{std::move(first), std::move(second)}; + } // std::variant else if constexpr (is_variant_v) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 68c719fb687..6a0493bc592 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -465,7 +465,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, - nb::call_guard()); + nb::call_guard()) + .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) + .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 320659a1d09..f83b4137304 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -355,6 +355,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("add_token", &BaseKVCacheManager::addToken, py::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard()) .def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard()) + .def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, py::call_guard()) .def( @@ -467,7 +468,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, - py::call_guard()); + py::call_guard()) + .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) + .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 3e266d0cd14..6639a4e49bf 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -14,6 +14,7 @@ #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheTransferManager.h" +#include "tensorrt_llm/batch_manager/kvCacheUtils.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -592,6 +593,96 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare) } } +TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest) +{ + auto constexpr numLayers = 12; + auto constexpr numKvHeads = 6; + auto constexpr sizePerHead = 128; + auto constexpr tokensPerBlock = 8; + auto constexpr blocksInPrimaryPool = 4; + auto constexpr blocksInSecondaryPool = 4; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + + auto constexpr batchSize = 1; + auto constexpr maxBlocksPerSeq = 10; + auto constexpr bytesPerToken = 4; + auto constexpr maxAttentionWindow = 4096; + auto constexpr maxAttentionWindowAllLayer = 4096; + auto constexpr sinkTokenLen = 0; + auto constexpr canUseOneMoreBlock = true; + + SizeType32 constexpr maxNewTokens{0}; + auto constexpr beamWidth = 1; + auto constexpr beamIdx = 0; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, + beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, + false, stream, true, onboardBlocks); + + // Add sequence [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] (17 tokens, three blocks) + auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto const inputLength = static_cast(inputTokens->size()); + LlmRequest::RequestIdType requestId{0}; + auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); + kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest0); + EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); + auto cacheBlockIds = kvCacheManager.getSequence(requestId).getCacheBlockIds(maxAttentionWindow).at(beamIdx); + EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); + + // Print all the block ids fromAllBlockIds + auto blockRange = BlockRange::fromAllBlockIds(kvCacheManager, requestId); + for (auto& block : blockRange.getBlockIds()) + { + std::cout << block << " "; + } + std::cout << std::endl; + kvCacheManager.removeSequence(requestId, llmRequest0); + std::cout << "Removed sequence 0" << std::endl; + + requestId = 1; + auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); + kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest1); + std::cout << "Added sequence 1" << std::endl; + cacheBlockIds = kvCacheManager.getSequence(requestId).getCacheBlockIds(maxAttentionWindow).at(beamIdx); + std::cout << "Cache Block IDs: "; + for (auto& block : cacheBlockIds) + { + std::cout << block << " "; + } + std::cout << std::endl; + EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); + auto blockRange2 = BlockRange::fromAllBlockIds(kvCacheManager, requestId); + std::cout << "All Block IDs: "; + for (auto& block : blockRange2.getBlockIds()) + { + std::cout << block << " "; + } + std::cout << std::endl; + + auto blockRange3 = BlockRange::fromNewlyAllocatedBlockIds(kvCacheManager, requestId); + std::cout << "Newly Allocated Block IDs: "; + for (auto& block : blockRange2.getBlockIds()) + { + std::cout << block << " "; + } + std::cout << std::endl; + + // BlockKey emptyKey{}; + // auto result = blockManager.findBlocksInReuseTreeByBlockKey(emptyKey, maxAttentionWindow); + // ASSERT_TRUE(result.has_value()); + // EXPECT_EQ((*result)->getBlockId(), KVCacheBlock::kCachedBlocksRootId); + + // BlockKey key{BlockKey{block0->getHash()}}; + // result = blockManager.findBlocksInReuseTreeByBlockKey(key, maxAttentionWindow); + // ASSERT_TRUE(result.has_value()); + // EXPECT_EQ((*result)->getBlockId(), block0->getBlockId()); +} + #ifdef ENABLE_FP4 TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) { diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aa0902484a5..03981210e46 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -216,6 +216,7 @@ def __init__(self, # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) + self.block_reuse_enabled = self.kv_cache_manager.enable_block_reuse self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0 self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse @@ -1921,24 +1922,32 @@ def _handle_responses(self): if request_done: if request.is_disagg_context_transmission_state: - self.ctx_in_transmission_requests.append(request) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: + requests_to_terminate.append(request) + self.ctx_in_transmission_requests.append( + (request, + self.kv_cache_manager.get_last_block_id( + request.py_request_id))) else: requests_to_terminate.append(request) else: new_active_requests.append(request) self.active_requests.clear() self.active_requests.extend(new_active_requests) - self._enqueue_responses(new_responses) for request in requests_to_terminate: self._terminate_request(request) + self._enqueue_responses(new_responses) return requests_to_terminate @nvtx_range("_terminate_ctx_finished_requests") def _terminate_ctx_finished_requests(self): - for request in self.ctx_in_transmission_requests[:]: + for request, block_id in self.ctx_in_transmission_requests[:]: if request.is_disagg_context_complete_state: - self._terminate_request(request) - self.ctx_in_transmission_requests.remove(request) + if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: + self._terminate_request(request) + else: + self.kv_cache_manager.unpin_blocks_by_id(block_id) + self.ctx_in_transmission_requests.remove((request, block_id)) def _handle_logits_communication(self, previous_batch, prev_microbatch_id): """Handle logits communication between pipeline parallel ranks. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 619831c4b4b..3e201b1a8ab 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -644,6 +644,12 @@ def get_cache_indices(self, assert len(result) == 1 return result[0] + def unpin_blocks_by_id(self, kv_cache_block_id: int): + self.impl.unpin_blocks_by_id(kv_cache_block_id) + + def get_last_block_id(self, request_id: int) -> int: + return self.impl.get_last_block_id(request_id) + def get_batch_cache_indices( self, request_ids: List[int], diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 733122c3229..49c612872bf 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1,7 +1,3 @@ -# I want to create accuracy tests for disaggregated serving. -# I need to to this by creating a new class that mimics LLM class. Instead of implementing the -# actual methods it will send OAI requests to the disaggregated serving endpoint. -# Please take a look at the existing test_llm_api_pytorch.py file for reference. import concurrent import contextlib import itertools @@ -359,12 +355,26 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) - def test_auto_dtype(self, disable_overlap_scheduler): - ctx_server_config = {"disable_overlap_scheduler": True} - gen_server_config = { - "disable_overlap_scheduler": disable_overlap_scheduler + @pytest.mark.parametrize("ctx_enable_block_reuse", [True, False]) + @pytest.mark.parametrize("gen_enable_block_reuse", [True, False]) + def test_auto_dtype(self, disable_overlap_scheduler, ctx_enable_block_reuse, + gen_enable_block_reuse): + ctx_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": { + "enable_block_reuse": ctx_enable_block_reuse + } } ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} + gen_server_config = { + "disable_overlap_scheduler": disable_overlap_scheduler, + "kv_cache_config": { + "enable_block_reuse": gen_enable_block_reuse + }, + "cache_transceiver_config": { + "backend": "DEFAULT" + } + } gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} disaggregated_server_config = { "hostname": "localhost", diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 535b6e02d13..a2ecf915259 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -68,6 +68,25 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_gentp4[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] + - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] + - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] From fa4225f61a14540652375b1734f5c54979d44fc1 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:12:30 -0700 Subject: [PATCH 02/17] Address review comments + testing Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/cacheFormatter.h | 75 ------------------ .../batch_manager/dataTransceiver.cpp | 4 +- .../batch_manager/dataTransceiver.h | 9 ++- .../batch_manager/kvCacheManagerTest.cpp | 79 +++++++++++++++++++ .../executor/serializeUtilsTest.cpp | 34 ++++++++ .../test_lists/qa/llm_function_core.txt | 4 +- .../qa/llm_function_core_sanity.txt | 4 +- .../test_lists/test-db/l0_dgx_h100.yml | 29 ++----- tests/integration/test_lists/waives.txt | 4 +- 9 files changed, 135 insertions(+), 107 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index beca72696ea..ee9a33a80bd 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -56,81 +56,6 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest BlockRange getBlockRangeForReceiving( BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse); -class KvCacheMeasureHelper -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - KvCacheMeasureHelper(std::string output_path) - : mOutputPath(std::move(output_path)) - { - } - - void markAsSender(bool isSender) - { - mIsSender = isSender; - } - - void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) - { - auto bandwidth = size * 8 / (duration / 1000) / 1e9; - if (mOutputPath.empty()) - { - return; - } - - std::lock_guard lock(mMutex); - mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); - } - - ~KvCacheMeasureHelper() - { - if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) - { - TLLM_CHECK(mIsSender.has_value()); - auto rank = mpi::MpiComm::world().getRank(); - std::string outFilePath - = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; - std::ofstream outFile(outFilePath); - - TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); - - size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); - - outFile << "RequestID"; - for (size_t i = 0; i < numTransferMeasure; i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - - for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) - { - outFile << requestID; - - for (auto const& measure : measures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n'; - } - - outFile.close(); - } - } - -private: - std::map> mRequestKVCacheTranfserMeasure; - std::string mOutputPath; - std::mutex mMutex; - std::optional mIsSender; -}; - // Used to support the cache transmission with different layouts and different protocols. class BaseCacheFormatter { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 58007d1ac04..c92cddba263 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -734,12 +734,12 @@ class CacheReceiver::Impl if (agentConnectionManager != nullptr) { // TODO: index -> validConnectionIdx conversion - auto valideConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); + auto validConnectionIdx = std::find(pickUpIdx.begin(), pickUpIdx.end(), i) - pickUpIdx.begin(); auto* agentConnection = dynamic_cast(connection); TLLM_CHECK(agentConnection != nullptr); TLLM_CHECK(cacheBufferId.has_value()); const_cast(agentConnection) - ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx); + ->sendRequestAndBufferInfo(requestInfo, cacheBufferId, validConnectionIdx); } else { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index a41c7eba480..8056f683a40 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -49,6 +49,8 @@ using DataContext = tensorrt_llm::executor::kv_cache::DataContext; using Connection = tensorrt_llm::executor::kv_cache::Connection; using ConnectionManager = tensorrt_llm::executor::kv_cache::ConnectionManager; using SizeType32 = tensorrt_llm::runtime::SizeType32; +using BlockKey = tensorrt_llm::batch_manager::kv_cache_manager::BlockKey; +using UniqueToken = tensorrt_llm::runtime::UniqueToken; class TransferSession { @@ -65,11 +67,12 @@ class TransferSession runtime::BufferManager const& bufferManager, int32_t indexFromEnd, BlockKey const& lastBlockKey, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) : mConnections(std::move(connections)) - , mDataContext(dataContext) + , mDataContext(std::move(dataContext)) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) , mRequest(llmRequest) + , mMeasures() , mRecordMeasure(recordMeasure) , mIndexFromEnd(indexFromEnd) , mLastBlockKey(lastBlockKey) @@ -104,12 +107,12 @@ class TransferSession // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file void exportMeasure(std::ofstream& outFile, bool isContext) const; - [[nodiscard]] int32_t getIndexFromEnd() const noexcept + [[nodiscard]] int32_t getIndexFromEnd() const { return mIndexFromEnd; } - [[nodiscard]] BlockKey const& getLastBlockKey() const noexcept + [[nodiscard]] BlockKey const& getLastBlockKey() const { return mLastBlockKey; } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 6639a4e49bf..fdc5a175b0f 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -3545,6 +3545,85 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) } } +TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksBasicNoPartial) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + std::vector vec{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + auto blocks = chopVectorIntoBlocks(vec, 10, 4, false); + std::vector> got(blocks.begin(), blocks.end()); + ASSERT_EQ(got.size(), 2); + EXPECT_THAT(got[0], ::testing::ElementsAreArray({0, 1, 2, 3})); + EXPECT_THAT(got[1], ::testing::ElementsAreArray({4, 5, 6, 7})); +} + +TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksBasicWithPartial) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + std::vector vec{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + auto blocks = chopVectorIntoBlocks(vec, 10, 4, true); + std::vector> got(blocks.begin(), blocks.end()); + ASSERT_EQ(got.size(), 3); + EXPECT_THAT(got[0], ::testing::ElementsAreArray({0, 1, 2, 3})); + EXPECT_THAT(got[1], ::testing::ElementsAreArray({4, 5, 6, 7})); + EXPECT_THAT(got[2], ::testing::ElementsAreArray({8, 9})); +} + +TEST(KVCacheManagerHelpersTest, ChopVectorIntoBlocksWithUsableSize) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + std::vector vec{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + auto blocks = chopVectorIntoBlocks(vec, 7, 4, true); + std::vector> got(blocks.begin(), blocks.end()); + ASSERT_EQ(got.size(), 2); + EXPECT_THAT(got[0], ::testing::ElementsAreArray({0, 1, 2, 3})); + EXPECT_THAT(got[1], ::testing::ElementsAreArray({4, 5, 6})); +} + +TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + auto constexpr numLayers = 2; + auto constexpr numKvHeads = 2; + auto constexpr sizePerHead = 16; + auto constexpr tokensPerBlock = 4; + auto constexpr blocksInPrimaryPool = 4; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + auto constexpr beamWidth = 1; + auto const maxAttentionWindow = tokensPerBlock * blocksInPrimaryPool; + + BlocksPerWindow const blocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, + beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, + 0, stream, std::nullopt, true, onboardBlocks); + kvCacheManager.allocatePools(false); + + LlmRequest::RequestIdType requestId{0}; + auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7}); + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + auto llmRequest = std::make_shared(requestId, 0, inputTokens, samplingConfig, isStreaming); + + kvCacheManager.addSequence(requestId, static_cast(inputTokens->size()), beamWidth, llmRequest); + auto const totalBlocks = kvCacheManager.getMaxNumBlocks(); + auto const freeAfterAlloc = kvCacheManager.getNumFreeBlocks(); + EXPECT_LT(freeAfterAlloc, totalBlocks); + + kvCacheManager.pinBlocks(requestId); + auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId); + ASSERT_TRUE(lastBlockIdOpt.has_value()); + kvCacheManager.removeSequence(requestId, llmRequest); + auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks(); + EXPECT_LT(freeAfterRemovePinned, totalBlocks); + + kvCacheManager.unpinBlocksById(lastBlockIdOpt.value()); + auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks(); + EXPECT_EQ(freeAfterUnpin, totalBlocks); +} + TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking) { auto constexpr numLayers = 12; diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 1faf9540760..597077e5191 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -11,6 +11,7 @@ */ #include "tensorrt_llm/executor/serializeUtils.h" +#include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" @@ -1031,3 +1032,36 @@ TEST(SerializeUtilsTest, CacheTransceiverConfig) EXPECT_EQ(cacheTransceiverConfig.getBackendType(), cacheTransceiverConfig2.getBackendType()); EXPECT_EQ(cacheTransceiverConfig.getMaxTokensInBuffer(), cacheTransceiverConfig2.getMaxTokensInBuffer()); } + +TEST(SerializeUtilsTest, BlockKeyBasic) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + + VecUniqueTokens uniqueTokens{UniqueToken{1, 0}, UniqueToken{2, 0}, UniqueToken{3, 0}}; + BlockKey key(false, std::nullopt, uniqueTokens, {}); + + testSerializeDeserialize(key); +} + +TEST(SerializeUtilsTest, BlockKeyWithExtras) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + + // Prepare multimodal extra keys + std::array h1{}; + std::array h2{}; + for (size_t i = 0; i < h1.size(); ++i) + { + h1[i] = static_cast(i); + h2[i] = static_cast(255 - i); + } + std::vector extraKeys{{h1, SizeType32{0}}, {h2, SizeType32{5}}}; + + VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}}; + std::optional loraTaskId = LoraTaskIdType{42}; + + // Note: cacheSaltID is intentionally not set since it is not serialized + BlockKey key(true, loraTaskId, uniqueTokens, extraKeys); + + testSerializeDeserialize(key); +} diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index bdd6bf4c90e..8dbf7c8ea45 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -552,8 +552,8 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-au accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-fp8] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index ad941b95dc1..1e8f7c8b6cf 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -7,8 +7,8 @@ accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[F accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index a2ecf915259..a07955a3c17 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -31,8 +31,14 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype - condition: @@ -68,25 +74,6 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_gentp4[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] - - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] - - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram - - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] - - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] - - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index bdc82cece59..53f62804e1f 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -300,8 +300,8 @@ accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbug accuracy/test_llm_api.py::TestPhi4MiniInstruct::test_fp8 SKIP (https://nvbugs/5465143) accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype SKIP (https://nvbugs/5481090) accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=False] SKIP (https://nvbugs/5483534) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5488118) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118) test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523) cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078) From 705b71776cc0aa1fc1ac0968669650104ed65193 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:44:32 -0700 Subject: [PATCH 03/17] Fix TRT path Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/executor/executorImpl.cpp | 33 +++++++++++++++---- cpp/tensorrt_llm/executor/executorImpl.h | 12 +++++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index c2118bbf771..cba4cecf5f3 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -2169,15 +2169,25 @@ void Executor::Impl::terminateCancelledRequests(RequestList& activeRequests) } } -void Executor::Impl::terminateContextFinishedRequests(RequestList& inTransmissionRequests) +void Executor::Impl::terminateContextFinishedRequests(InTransList& inTransmissionRequests) { NVTX3_SCOPED_RANGE(terminateContextFinishedRequests); for (auto it = inTransmissionRequests.begin(); it != inTransmissionRequests.end();) { - auto req = *it; + auto& item = *it; + auto req = item.request; if (req->isDisaggContextCompleteState()) { - mModel->terminateRequest(req); + // If lastBlockId was tracked, unpin it. Otherwise, just terminate. + auto kvMgr = mModel->getKVCacheManager(); + if (kvMgr && item.lastBlockId.has_value()) + { + kvMgr->unpinBlocksById(item.lastBlockId.value()); + } + else + { + mModel->terminateRequest(req); + } it = inTransmissionRequests.erase(it); } else @@ -2200,7 +2210,7 @@ void Executor::Impl::appendNewResponses(std::vector&& newResponses) } Executor::Impl::RequestList Executor::Impl::populateNewResponses( - RequestList& activeRequests, RequestList& inTransmissionRequests, std::vector& newResponses) + RequestList& activeRequests, InTransList& inTransmissionRequests, std::vector& newResponses) { NVTX3_SCOPED_RANGE(populateNewResponses); RequestList finishedRequests; @@ -2223,7 +2233,18 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses( // move the in transmission requests to another tracker if (llmReq->isDisaggContextTransmissionState()) { - inTransmissionRequests.push_back(*it); + // Save either lastBlockId (reuse enabled and no VSWA) or just the request + std::optional lastBlockId{}; + auto kvMgr = mModel->getKVCacheManager(); + if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow()) + { + if (auto last = kvMgr->getLastBlockId(llmReq->mRequestId)) + { + lastBlockId = last.value(); + } + mModel->terminateRequest(llmReq); + } + inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId}); } finishedRequests.push_back(*it); it = activeRequests.erase(it); @@ -2252,7 +2273,7 @@ void Executor::Impl::executionLoop() std::chrono::time_point iterEnd; bool firstIteration{true}; RequestList activeRequests; - RequestList inTransmissionRequests; + InTransList inTransmissionRequests; std::vector newResponses; while (!mShutdown || !activeRequests.empty()) { diff --git a/cpp/tensorrt_llm/executor/executorImpl.h b/cpp/tensorrt_llm/executor/executorImpl.h index 7d34cbdf382..5a30e0c8a0a 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.h +++ b/cpp/tensorrt_llm/executor/executorImpl.h @@ -79,6 +79,14 @@ class Executor::Impl using LlmRequestPtr = std::shared_ptr; using RequestList = std::list; + struct InTransmissionItem + { + LlmRequestPtr request; + std::optional lastBlockId; // present when reuse enabled and not variable window + }; + + using InTransList = std::list; + public: Impl(std::filesystem::path const& modelPath, std::optional const& encoderModelPath, [[maybe_unused]] ModelType modelType, ExecutorConfig const& executorConfig); @@ -206,7 +214,7 @@ class Executor::Impl void terminateCancelledRequests(RequestList& activeRequests); - void terminateContextFinishedRequests(RequestList& inTransmissionRequests); + void terminateContextFinishedRequests(InTransList& inTransmissionRequests); void appendNewResponses(std::vector&& newResponses); @@ -215,7 +223,7 @@ class Executor::Impl /// and returned for bookkeeping. /// @return A list of requests that have completed. RequestList populateNewResponses( - RequestList& activeRequests, RequestList& inTransmissionRequests, std::vector& newResponses); + RequestList& activeRequests, InTransList& inTransmissionRequests, std::vector& newResponses); void executionLoop(); diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 03981210e46..39f26df671c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -216,7 +216,7 @@ def __init__(self, # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) - self.block_reuse_enabled = self.kv_cache_manager.enable_block_reuse + self.block_reuse_enabled = True if self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse else False self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0 self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse From 372a8432a4176ca356d35fc96c684d1ab3911a7f Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 18 Sep 2025 07:15:59 -0700 Subject: [PATCH 04/17] Review comment Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 19048f7767d..95e226ae380 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1104,10 +1104,10 @@ std::optional> WindowBlockManager::findBlocksInReu auto blockedUniqueTokens = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); std::vector blockKeys; - for (auto const& blockedUniqueTokens : blockedUniqueTokens) + for (auto const& blockedUniqueTokensList : blockedUniqueTokens) { blockKeys.push_back(blockKey); - blockKeys.back().uniqueTokens = blockedUniqueTokens; + blockKeys.back().uniqueTokens = blockedUniqueTokensList; } auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) From 832202ca93bbf39cc5e41a2ada88bd1b9e7439bc Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 18 Sep 2025 09:14:06 -0700 Subject: [PATCH 05/17] Fix tests Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.cpp | 5 +- .../batch_manager/kvCacheManagerTest.cpp | 58 +++++-------------- 2 files changed, 16 insertions(+), 47 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 95e226ae380..d9dcc9a20e2 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2296,8 +2296,9 @@ void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef llmRequest) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); - if (mBlockManager.getNumPools() == 1 - && llmRequest->getLlmRequestType() == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY && mEnableBlockReuse) + if (mBlockManager.getNumPools() == 1 && llmRequest + && llmRequest->getLlmRequestType() == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY && mEnableBlockReuse + && !mBlockManager.isVariableWindow()) { pinBlocks(requestId); } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index fdc5a175b0f..c42be64cb75 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -634,53 +634,21 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest) auto cacheBlockIds = kvCacheManager.getSequence(requestId).getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); - // Print all the block ids fromAllBlockIds - auto blockRange = BlockRange::fromAllBlockIds(kvCacheManager, requestId); - for (auto& block : blockRange.getBlockIds()) - { - std::cout << block << " "; - } - std::cout << std::endl; kvCacheManager.removeSequence(requestId, llmRequest0); - std::cout << "Removed sequence 0" << std::endl; - - requestId = 1; - auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); - kvCacheManager.addSequence(requestId, inputLength, beamWidth, llmRequest1); - std::cout << "Added sequence 1" << std::endl; - cacheBlockIds = kvCacheManager.getSequence(requestId).getCacheBlockIds(maxAttentionWindow).at(beamIdx); - std::cout << "Cache Block IDs: "; - for (auto& block : cacheBlockIds) - { - std::cout << block << " "; - } - std::cout << std::endl; - EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); - auto blockRange2 = BlockRange::fromAllBlockIds(kvCacheManager, requestId); - std::cout << "All Block IDs: "; - for (auto& block : blockRange2.getBlockIds()) - { - std::cout << block << " "; - } - std::cout << std::endl; - - auto blockRange3 = BlockRange::fromNewlyAllocatedBlockIds(kvCacheManager, requestId); - std::cout << "Newly Allocated Block IDs: "; - for (auto& block : blockRange2.getBlockIds()) - { - std::cout << block << " "; - } - std::cout << std::endl; - - // BlockKey emptyKey{}; - // auto result = blockManager.findBlocksInReuseTreeByBlockKey(emptyKey, maxAttentionWindow); - // ASSERT_TRUE(result.has_value()); - // EXPECT_EQ((*result)->getBlockId(), KVCacheBlock::kCachedBlocksRootId); - // BlockKey key{BlockKey{block0->getHash()}}; - // result = blockManager.findBlocksInReuseTreeByBlockKey(key, maxAttentionWindow); - // ASSERT_TRUE(result.has_value()); - // EXPECT_EQ((*result)->getBlockId(), block0->getBlockId()); + inputTokens->pop_back(); + BlockKey fullKey{*inputTokens}; + auto const foundFull = kvCacheManager.findBlocksInReuseTreeByBlockKey(fullKey, maxAttentionWindow); + ASSERT_TRUE(foundFull.has_value()); + ASSERT_NE(foundFull.value(), nullptr); + auto const& lastBlock = foundFull.value(); + + // Check the chain back to previous blocks + auto const prev2 = lastBlock->getPrevBlock(); + ASSERT_NE(prev2, nullptr); + auto const prev1 = prev2->getPrevBlock(); + ASSERT_NE(prev1, nullptr); + EXPECT_EQ(prev1->getPrevBlock(), nullptr); } #ifdef ENABLE_FP4 From 66cbba0c78726d45a70c88cfa6931daad9b155a3 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:46:08 -0700 Subject: [PATCH 06/17] Fix pp bugs Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 39f26df671c..e251bca9549 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1942,7 +1942,7 @@ def _handle_responses(self): @nvtx_range("_terminate_ctx_finished_requests") def _terminate_ctx_finished_requests(self): for request, block_id in self.ctx_in_transmission_requests[:]: - if request.is_disagg_context_complete_state: + if request.is_disagg_context_complete_state and request.is_finished: if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: self._terminate_request(request) else: From 6eb8abbd6af176fb46de13419b769b4bac4157f2 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Sun, 21 Sep 2025 15:46:46 -0700 Subject: [PATCH 07/17] Fix eagle Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 10 +----- .../batch_manager/cacheFormatter.cpp | 33 ++++++++++++++++--- .../batch_manager/dataTransceiver.cpp | 4 +-- .../batch_manager/kvCacheManager.cpp | 7 +--- .../nanobind/batch_manager/kvCacheManager.cpp | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 6 ++-- .../_torch/pyexecutor/resource_manager.py | 3 ++ 7 files changed, 40 insertions(+), 24 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 0e7bad3a585..3a5fb4f41dc 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -40,14 +40,6 @@ class BlockRange return BlockRange(cacheManager, blockIds, requestId); } - static BlockRange fromNewlyAllocatedBlockIds( - BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId) - { - auto const windowSize = firstWindowSize(cacheManager); - auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize); - return BlockRange(cacheManager, blockIds, requestId); - } - static BlockRange fromReuseTree( BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd) { @@ -62,11 +54,11 @@ class BlockRange blockIds.reserve(numBlocksToCollect); for (int32_t i = 0; i < numBlocksToCollect; ++i) { + TLLM_CHECK_WITH_INFO(lastBlock->getPrevBlock(), "last block has no prev block"); blockIds.push_back(lastBlock->getBlockId()); if (i + 1 < numBlocksToCollect) { lastBlock = lastBlock->getPrevBlock(); - TLLM_CHECK_WITH_INFO(lastBlock, "Previous block not found while traversing reuse tree"); } } // Reverse to chronological order: oldest to newest diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index cecfa62df41..e2ab9641079 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -32,6 +32,7 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include #include #include #include @@ -57,11 +58,34 @@ BlockRange getBlockRangeForSending( BlockRange getBlockRangeForReceiving( BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse) { - auto poolNum = cacheManager->getBlockManager().getNumPools(); - if (poolNum == 1 && cacheManager->isEnableBlockReuse() && srcEnableBlockReuse) + if (poolNum == 1 && srcEnableBlockReuse) { - return BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId); + // Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones. + constexpr SizeType32 beam{0}; + auto range = BlockRange::fromAllBlockIds(*cacheManager, llmRequest.mRequestId, beam); + auto const& allBlockIds = range.getBlockIds(); + auto const totalBlocks = static_cast(allBlockIds.size()); + // Derive reused blocks count from number of unique prepopulated tokens + auto const tokensPerBlock = cacheManager->getBlockManager().getTokensPerBlock(); + auto const prepopulatedTokens = llmRequest.getPrepopulatedPromptLen(); + auto const totalUniqueTokens = llmRequest.getPromptLen(); + auto const usedBlocks = std::min( + static_cast((totalUniqueTokens + tokensPerBlock - 1) / tokensPerBlock), totalBlocks); + auto const reusedBlocks = std::min( + static_cast((prepopulatedTokens + tokensPerBlock - 1) / tokensPerBlock), usedBlocks); + + std::vector newBlockIds; + if (reusedBlocks < usedBlocks) + { + newBlockIds.assign(allBlockIds.begin() + reusedBlocks, allBlockIds.begin() + usedBlocks); + } + else + { + newBlockIds.clear(); + } + range.setBlockIds(std::move(newBlockIds)); + return range; } else { @@ -328,6 +352,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio auto bufferEleSizes = getBufferSizeForTarget(); auto result = mCacheTransBufferManager->getOrAllocateSendBuffers( cacheBufferId, static_cast(bufferTargetNum), bufferEleSizes, bufferManager); + auto& outputSplitCaches = std::get<0>(result); auto& bufferCoverTargetNum = std::get<1>(result); auto& onlyUseDynamicBuffer = std::get<2>(result); @@ -380,7 +405,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio } else { - // If cacheIdx< bufferCoverTargetNum, the ouputSplitCaches.at(cacheIdx) is allocated by cudaMallocAsync, // which is unable to be transferred by UCX GPU-direct RDMA. We need copy the data to pre-allocated // cudaMalloc buffer,and then start send. @@ -402,7 +426,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio auto sendSize = std::min(remainSendSize, sendBufferEleSize); auto copySlice = runtime::ITensor::slice( outputSplitCaches[bufferIdx], needSendSize - remainSendSize, sendSize); - auto copyTargetSlice = runtime::ITensor::slice(sendUseAllocBuffer, 0, sendSize); bufferManager.copy(*copySlice, *copyTargetSlice); bufferManager.getStream().synchronize(); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index c92cddba263..3ee6fbb25c3 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -699,9 +699,9 @@ class CacheReceiver::Impl lastBlockKey.extraKeys = std::move(extraKeys); } // Compute indexFromEnd from the number of requested blocks - size_t requestedBlockSize = requestedBlockRange.getBlockIds().size(); + int32_t requestedBlockSize = requestedBlockRange.getBlockIds().size(); TLLM_CHECK_WITH_INFO(requestedBlockSize > 0, "requestedBlockSize must be > 0"); - int32_t indexFromEnd = static_cast(requestedBlockSize - 1); + int32_t indexFromEnd = requestedBlockSize - 1; requestInfo = RequestInfo(requestId, mSelfState, indexFromEnd, lastBlockKey); } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index d9dcc9a20e2..ccb6f92ecea 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1103,6 +1103,7 @@ std::optional> WindowBlockManager::findBlocksInReu std::lock_guard lock(mCachedBlocksRootMutex); auto blockedUniqueTokens = chopVectorIntoBlocks(blockKey.uniqueTokens, blockKey.uniqueTokens.size(), mTokensPerBlock, true); + std::vector blockKeys; for (auto const& blockedUniqueTokensList : blockedUniqueTokens) { @@ -2296,12 +2297,6 @@ void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef llmRequest) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); - if (mBlockManager.getNumPools() == 1 && llmRequest - && llmRequest->getLlmRequestType() == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY && mEnableBlockReuse - && !mBlockManager.isVariableWindow()) - { - pinBlocks(requestId); - } auto sequenceNode = [this, requestId] { std::scoped_lock lock(mSequencesMtx); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 6a0493bc592..8cabb0cb3c6 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -353,6 +353,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("add_token", &BaseKVCacheManager::addToken, nb::call_guard()) .def("add_sequence", &BaseKVCacheManager::addSequence, nb::call_guard()) .def("remove_sequence", &BaseKVCacheManager::removeSequence, nb::call_guard()) + .def("pin_blocks", &BaseKVCacheManager::pinBlocks, nb::call_guard()) .def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence, nb::call_guard()) .def( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e251bca9549..ac5901ccfd0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1812,6 +1812,8 @@ def _do_terminate_request(self, request: LlmRequest): request, cache_block_ids): self.resource_manager.free_resources(request) else: + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and request.is_context_only_request: + self.kv_cache_manager.pin_blocks(request.py_request_id) self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") @@ -1920,7 +1922,7 @@ def _handle_responses(self): request_done = request.is_finished new_responses.append((req_id, response)) - if request_done: + if request_done or request.is_disagg_context_complete_state: if request.is_disagg_context_transmission_state: if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: requests_to_terminate.append(request) @@ -1942,7 +1944,7 @@ def _handle_responses(self): @nvtx_range("_terminate_ctx_finished_requests") def _terminate_ctx_finished_requests(self): for request, block_id in self.ctx_in_transmission_requests[:]: - if request.is_disagg_context_complete_state and request.is_finished: + if request.is_disagg_context_complete_state: if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: self._terminate_request(request) else: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3e201b1a8ab..66e2299d22b 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1011,6 +1011,9 @@ def _validate_and_adjust_attention_windows( else: return blocks_per_window, max_seq_len, max_attention_window_vec + def pin_blocks(self, request_id: int): + self.impl.pin_blocks(request_id) + def _set_temp_attention_window_inputs( self) -> Optional[TempAttentionWindowInputs]: """ From bd38219e170b867c3bb73e90470d0303612745e9 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Mon, 22 Sep 2025 22:43:59 -0700 Subject: [PATCH 08/17] fix pp bugs Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 37 +++++-- .../tensorrt_llm/batch_manager/kvCacheUtils.h | 4 +- .../batch_manager/cacheFormatter.cpp | 5 +- .../batch_manager/kvCacheManager.cpp | 101 ++++++++++++++++-- .../nanobind/batch_manager/kvCacheManager.cpp | 17 ++- .../pybind/batch_manager/kvCacheManager.cpp | 17 ++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 51 +++++---- .../_torch/pyexecutor/resource_manager.py | 11 +- 8 files changed, 195 insertions(+), 48 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index b670a18c90c..02199b119f9 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -607,6 +607,9 @@ class WindowBlockManager //! \brief Get the ids of all newly allocated (not reused) blocks for the sequence. std::vector getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const; + [[nodiscard]] std::optional storeBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks = false); + void storeNewBlock(GenerationRequest& sequence, OptionalRef llmRequest); //! \brief Pin blocks associated with a sequence to prevent eviction. @@ -787,8 +790,10 @@ class WindowBlockManager //! \brief Store blocks in cached blocks. //! \param blockKeys Key of each block. //! \param blockIds Id of each block. - //! \return Number of actual blocks stored. - SizeType32 storeBlocks(std::vector const& blockKeys, std::vector const& blockIds); + //! \param pinBlocks If true, increment ref count for blocks while storing (pin on store). + //! \return The id of the last block stored in the reuse tree, if any were stored. + [[nodiscard]] std::optional storeBlocks(std::vector const& blockKeys, + std::vector const& blockIds, bool pinBlocks = false); [[nodiscard]] bool verifyQueueIntegrity(); @@ -981,7 +986,11 @@ class BlockManager std::vector getNewlyAllocatedBlockIds( GenerationRequest const& sequence, SizeType32 windowSize) const; - void releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt); + [[nodiscard]] std::optional releaseBlocks( + GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); + + [[nodiscard]] std::optional storeBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -1010,10 +1019,10 @@ class BlockManager void offloadBlock(BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - void storeBlocks(std::vector const& blockKeys, std::vector const& blockIds, - SizeType32 windowSize) + [[nodiscard]] std::optional storeBlocks(std::vector const& blockKeys, + std::vector const& blockIds, SizeType32 windowSize, bool pinBlocks = false) { - mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds); + return mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds, pinBlocks); } [[nodiscard]] bool verifyQueueIntegrity(SizeType32 windowSize); @@ -1350,8 +1359,8 @@ class BaseKVCacheManager OptionalRef llmRequest = std::nullopt) = 0; - virtual void removeSequence( - LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt) + [[nodiscard]] virtual std::optional removeSequence(LlmRequest::RequestIdType requestId, + OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) = 0; virtual void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) = 0; @@ -1395,6 +1404,11 @@ class BaseKVCacheManager //! \details This block become reusable from next step. virtual void storeNewBlock(LlmRequest const& llmRequest) = 0; + /// \brief Store blocks for reuse for a given request id + [[nodiscard]] virtual std::optional storeBlocksForReuse( + LlmRequest::RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks = false) + = 0; + //! \brief Get the block ids of a request [per beam] **for a given window size block manager** [[nodiscard]] virtual std::vector> const& getCacheBlockIds( LlmRequest::RequestIdType requestId, SizeType32 windowSize) const @@ -1664,8 +1678,8 @@ class KVCacheManager : public BaseKVCacheManager void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth, OptionalRef llmRequest = std::nullopt) override; - void removeSequence( - LlmRequest::RequestIdType requestId, OptionalRef llmRequest = std::nullopt) override; + [[nodiscard]] std::optional removeSequence(LlmRequest::RequestIdType requestId, + OptionalRef llmRequest = std::nullopt, bool pinOnRelease = false) override; void schedulingRemoveSequence(LlmRequest::RequestIdType requestId) override; @@ -1725,6 +1739,9 @@ class KVCacheManager : public BaseKVCacheManager //! \brief Store newest blocks for reuse void storeNewBlock(LlmRequest const& llmRequest) override; + [[nodiscard]] std::optional storeBlocksForReuse( + LlmRequest::RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks = false) override; + [[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock); [[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index 3a5fb4f41dc..c1c686f6f28 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -54,10 +54,12 @@ class BlockRange blockIds.reserve(numBlocksToCollect); for (int32_t i = 0; i < numBlocksToCollect; ++i) { - TLLM_CHECK_WITH_INFO(lastBlock->getPrevBlock(), "last block has no prev block"); + TLLM_CHECK_WITH_INFO( + lastBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId, "last block has no block id"); blockIds.push_back(lastBlock->getBlockId()); if (i + 1 < numBlocksToCollect) { + TLLM_CHECK_WITH_INFO(lastBlock->getPrevBlock(), "last block has no prev block"); lastBlock = lastBlock->getPrevBlock(); } } diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index e2ab9641079..8460d7ca364 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -82,7 +82,10 @@ BlockRange getBlockRangeForReceiving( } else { - newBlockIds.clear(); + if (usedBlocks > 0 && usedBlocks <= totalBlocks) + { + newBlockIds.push_back(allBlockIds[usedBlocks - 1]); + } } range.setBlockIds(std::move(newBlockIds)); return range; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index ccb6f92ecea..212a2b4bce0 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -303,7 +303,8 @@ void KVCacheBlock::incRefCount() void KVCacheBlock::decRefCount() { - TLLM_CHECK_WITH_INFO(hasRefs(), "Can't remove link from block that is not allocated"); + TLLM_CHECK_WITH_INFO( + hasRefs(), "Can't remove link from block (id=%d) that is not allocated", static_cast(mBlockId)); mRefCount--; } @@ -774,7 +775,7 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false); auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); - storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], windowSize); + (void) mWindowBlockManagers.at(windowSize).storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } } @@ -1116,6 +1117,7 @@ std::optional> WindowBlockManager::findBlocksInReu auto [partialMatch, numMatched, matchingBlock] = searchRoot != nullptr ? searchRoot->findMatchingBlock(blockKey, true, true) : std::make_tuple(false, 0, nullptr); + if (matchingBlock == nullptr) { return std::nullopt; @@ -1461,8 +1463,13 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm } } +<<<<<<< HEAD SizeType32 WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds) +======= +std::optional WindowBlockManager::storeBlocks( + std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks) +>>>>>>> 07a060666 (fix pp bugs) { SizeType32 numBlocksStoredForReuse = 0; std::lock_guard lock(mCachedBlocksRootMutex); @@ -1474,6 +1481,7 @@ SizeType32 WindowBlockManager::storeBlocks( auto numBlocks = blockKeys.size(); std::vector storedBlocks; + std::optional lastStoredId = std::nullopt; for (std::size_t blockCnt = 0; blockCnt < numBlocks; ++blockCnt) { auto const bid = blockIds[blockCnt]; @@ -1521,12 +1529,21 @@ SizeType32 WindowBlockManager::storeBlocks( searchRoot = block; numBlocksStoredForReuse++; } + if (pinBlocks) + { + searchRoot->incRefCount(); + } + lastStoredId = searchRoot->getBlockId(); } if (mEventManager) { mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } +<<<<<<< HEAD return numBlocksStoredForReuse; +======= + return lastStoredId; +>>>>>>> 07a060666 (fix pp bugs) } void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) @@ -1637,21 +1654,53 @@ std::deque BlockManager::getLatestEvents(std::optionalgetEvents(timeout) : std::deque{}; } -void BlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest) +std::optional BlockManager::storeBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { + std::optional lastStoredId = std::nullopt; + for (auto& [_, manager] : mWindowBlockManagers) + { + lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); + } + return lastStoredId; +} + +std::optional BlockManager::releaseBlocks( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) +{ +<<<<<<< HEAD // Released block will be stored when reuse is enabled. // Reuse is implied to be enabled if llmRequest is provided. +======= + // When releasing the blocks for a sequence, we store those blocks for potential reuse only if: + // - Block reuse is enabled. + // - A request was provided to this function call to identify which tokens these blocks cover + // - Beam search is NOT enabled <=> beam width == 1 + // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit + // the max attention window). + // - The sequence did not switch to cyclic kv-cache during generation phase. + // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. + // - The sequence is not a dummy request. + bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic() + && !llmRequest->isDummyRequest(); + std::optional lastStoredId = std::nullopt; +>>>>>>> 07a060666 (fix pp bugs) for (auto& [_, manager] : mWindowBlockManagers) { if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1) { +<<<<<<< HEAD manager.releaseBlocks(sequence, std::nullopt); } else { manager.releaseBlocks(sequence, llmRequest); +======= + lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); +>>>>>>> 07a060666 (fix pp bugs) } } + return lastStoredId; } void BlockManager::pinBlocks(GenerationRequest& sequence) @@ -1738,7 +1787,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< { // store all blocks TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); - storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); + (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); return; } @@ -1749,7 +1798,7 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< if (prevBlock->getPrevBlock() == nullptr) { TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str()); - storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); + (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); return; } @@ -1760,10 +1809,30 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< return; } TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str()); - storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); + (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } +<<<<<<< HEAD void WindowBlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest) +======= +std::optional WindowBlockManager::storeBlocksForReuse( + GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) +{ + auto constexpr beamIdx = 0; + auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); + auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize); + + // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't + // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume + // the last token's state is not filled yet. + auto const usableSize = static_cast(uniqueTokens.size()) - 1; + auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); + auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); + return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks); +} + +void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) +>>>>>>> 07a060666 (fix pp bugs) { auto const requestId = sequence.getRequestId(); @@ -2294,7 +2363,8 @@ void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest) mBlockManager.storeNewBlock(sequence, llmRequest); } -void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef llmRequest) +std::optional KVCacheManager::removeSequence( + RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) { TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); auto sequenceNode = [this, requestId] @@ -2302,18 +2372,31 @@ void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef lastStoredId = std::nullopt; if (!sequenceNode.empty()) { if (mEnableBlockReuse) { - mBlockManager.releaseBlocks(sequenceNode.mapped(), llmRequest); + lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), llmRequest, pinBlocks); } else { - mBlockManager.releaseBlocks(sequenceNode.mapped(), std::nullopt); + lastStoredId = mBlockManager.releaseBlocks(sequenceNode.mapped(), std::nullopt, pinBlocks); } } TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); + return lastStoredId; +} + +std::optional KVCacheManager::storeBlocksForReuse( + RequestIdType requestId, OptionalRef llmRequest, bool pinBlocks) +{ + TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); + auto& sequence = getSequence(requestId); + std::optional lastStoredId + = mBlockManager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); + TLLM_LOG_TRACE("[%s]::%s stop", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__); + return lastStoredId; } void KVCacheManager::schedulingRemoveSequence(RequestIdType requestId) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 8cabb0cb3c6..ea474ebb154 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -67,7 +67,7 @@ std::optional from_torch(std::optiona class PyKvCacheManager : public tbk::BaseKVCacheManager { public: - NB_TRAMPOLINE(tbk::BaseKVCacheManager, 28); + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 29); // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors void allocatePools(bool useUvm = false) override @@ -116,10 +116,17 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(addSequence, requestId, inputLength, beamWidth, llmRequest); } - void removeSequence(tb::LlmRequest::RequestIdType requestId, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + std::optional removeSequence(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, + bool pinOnRelease = false) override { - NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest); + NB_OVERRIDE_PURE(removeSequence, requestId, llmRequest, pinOnRelease); + } + + std::optional storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest, bool pinBlocks) override + { + NB_OVERRIDE_PURE(storeBlocksForReuse, requestId, llmRequest, pinBlocks); } tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override @@ -460,6 +467,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, nb::call_guard()) .def_prop_ro("cross_kv", &BaseKVCacheManager::isCrossKv) .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard()) + .def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse, + nb::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, nb::call_guard()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index f83b4137304..f48e092a1ce 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -103,10 +103,19 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, llmRequest); } - void removeSequence(tb::LlmRequest::RequestIdType requestId, - tensorrt_llm::common::OptionalRef llmRequest = std::nullopt) override + std::optional removeSequence(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest = std::nullopt, + bool pinOnRelease = false) override { - PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, removeSequence, requestId, llmRequest); + PYBIND11_OVERLOAD_PURE(std::optional, tbk::BaseKVCacheManager, removeSequence, + requestId, llmRequest, pinOnRelease); + } + + std::optional storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId, + tensorrt_llm::common::OptionalRef llmRequest, bool pinBlocks) override + { + PYBIND11_OVERLOAD_PURE(std::optional, tbk::BaseKVCacheManager, storeBlocksForReuse, + requestId, llmRequest, pinBlocks); } tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override @@ -462,6 +471,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, py::call_guard()) .def_property_readonly("cross_kv", &BaseKVCacheManager::isCrossKv) .def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard()) + .def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse, + py::call_guard()) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, py::call_guard()) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ac5901ccfd0..9f642384647 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -912,16 +912,17 @@ def _executor_loop_pp(self): with torch.cuda.nvtx.range("_handle_previous_batch_pp"): self._update_requests(previous_batch.sample_state) - if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs: - self._send_disagg_ctx_cache( - previous_batch.scheduled_ctx_reqs) - self._handle_canceled_requests() self._handle_logits_communication( previous_batch, prev_microbatch_id) - finished_requests = self._handle_responses() + context_requests = [] + if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs: + context_requests = previous_batch.scheduled_ctx_reqs + + finished_requests = self._handle_responses( + context_requests) previous_scheduled_batch = previous_batch.sample_state.scheduled_requests self.resource_manager.update_resources( previous_scheduled_batch) @@ -1812,8 +1813,6 @@ def _do_terminate_request(self, request: LlmRequest): request, cache_block_ids): self.resource_manager.free_resources(request) else: - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and request.is_context_only_request: - self.kv_cache_manager.pin_blocks(request.py_request_id) self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") @@ -1883,13 +1882,18 @@ def _handle_first_token_response(self, scheduled_batch): self._enqueue_responses(new_responses) @nvtx_range("_handle_responses") - def _handle_responses(self): + def _handle_responses(self, context_requests: List[LlmRequest] = None): new_responses = [] requests_to_terminate = [] new_active_requests = [] logger.debug( f'------before _handle_responses, rank = {self.dist.rank}, output = {self.active_requests}' ) + if context_requests is not None: + context_requests_map = set( + [request.py_request_id for request in context_requests]) + else: + context_requests_map = set() for request in self.active_requests: req_id = request.py_request_id # no responses for dummy request, and finish it @@ -1922,19 +1926,30 @@ def _handle_responses(self): request_done = request.is_finished new_responses.append((req_id, response)) - if request_done or request.is_disagg_context_complete_state: - if request.is_disagg_context_transmission_state: - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: - requests_to_terminate.append(request) - self.ctx_in_transmission_requests.append( - (request, - self.kv_cache_manager.get_last_block_id( - request.py_request_id))) - else: + if request_done: + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: + if request.py_request_id in context_requests_map: + block_id = self.kv_cache_manager.store_blocks_for_reuse( + request, True) + self.ctx_in_transmission_requests.append( + (request, block_id)) requests_to_terminate.append(request) + else: + if request.is_disagg_context_transmission_state: + self.ctx_in_transmission_requests.append( + (request, None)) + else: + if request.py_request_id in context_requests_map: + self.ctx_in_transmission_requests.append( + (request, None)) + else: + requests_to_terminate.append(request) else: new_active_requests.append(request) - self.active_requests.clear() + + if context_requests is not None and self.kv_cache_transceiver: + self._send_disagg_ctx_cache(context_requests) + self.active_requests.extend(new_active_requests) for request in requests_to_terminate: self._terminate_request(request) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 66e2299d22b..8d30c73efce 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -545,8 +545,15 @@ def update_resources(self, scheduled_batch: ScheduledRequests): for request in scheduled_batch.context_requests: self.impl.store_context_blocks(request) - def free_resources(self, request: LlmRequest): - self.impl.remove_sequence(request.py_request_id, request) + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): + return self.impl.remove_sequence(request.py_request_id, request, + pin_on_release) + + def store_blocks_for_reuse(self, + request: LlmRequest, + pin_blocks: bool = False): + return self.impl.store_blocks_for_reuse(request.py_request_id, request, + pin_blocks) @staticmethod def calculate_scaling_factor_size_bytes( From 4ccfb4ee65af4f9aa36484ac8121c2f7c836b667 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:20:04 -0700 Subject: [PATCH 09/17] Fix pp bugs Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 1 - tensorrt_llm/_torch/pyexecutor/py_executor.py | 62 +++++++++++-------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index fb6aa5cc67f..f8b74d7d48e 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -69,7 +69,6 @@ void LlmRequest::createSerializedResult( /// Note that there is some dependency on the order of operations in this method. Modify with care! std::optional LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank) { - TLLM_CHECK(!isDisaggContextCompleteState()); if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))) { return std::nullopt; diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 9f642384647..0eb85f86233 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -912,17 +912,24 @@ def _executor_loop_pp(self): with torch.cuda.nvtx.range("_handle_previous_batch_pp"): self._update_requests(previous_batch.sample_state) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in previous_batch.scheduled_ctx_reqs: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests.append( + (req, block_id)) + + self._send_disagg_ctx_cache( + previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() self._handle_logits_communication( previous_batch, prev_microbatch_id) - context_requests = [] - if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs: - context_requests = previous_batch.scheduled_ctx_reqs - - finished_requests = self._handle_responses( - context_requests) + finished_requests = self._handle_responses() previous_scheduled_batch = previous_batch.sample_state.scheduled_requests self.resource_manager.update_resources( previous_scheduled_batch) @@ -1105,6 +1112,15 @@ def _executor_loop(self): self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in scheduled_batch.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests.append( + (req, block_id)) if self.kv_cache_transceiver: ctx_transmission_reqs = self._send_disagg_ctx_cache( @@ -1243,6 +1259,16 @@ def _executor_loop_overlap(self): elif self.previous_batch is not None and not use_previous_draft_tokens: self._update_requests(self.previous_batch.sample_state) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in self.previous_batch.sample_state.scheduled_requests.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests.append( + (req, block_id)) + if self.guided_decoder is not None: # add_batch must be called again to have updated new tokens. self.guided_decoder.add_batch(scheduled_batch) @@ -1842,8 +1868,6 @@ def _enqueue_responses(self, responses: Iterable[Tuple[int, LlmResponse]]): if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses: return - logger.debug( - f'before gather, rank = {self.dist.rank}, responses = {responses}') if self.enable_attention_dp and self.dist.world_size != 1: if not self.gather_all_responses: responses_list = self.dist.tp_gather(responses) @@ -1882,18 +1906,13 @@ def _handle_first_token_response(self, scheduled_batch): self._enqueue_responses(new_responses) @nvtx_range("_handle_responses") - def _handle_responses(self, context_requests: List[LlmRequest] = None): + def _handle_responses(self): new_responses = [] requests_to_terminate = [] new_active_requests = [] logger.debug( f'------before _handle_responses, rank = {self.dist.rank}, output = {self.active_requests}' ) - if context_requests is not None: - context_requests_map = set( - [request.py_request_id for request in context_requests]) - else: - context_requests_map = set() for request in self.active_requests: req_id = request.py_request_id # no responses for dummy request, and finish it @@ -1928,28 +1947,17 @@ def _handle_responses(self, context_requests: List[LlmRequest] = None): if request_done: if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa: - if request.py_request_id in context_requests_map: - block_id = self.kv_cache_manager.store_blocks_for_reuse( - request, True) - self.ctx_in_transmission_requests.append( - (request, block_id)) requests_to_terminate.append(request) else: if request.is_disagg_context_transmission_state: self.ctx_in_transmission_requests.append( (request, None)) else: - if request.py_request_id in context_requests_map: - self.ctx_in_transmission_requests.append( - (request, None)) - else: - requests_to_terminate.append(request) + requests_to_terminate.append(request) else: new_active_requests.append(request) - if context_requests is not None and self.kv_cache_transceiver: - self._send_disagg_ctx_cache(context_requests) - + self.active_requests.clear() self.active_requests.extend(new_active_requests) for request in requests_to_terminate: self._terminate_request(request) From 519f04a8211ebdcbcdb9b4444a1d62a656200928 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:23:04 -0700 Subject: [PATCH 10/17] remove newly allocated blocks Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 13 --------- .../batch_manager/dataTransceiver.cpp | 3 +- .../batch_manager/kvCacheManager.cpp | 29 ------------------- cpp/tensorrt_llm/common/envUtils.h | 2 -- .../nanobind/batch_manager/kvCacheManager.cpp | 8 ----- .../pybind/batch_manager/kvCacheManager.cpp | 9 ------ 6 files changed, 2 insertions(+), 62 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 02199b119f9..927c8583248 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -604,9 +604,6 @@ class WindowBlockManager void replaceSharedBlock(GenerationRequest& sequence, SizeType32 blockIdx); - //! \brief Get the ids of all newly allocated (not reused) blocks for the sequence. - std::vector getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const; - [[nodiscard]] std::optional storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks = false); @@ -983,9 +980,6 @@ class BlockManager void replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx); - std::vector getNewlyAllocatedBlockIds( - GenerationRequest const& sequence, SizeType32 windowSize) const; - [[nodiscard]] std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); @@ -1419,10 +1413,6 @@ class BaseKVCacheManager std::vector const& requestIds, SizeType32 windowSize) const = 0; - [[nodiscard]] virtual std::vector getNewlyAllocatedBlockIds( - LlmRequest::RequestIdType requestId, SizeType32 windowSize) const - = 0; - /// @brief Get the last block id (beam 0) for a given sequence and window size [[nodiscard]] virtual std::optional getLastBlockId(LlmRequest::RequestIdType requestId) const = 0; @@ -1780,9 +1770,6 @@ class KVCacheManager : public BaseKVCacheManager std::vector>> getBatchCacheBlockIds( std::vector const& requestIds, SizeType32 windowSize) const override; - std::vector getNewlyAllocatedBlockIds( - LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override; - runtime::ITensor::SharedPtr getUniquePrimaryPool() const override; runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 3ee6fbb25c3..9ceaff73bcb 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -363,7 +363,8 @@ class CacheSender::Impl { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, - info.getIndexFromEnd(), info.getLastBlockKey(), nullptr, !common::getEnvKVCacheTransferOutputPath().empty()); + info.getIndexFromEnd(), info.getLastBlockKey(), nullptr, + !common::getEnvKVCacheTransferOutputPath().empty()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 212a2b4bce0..62034afd04b 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1597,29 +1597,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp } } -std::vector BlockManager::getNewlyAllocatedBlockIds( - GenerationRequest const& sequence, SizeType32 windowSize) const -{ - return mWindowBlockManagers.at(windowSize).getNewlyAllocatedBlockIds(sequence); -} - -std::vector WindowBlockManager::getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const -{ - std::vector allocatedBlockIds; - for (auto const& beamBlockIds : sequence.getCacheBlockIds(mWindowSize)) - { - for (auto const& blockId : beamBlockIds) - { - auto const& block = mAllBlocksById.at(blockId); - if (!blockInRadixTree(block)) - { - allocatedBlockIds.push_back(blockId); - } - } - } - return allocatedBlockIds; -} - void BlockManager::releaseLastBlock(GenerationRequest& sequence, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).releaseLastBlock(sequence); @@ -2789,12 +2766,6 @@ std::vector>> KVCacheManager::getBatchCacheB return result; } -std::vector KVCacheManager::getNewlyAllocatedBlockIds( - LlmRequest::RequestIdType requestId, SizeType32 windowSize) const -{ - return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize); -} - std::optional KVCacheManager::getLastBlockId(LlmRequest::RequestIdType requestId) const { auto const& seq = getSequence(requestId); diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index f5c0d854ba4..b7e5379c786 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -66,8 +66,6 @@ std::string getEnvNixlInterface(); bool getEnvDisaggLayerwise(); -bool getEnvDisableSelectiveCacheTransfer(); - bool getEnvParallelCacheSend(); bool getEnvRequestKVCacheConcurrent(); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index ea474ebb154..589bf6e2bf5 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -199,12 +199,6 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getBatchCacheBlockIds, requestIds, windowSize); } - std::vector getNewlyAllocatedBlockIds( - tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override - { - NB_OVERRIDE_PURE(getNewlyAllocatedBlockIds, requestId, windowSize); - } - SizeType32 getUsedNumBlocks() const override { NB_OVERRIDE_PURE(getUsedNumBlocks); @@ -472,8 +466,6 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, nb::call_guard()) - .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, - nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index f48e092a1ce..2f40cd9de9d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -195,13 +195,6 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager getBatchCacheBlockIds, requestIds, windowSize); } - std::vector getNewlyAllocatedBlockIds( - tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override - { - PYBIND11_OVERLOAD_PURE( - std::vector, tbk::BaseKVCacheManager, getNewlyAllocatedBlockIds, requestId, windowSize); - } - SizeType32 getUsedNumBlocks() const override { PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getUsedNumBlocks); @@ -476,8 +469,6 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard()) .def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds, py::call_guard()) - .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, - py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) From 78e2eff7a966f4a8e1a4a87af8966a36d48f4a69 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:47:39 -0700 Subject: [PATCH 11/17] Fix accuracy bug Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 8460d7ca364..3663fc05a53 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -72,8 +72,8 @@ BlockRange getBlockRangeForReceiving( auto const totalUniqueTokens = llmRequest.getPromptLen(); auto const usedBlocks = std::min( static_cast((totalUniqueTokens + tokensPerBlock - 1) / tokensPerBlock), totalBlocks); - auto const reusedBlocks = std::min( - static_cast((prepopulatedTokens + tokensPerBlock - 1) / tokensPerBlock), usedBlocks); + auto const reusedBlocks + = std::min(static_cast((prepopulatedTokens / tokensPerBlock)), usedBlocks); std::vector newBlockIds; if (reusedBlocks < usedBlocks) From 0c310fa09d96560b044f1c0cfa4ae70db9d41995 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:27:10 -0700 Subject: [PATCH 12/17] Fix unit test Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 2918ddd10a3..af9c1b096bb 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -208,7 +208,7 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- auto totalNumBlocks = mMaxNumSequences * numBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; - auto constexpr enableBlockReuse = true; + auto constexpr enableBlockReuse = false; auto constexpr onboardBlocks = true; auto constexpr dataType = nvinfer1::DataType::kFLOAT; @@ -577,7 +577,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam Date: Tue, 23 Sep 2025 16:30:42 -0700 Subject: [PATCH 13/17] fix Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/dataTransceiver.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 8056f683a40..47f1a9bc1dd 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -129,6 +129,7 @@ class TransferSession int32_t mIndexFromEnd{0}; BlockKey mLastBlockKey{}; }; + using UniqueToken = tensorrt_llm::runtime::UniqueToken; using BlockKey = tensorrt_llm::batch_manager::kv_cache_manager::BlockKey; From 34bf1d423aae728988756404e4c790c6ad79de9c Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Wed, 24 Sep 2025 06:52:21 -0700 Subject: [PATCH 14/17] Fix compile error Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 19 ++++--- .../batch_manager/kvCacheManager.cpp | 51 ++++--------------- .../utils/inflightBatchingUtils.cpp | 4 +- 3 files changed, 24 insertions(+), 50 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 927c8583248..26302ee39af 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -41,6 +41,7 @@ #include #include #include +#include #include namespace kvc = tensorrt_llm::executor::kv_cache; @@ -614,7 +615,8 @@ class WindowBlockManager //! \brief Release blocks of the sequence. //! \details When llmRequest is provided and reuse is enabled, blocks will be stored. - void releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest); + std::optional releaseBlocks( + GenerationRequest& sequence, OptionalRef llmRequest); //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId); @@ -788,9 +790,10 @@ class WindowBlockManager //! \param blockKeys Key of each block. //! \param blockIds Id of each block. //! \param pinBlocks If true, increment ref count for blocks while storing (pin on store). - //! \return The id of the last block stored in the reuse tree, if any were stored. - [[nodiscard]] std::optional storeBlocks(std::vector const& blockKeys, - std::vector const& blockIds, bool pinBlocks = false); + //! \return Pair of (num blocks stored for reuse, id of the last block stored if any). + [[nodiscard]] std::pair> storeBlocks( + std::vector const& blockKeys, std::vector const& blockIds, + bool pinBlocks = false); [[nodiscard]] bool verifyQueueIntegrity(); @@ -817,6 +820,7 @@ class WindowBlockManager { return mIsSWA; } + [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKey( BlockKey const& blockKey); @@ -980,7 +984,7 @@ class BlockManager void replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx); - [[nodiscard]] std::optional releaseBlocks( + std::optional releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest = std::nullopt, bool pinBlocks = false); [[nodiscard]] std::optional storeBlocksForReuse( @@ -1013,8 +1017,9 @@ class BlockManager void offloadBlock(BlockPtr const& block, SizeType32 windowSize, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - [[nodiscard]] std::optional storeBlocks(std::vector const& blockKeys, - std::vector const& blockIds, SizeType32 windowSize, bool pinBlocks = false) + [[nodiscard]] std::pair> storeBlocks( + std::vector const& blockKeys, std::vector const& blockIds, + SizeType32 windowSize, bool pinBlocks = false) { return mWindowBlockManagers.at(windowSize).storeBlocks(blockKeys, blockIds, pinBlocks); } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 62034afd04b..2e4978edafa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -89,8 +89,6 @@ std::vector getAllSequenceBlocks(BlockPtr lastBlock) return sequenceBlocks; } - - } // namespace namespace tensorrt_llm::batch_manager::kv_cache_manager @@ -1463,13 +1461,8 @@ void WindowBlockManager::allocateBlock(GenerationRequest& sequence, bool shareAm } } -<<<<<<< HEAD -SizeType32 WindowBlockManager::storeBlocks( - std::vector const& blockKeys, std::vector const& blockIds) -======= -std::optional WindowBlockManager::storeBlocks( +std::pair> WindowBlockManager::storeBlocks( std::vector const& blockKeys, std::vector const& blockIds, bool pinBlocks) ->>>>>>> 07a060666 (fix pp bugs) { SizeType32 numBlocksStoredForReuse = 0; std::lock_guard lock(mCachedBlocksRootMutex); @@ -1539,11 +1532,7 @@ std::optional WindowBlockManager::storeBlocks( { mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize); } -<<<<<<< HEAD - return numBlocksStoredForReuse; -======= - return lastStoredId; ->>>>>>> 07a060666 (fix pp bugs) + return {numBlocksStoredForReuse, lastStoredId}; } void BlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeType32 windowSize, SizeType32 blockIdx) @@ -1645,36 +1634,18 @@ std::optional BlockManager::storeBlocksForReuse( std::optional BlockManager::releaseBlocks( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { -<<<<<<< HEAD // Released block will be stored when reuse is enabled. // Reuse is implied to be enabled if llmRequest is provided. -======= - // When releasing the blocks for a sequence, we store those blocks for potential reuse only if: - // - Block reuse is enabled. - // - A request was provided to this function call to identify which tokens these blocks cover - // - Beam search is NOT enabled <=> beam width == 1 - // - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit - // the max attention window). - // - The sequence did not switch to cyclic kv-cache during generation phase. - // A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached. - // - The sequence is not a dummy request. - bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic() - && !llmRequest->isDummyRequest(); std::optional lastStoredId = std::nullopt; ->>>>>>> 07a060666 (fix pp bugs) for (auto& [_, manager] : mWindowBlockManagers) { if (!llmRequest.has_value() || llmRequest->isDummyRequest() || sequence.getBeamWidth() > 1) { -<<<<<<< HEAD - manager.releaseBlocks(sequence, std::nullopt); + lastStoredId = manager.releaseBlocks(sequence, std::nullopt); } else { - manager.releaseBlocks(sequence, llmRequest); -======= - lastStoredId = manager.storeBlocksForReuse(sequence, llmRequest, pinBlocks); ->>>>>>> 07a060666 (fix pp bugs) + lastStoredId = manager.releaseBlocks(sequence, llmRequest); } } return lastStoredId; @@ -1789,9 +1760,6 @@ void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef< (void) storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]); } -<<<<<<< HEAD -void WindowBlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef llmRequest) -======= std::optional WindowBlockManager::storeBlocksForReuse( GenerationRequest& sequence, OptionalRef llmRequest, bool pinBlocks) { @@ -1805,14 +1773,14 @@ std::optional WindowBlockManager::storeBlocksForReuse( auto const usableSize = static_cast(uniqueTokens.size()) - 1; auto blockedUniqueTokens = chopVectorIntoBlocks(uniqueTokens, usableSize, mTokensPerBlock, true); auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks); + return storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks).second; } -void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) ->>>>>>> 07a060666 (fix pp bugs) +std::optional WindowBlockManager::releaseBlocks( + GenerationRequest& sequence, OptionalRef llmRequest) { auto const requestId = sequence.getRequestId(); - + std::optional lastStoredId = std::nullopt; auto node = mAllocatedBlocksPerSeq.extract(requestId); TLLM_CHECK(node); auto& allocatedBlocks = node.mapped(); @@ -1836,7 +1804,7 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) std::transform(allocatedBlocks.begin(), allocatedBlocks.end(), cacheBlockIds.begin(), [](BlockPtr const& block) { return block->getBlockId(); }); - auto numBlocksStoredForReuse = storeBlocks(std::move(blockKeys), cacheBlockIds); + auto [numBlocksStoredForReuse, lastStoredId] = storeBlocks(std::move(blockKeys), cacheBlockIds); TLLM_LOG_DEBUG("%s::releaseBlocks Request %lu, %d blocks stored for reuse", mLogPrefix.c_str(), sequence.getRequestId(), numBlocksStoredForReuse); } @@ -1857,6 +1825,7 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) } // Remove stored block ids in sequence sequence.clearCacheBlocks(mWindowSize); + return lastStoredId; } void BlockManager::schedulingReleaseBlocks(RequestIdType requestId) diff --git a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp index 74ed6102ebc..4c2d4536fd9 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp @@ -263,11 +263,11 @@ void terminateRequest(SequenceSlotManager& seqSlotManager, LlmRequest& llmReq, S auto const requestId = llmReq.mRequestId; if (kvCacheManager) { - kvCacheManager->removeSequence(requestId, llmReq); + (void) kvCacheManager->removeSequence(requestId, llmReq); } if (crossKvCacheManager) { - crossKvCacheManager->removeSequence(requestId, llmReq); + (void) crossKvCacheManager->removeSequence(requestId, llmReq); } if (pause && !llmReq.isGenerationCompleteState()) { From db16b79cd62f01795870371a238f541e53d4cff9 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 25 Sep 2025 10:01:37 -0700 Subject: [PATCH 15/17] Fix ci Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/executor/executorImpl.cpp | 6 +- cpp/tensorrt_llm/executor/executorImpl.h | 5 +- .../batch_manager/kvCacheManagerTest.cpp | 68 +++++++++---------- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index cba4cecf5f3..101e1bef067 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -2233,15 +2233,11 @@ Executor::Impl::RequestList Executor::Impl::populateNewResponses( // move the in transmission requests to another tracker if (llmReq->isDisaggContextTransmissionState()) { - // Save either lastBlockId (reuse enabled and no VSWA) or just the request std::optional lastBlockId{}; auto kvMgr = mModel->getKVCacheManager(); if (kvMgr && kvMgr->isEnableBlockReuse() && !kvMgr->getBlockManager().isVariableWindow()) { - if (auto last = kvMgr->getLastBlockId(llmReq->mRequestId)) - { - lastBlockId = last.value(); - } + lastBlockId = kvMgr->storeBlocksForReuse(llmReq->mRequestId, llmReq, /*pinBlocks=*/true); mModel->terminateRequest(llmReq); } inTransmissionRequests.push_back(InTransmissionItem{*it, lastBlockId}); diff --git a/cpp/tensorrt_llm/executor/executorImpl.h b/cpp/tensorrt_llm/executor/executorImpl.h index 5a30e0c8a0a..19bd00bd65b 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.h +++ b/cpp/tensorrt_llm/executor/executorImpl.h @@ -79,10 +79,13 @@ class Executor::Impl using LlmRequestPtr = std::shared_ptr; using RequestList = std::list; + // When block reuse is enabled for context worker for disaggregated serving, + // we need to store the last block id so that we can unpin the block when + // the request is finished. struct InTransmissionItem { LlmRequestPtr request; - std::optional lastBlockId; // present when reuse enabled and not variable window + std::optional lastBlockId; }; using InTransList = std::list; diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index c42be64cb75..b89fab52c97 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -634,7 +634,7 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest) auto cacheBlockIds = kvCacheManager.getSequence(requestId).getCacheBlockIds(maxAttentionWindow).at(beamIdx); EXPECT_THAT(cacheBlockIds, ::testing::ElementsAreArray({0, 1, 2})); - kvCacheManager.removeSequence(requestId, llmRequest0); + (void) kvCacheManager.removeSequence(requestId, llmRequest0); inputTokens->pop_back(); BlockKey fullKey{*inputTokens}; @@ -2055,7 +2055,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) EXPECT_EQ(llmRequest0->getAllocNewBlocksPerRequest(), numBlocks); EXPECT_EQ(llmRequest0->getMissedBlocksPerRequest(), numBlocks); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest0)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId, llmRequest0)); requestId = 1; auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming); @@ -2245,9 +2245,9 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) // remove both sequences, blocks get stored // leaf block 3 (priority 90), context blocks 2, 1, 0 (priority 5) - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); // leaf block 7 (priority 5), context blocks 6, 5, 4 (priority 90) - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); // all blocks are available again. EXPECT_EQ(kvCacheManager.getNumFreeBlocks(), 8); @@ -2262,7 +2262,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); // leaf block 2 (priority 35), context blocks 3, 7 (priority 35) - kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(2, llmRequest2); // reuse blocks 0 and 1, new block 2 (lowest priority leaf) // Uses 3 blocks 0, 1, 2 which contain [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10] @@ -2313,7 +2313,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) llmRequest0->setKvCacheRetentionConfig( KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 80, 10ms)}, 80)); kvCacheManager.addSequence(0, inputLength0, beamWidth, llmRequest0); - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); auto inputTokens1 = std::make_shared(VecTokens{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); auto const inputLength1 = static_cast(inputTokens1->size()); @@ -2321,7 +2321,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) llmRequest1->setKvCacheRetentionConfig( KvCacheRetentionConfig({KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 50)}, 80)); kvCacheManager.addSequence(1, inputLength1, beamWidth, llmRequest1); - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); std::this_thread::sleep_for(std::chrono::milliseconds(50)); @@ -2333,7 +2333,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) auto const inputLength2 = static_cast(inputTokens2->size()); auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); - kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(2, llmRequest2); // Check that the [12, 13, 14, 15] block is still in the cache auto inputTokens3 = std::make_shared(VecTokens{12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); @@ -2389,7 +2389,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) kvCacheManager.addToken(0); llmRequest0->addNewToken(0, 0); } - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); } { auto inputTokens1 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2404,7 +2404,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) kvCacheManager.addToken(1); llmRequest1->addNewToken(0, 0); } - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); } std::this_thread::sleep_for(std::chrono::milliseconds(50)); @@ -2414,7 +2414,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) auto const inputLength2 = static_cast(inputTokens2->size()); auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); - kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(2, llmRequest2); auto inputTokens3 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength3 = static_cast(inputTokens3->size()); @@ -2459,7 +2459,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) auto llmRequest0 = std::make_shared(0, maxNewTokens, inputTokens0, samplingConfig, isStreaming); // get new blocks 0, 1, 2 kvCacheManager.addSequence(0, inputLength0, beamWidth, llmRequest0); - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); // store blocks 0, 1, 2 for reuse ([0,1,2,3], [4,5,6,7], [8,9,10]) // Offload the last two blocks of llmRequest0 to secondary memory @@ -2468,7 +2468,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) auto llmRequest1 = std::make_shared(1, maxNewTokens, inputTokens1, samplingConfig, isStreaming); // get blocks 3, 2, 1. This causes 2 and 1 to be offloaded to secondary kvCacheManager.addSequence(1, inputLength1, beamWidth, llmRequest1); - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); // store blocks 3, 2, 1 for reuse ([1,1,2,3], [4,5,6,7], [8,9,10]) // Match the middle block of request 0 @@ -2493,7 +2493,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) kvCacheManager.addToken(2); // The middle block remains in secondary, but the third block is in primary - kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(2, llmRequest2); auto inputTokens3 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 0, 0}); auto const inputLength3 = static_cast(inputTokens3->size()); @@ -2540,7 +2540,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) kvCacheManager.addToken(0); // The second block allocated should be first in line for eviction. - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); auto inputTokens1 = std::make_shared(VecTokens{1, 1, 2, 3}); auto const inputLength1 = static_cast(inputTokens1->size()); @@ -2558,8 +2558,8 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) auto llmRequest2 = std::make_shared(2, maxNewTokens, inputTokens2, samplingConfig, isStreaming); kvCacheManager.addSequence(2, inputLength2, beamWidth, llmRequest2); - kvCacheManager.removeSequence(1, llmRequest1); - kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(2, llmRequest2); auto inputTokens3 = std::make_shared(VecTokens{2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); auto const inputLength3 = static_cast(inputTokens3->size()); @@ -2650,7 +2650,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), 0); // Free first sequence - kvCacheManager.removeSequence(requestId0, llmRequest0); + (void) kvCacheManager.removeSequence(requestId0, llmRequest0); // Verify that 3 primary blocks are free. EXPECT_EQ(blockManager.getNumFreeBlocks(), 3); @@ -2687,7 +2687,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) EXPECT_FALSE(block2->isPrimary()); // Cleanup - kvCacheManager.removeSequence(requestId1, llmRequest1); + (void) kvCacheManager.removeSequence(requestId1, llmRequest1); } TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) @@ -2884,7 +2884,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth); EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth * 2); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); auto currentNumBlocks = totalNumBlocks; @@ -2961,7 +2961,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numBlocksPerSeq); EXPECT_NO_THROW(kvCacheManager.rewindKVCache(requestId, 4)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); auto currentNumBlocks = totalNumBlocks; @@ -3100,7 +3100,7 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth); EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocks - maxBeamWidth * 2); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); auto currentNumBlocks = totalNumBlocks; @@ -3215,7 +3215,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowSmallerThanBlockSizeT EXPECT_EQ(numBlocks, 3); EXPECT_THAT(seq0.getCacheBlockIds(onlyWindowSize).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId, llmRequest)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId, llmRequest)); // no blocks stored because reuse is disabled EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); @@ -3275,7 +3275,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) EXPECT_EQ(std::get(events.front().data).blocks[0].cacheLevel, 0); kvCacheManager.addToken(0); llmRequest0->addNewToken(0, 0); - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); auto newEvents = getEvents(kvCacheManager); EXPECT_EQ(newEvents.size(), 1); @@ -3293,7 +3293,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) llmRequest1->setLoraTaskId(42); kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1); kvCacheManager.storeContextBlocks(*llmRequest1); - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); events = getEvents(kvCacheManager); @@ -3323,8 +3323,8 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) EXPECT_THAT(std::get(events.front().data).blockHashes, ::testing::ElementsAreArray({firstSwapped})); - kvCacheManager.removeSequence(2, llmRequest2); - kvCacheManager.removeSequence(3, llmRequest3); + (void) kvCacheManager.removeSequence(2, llmRequest2); + (void) kvCacheManager.removeSequence(3, llmRequest3); events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 2); @@ -3427,7 +3427,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamOverflow) events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 0); - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); events = getEvents(kvCacheManager); @@ -3473,7 +3473,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) std::vector{tle::KvCacheRetentionConfig::TokenRangeRetentionConfig(0, std::nullopt, 50)}, 35)); kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0); kvCacheManager.storeContextBlocks(*llmRequest0); - kvCacheManager.removeSequence(0, llmRequest0); + (void) kvCacheManager.removeSequence(0, llmRequest0); auto events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 3); @@ -3489,7 +3489,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) auto llmRequest1 = std::make_shared(1, 0, inputTokens1, samplingConfig, true); kvCacheManager.addSequence(1, inputTokens1->size(), beamWidth, llmRequest1); kvCacheManager.storeContextBlocks(*llmRequest1); - kvCacheManager.removeSequence(1, llmRequest1); + (void) kvCacheManager.removeSequence(1, llmRequest1); events = getEvents(kvCacheManager); EXPECT_EQ(events.size(), 1); // The second partial block gets stored. No priorities updated. @@ -3566,7 +3566,7 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById) KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, maxAttentionWindow, true, onboardBlocks); kvCacheManager.allocatePools(false); LlmRequest::RequestIdType requestId{0}; @@ -3583,7 +3583,7 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById) kvCacheManager.pinBlocks(requestId); auto lastBlockIdOpt = kvCacheManager.getLastBlockId(requestId); ASSERT_TRUE(lastBlockIdOpt.has_value()); - kvCacheManager.removeSequence(requestId, llmRequest); + (void) kvCacheManager.removeSequence(requestId, llmRequest); auto const freeAfterRemovePinned = kvCacheManager.getNumFreeBlocks(); EXPECT_LT(freeAfterRemovePinned, totalBlocks); @@ -3867,7 +3867,7 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerSinkTokenLengthTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocksCtx - maxBeamWidth * 2 + 1); EXPECT_NO_THROW(kvCacheManager.addToken(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks - numSharedBlocksCtx - maxBeamWidth * 2 + 1); - EXPECT_NO_THROW(kvCacheManager.removeSequence(requestId)); + EXPECT_NO_THROW((void) kvCacheManager.removeSequence(requestId)); EXPECT_EQ(blockManager.getNumFreeBlocks(), totalNumBlocks); auto currentNumBlocks = totalNumBlocks; @@ -4479,7 +4479,7 @@ TEST_P(FillKvCacheAndCompleteRequestsTest, FillKvCacheWithRequestsAndCompleteOne llmRequest.addNewToken(0, 0); kvCacheManager->addToken(llmRequest.mRequestId); } - kvCacheManager->removeSequence(llmRequest.mRequestId, llmRequest); + (void) kvCacheManager->removeSequence(llmRequest.mRequestId, llmRequest); } auto const [expectedNumFreeBlocks, _] = params.kvCacheManagerInstantiationParameters.blocksPerWindow.at( params.kvCacheManagerInstantiationParameters.maxAttentionWindow); From c2ec52de9cc1e283e36ff81e04fcddfd43d66be5 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:29:44 -0700 Subject: [PATCH 16/17] Fix ci Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp | 6 ------ tensorrt_llm/_torch/pyexecutor/py_executor.py | 5 +++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 9ceaff73bcb..fe30046df98 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -244,12 +244,6 @@ class CacheSender::Impl public: using RequestIdType = LlmRequest::RequestIdType; - struct Response - { - LlmRequest* mRequest; - std::promise mPromise; - }; - Impl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0eb85f86233..9bbf93c6042 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -922,8 +922,9 @@ def _executor_loop_pp(self): self.ctx_in_transmission_requests.append( (req, block_id)) - self._send_disagg_ctx_cache( - previous_batch.scheduled_ctx_reqs) + if self.kv_cache_transceiver: + self._send_disagg_ctx_cache( + previous_batch.scheduled_ctx_reqs) self._handle_canceled_requests() self._handle_logits_communication( From 66044e8b0f10fc213215b8f782d421cfdb8b88a0 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:14:57 -0700 Subject: [PATCH 17/17] review comment Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h | 9 ++++----- cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h | 2 +- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 5 ++--- .../unit_tests/batch_manager/kvCacheManagerTest.cpp | 5 ++--- jenkins/Build.groovy | 4 +++- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 26302ee39af..ce0f28dab56 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -821,8 +821,7 @@ class WindowBlockManager return mIsSWA; } - [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKey( - BlockKey const& blockKey); + [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey); //! \brief Unpin blocks by starting from a block id and walking prev pointers. void unpinBlocksById(KVCacheBlock::IdType blockId); @@ -1194,7 +1193,7 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } - [[nodiscard]] std::optional> findBlocksInReuseTreeByBlockKey( + [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey( BlockKey const& blockKey, SizeType32 windowSize) { return mWindowBlockManagers.at(windowSize).findBlocksInReuseTreeByBlockKey(blockKey); @@ -1491,7 +1490,7 @@ class BaseKVCacheManager [[nodiscard]] virtual CacheType getCacheType() const = 0; - [[nodiscard]] virtual std::optional> findBlocksInReuseTreeByBlockKey( + [[nodiscard]] virtual std::shared_ptr findBlocksInReuseTreeByBlockKey( BlockKey const& blockKey, SizeType32 windowSize) = 0; @@ -1794,7 +1793,7 @@ class KVCacheManager : public BaseKVCacheManager mBlockManager.flushIterationEvents(); } - std::optional> findBlocksInReuseTreeByBlockKey( + std::shared_ptr findBlocksInReuseTreeByBlockKey( BlockKey const& blockKey, SizeType32 windowSize) override { return mBlockManager.findBlocksInReuseTreeByBlockKey(blockKey, windowSize); diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h index c1c686f6f28..64f1e85b8d7 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h @@ -45,7 +45,7 @@ class BlockRange { auto const windowSize = firstWindowSize(cacheManager); // Find the last block in the reuse tree for the provided full sequence of block keys - auto lastBlock = *cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize); + auto lastBlock = cacheManager.findBlocksInReuseTreeByBlockKey(lastBlockKey, windowSize); // TODO: handle the case where the last block is not found TLLM_CHECK_WITH_INFO(lastBlock, "Couldn't find the requested block in the reuse tree"); int32_t const numBlocksToCollect = indexFromEnd + 1; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 2e4978edafa..cff2b1f41fd 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1096,8 +1096,7 @@ bool WindowBlockManager::blockInRadixTree(BlockPtr const& block) return !block->getUniqueTokens().empty() && block->getPrevBlock() != nullptr; } -std::optional> WindowBlockManager::findBlocksInReuseTreeByBlockKey( - BlockKey const& blockKey) +std::shared_ptr WindowBlockManager::findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey) { std::lock_guard lock(mCachedBlocksRootMutex); auto blockedUniqueTokens @@ -1118,7 +1117,7 @@ std::optional> WindowBlockManager::findBlocksInReu if (matchingBlock == nullptr) { - return std::nullopt; + return nullptr; } searchRoot = std::move(matchingBlock); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index b89fab52c97..81437e5e1c4 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -639,9 +639,8 @@ TEST_F(KVCacheManagerTest, FindBlocksInReuseTreeByBlockKeysTest) inputTokens->pop_back(); BlockKey fullKey{*inputTokens}; auto const foundFull = kvCacheManager.findBlocksInReuseTreeByBlockKey(fullKey, maxAttentionWindow); - ASSERT_TRUE(foundFull.has_value()); - ASSERT_NE(foundFull.value(), nullptr); - auto const& lastBlock = foundFull.value(); + ASSERT_NE(foundFull, nullptr); + auto const& lastBlock = foundFull; // Check the chain back to previous blocks auto const prev2 = lastBlock->getPrevBlock(); diff --git a/jenkins/Build.groovy b/jenkins/Build.groovy index afbd55cb3d8..c84db49b9d5 100644 --- a/jenkins/Build.groovy +++ b/jenkins/Build.groovy @@ -102,17 +102,19 @@ def BUILD_CONFIGS = [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON", (TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz", (WHEEL_ARCHS): "90-real;100-real;103-real;120-real", + (BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA ], (CONFIG_LINUX_AARCH64_PYBIND): [ (WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl", (TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz", (WHEEL_ARCHS): "90-real;100-real;103-real;120-real", + (BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA ], (CONFIG_LINUX_AARCH64_LLVM) : [ (WHEEL_EXTRA_ARGS) : "--extra-cmake-vars WARNING_IS_ERROR=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_CUDA_HOST_COMPILER=clang -DCMAKE_LINKER_TYPE=LLD", (TARNAME) : "llvm-TensorRT-LLM-GH200.tar.gz", (WHEEL_ARCHS): "90-real;100-real;103-real;120-real", - (BUILD_JOBS_FOR_CONFIG): "6", // TODO: Remove after fix the build OOM issue on SBSA + (BUILD_JOBS_FOR_CONFIG): "4", // TODO: Remove after fix the build OOM issue on SBSA ], ]