Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
da1cbe4
Initial cpp binding stuff
jthomson04 Jul 22, 2025
4bd4df4
Basic connector tests
jthomson04 Jul 22, 2025
f83f0ed
Hook into torch runtime and py executor
jthomson04 Jul 23, 2025
d4e5178
Expose block pools as torch tensor
jthomson04 Jul 24, 2025
0c9fa7a
Little fixes
jthomson04 Jul 24, 2025
229e6c5
Scheduler Output bindings
jthomson04 Jul 25, 2025
614cb01
more little fixes - dont instantiate twice
jthomson04 Jul 25, 2025
6c26369
MEGA REFACTOR, move scheduler and worker into their own class, do ini…
jthomson04 Jul 27, 2025
d545a5d
Get num new matched tokens
jthomson04 Jul 28, 2025
9a1ba68
Suspend requests for async onboard
jthomson04 Jul 29, 2025
1f0a35b
async load and resume
jthomson04 Jul 29, 2025
e75e6c4
Little cleanup
jthomson04 Jul 29, 2025
66aa5a7
scheduler output for build_connector_meta
jthomson04 Jul 29, 2025
2f80a23
Worker-side hooks
jthomson04 Jul 29, 2025
a521b6a
Move a ton of stuff out of c++ into python
jthomson04 Jul 29, 2025
b80ed13
small refactorings and docs
jthomson04 Jul 30, 2025
50bcec3
A whole bunch of unit tests
jthomson04 Jul 30, 2025
e305010
precommit
jthomson04 Jul 30, 2025
fe45192
Fix wait_for_save
jthomson04 Jul 30, 2025
7ca84a2
start on integration tests
jthomson04 Jul 31, 2025
b85f749
Integration tests for async save and load
jthomson04 Jul 31, 2025
e16a38d
Simplify add token stuff
jthomson04 Jul 31, 2025
7081fe7
Tests for scheduler metadata
jthomson04 Jul 31, 2025
7d7dabe
Chunked prefill tests
jthomson04 Jul 31, 2025
65f58a4
simplify register_kv_caches handling
jthomson04 Aug 1, 2025
4140d52
remove changes to add token and update token
jthomson04 Aug 1, 2025
812fcf4
add support for the overlap scheduler + little refactoring
jthomson04 Aug 1, 2025
7b3795f
little cleanup
jthomson04 Aug 2, 2025
48e08ed
Little refactor, provide kv cache as a single contiguous tensor
jthomson04 Aug 5, 2025
1c3fe6f
Gate cuda graph support
jthomson04 Aug 6, 2025
914b34b
Include cache block ids in request_finished
jthomson04 Aug 6, 2025
5a5ea47
Little bugfixes and implement a basic example
jthomson04 Aug 7, 2025
2056e70
Address reviewer comments
jthomson04 Aug 7, 2025
96b71c4
more improvements + refactoring + docstrings
jthomson04 Aug 7, 2025
d0ad8a6
Nanobind support (finally)
jthomson04 Aug 11, 2025
c6afb96
Merge remote-tracking branch 'origin/main' into jthomson04/connector-api
jthomson04 Aug 11, 2025
b7d2ee6
coderabbit + refactor
jthomson04 Aug 11, 2025
03fa470
CI Integration, only support guarantee no evict, various coderabbit s…
jthomson04 Aug 11, 2025
921dd94
update state after alloc
jthomson04 Aug 11, 2025
35eeb02
Fix scheduler output
jthomson04 Aug 11, 2025
0fa51a7
Merge branch 'main' into jthomson04/connector-api
Tabrizian Aug 12, 2025
b142dde
fix license headers
jthomson04 Aug 12, 2025
36b1d0b
Merge branch 'main' into jthomson04/connector-api
Tabrizian Aug 12, 2025
0b210f0
fix tests and test list
jthomson04 Aug 13, 2025
209052a
Dont pass connector manager through add_sequence
jthomson04 Aug 13, 2025
5059269
Merge remote-tracking branch 'origin/main' into jthomson04/connector-api
jthomson04 Aug 14, 2025
6018fc9
Merge remote-tracking branch 'origin/main' into jthomson04/connector-api
jthomson04 Aug 15, 2025
df6350d
Init scheduler and worker concurrently
jthomson04 Aug 16, 2025
f24eff2
Merge branch 'main' into jthomson04/connector-api
jthomson04 Aug 19, 2025
d5f7f1d
maybe fix CI
jthomson04 Aug 19, 2025
ebfe401
Add fix for llm stability
jthomson04 Aug 20, 2025
f9a3960
Merge remote-tracking branch 'origin/main' into jthomson04/connector-api
jthomson04 Aug 20, 2025
60b3ad9
Merge branch 'main' into jthomson04/connector-api
jthomson04 Aug 21, 2025
a383d03
Dont call request_finished unless request has already been scheduled
jthomson04 Aug 21, 2025
19b03c4
Merge remote-tracking branch 'origin/main' into jthomson04/connector-api
jthomson04 Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Copyright looks incorrect

*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/common.h"

#include <utility>
#include <vector>

using SizeType32 = tensorrt_llm::runtime::SizeType32;
using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType;

/// See tensorrt_llm/_torch/pyexecutor/connector.py for details on the Connector API.

namespace tensorrt_llm::batch_manager::kv_connector
{

/// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences.
class KvCacheConnectorManager
{
public:
KvCacheConnectorManager() = default;
virtual ~KvCacheConnectorManager() = default;

/// @brief Handle the getNumNewMatchedTokens call inside the C++ KV Cache Manager.
/// @return The number of tokens that can be loaded from remote KV cache.
virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0;
};

} // namespace tensorrt_llm::batch_manager::kv_connector
20 changes: 15 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheType.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
Expand Down Expand Up @@ -536,7 +537,8 @@ class WindowBlockManager
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);

~WindowBlockManager();

Expand Down Expand Up @@ -833,6 +835,8 @@ class WindowBlockManager
bool mEnablePartialReuse;
// Whether partially matched blocks that are already in use should be copied and reused.
bool mCopyOnPartialReuse;
// The kv cache connector manager
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
};

class BlockManager
Expand All @@ -850,7 +854,8 @@ class BlockManager
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnPartialReuse = true);
bool copyOnPartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

BlockManager(BlockManager const&) = delete;
BlockManager& operator=(BlockManager const&) = delete;
Expand Down Expand Up @@ -1285,6 +1290,7 @@ class BaseKVCacheManager
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
= 0;

[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
[[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0;
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;

Expand Down Expand Up @@ -1371,7 +1377,8 @@ class KVCacheManager : public BaseKVCacheManager
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
Expand All @@ -1381,7 +1388,8 @@ class KVCacheManager : public BaseKVCacheManager
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
Expand All @@ -1391,7 +1399,8 @@ class KVCacheManager : public BaseKVCacheManager
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);
bool copyOnpartialReuse = true,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);

KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
Expand Down Expand Up @@ -1621,6 +1630,7 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<SizeType32> getNewlyAllocatedBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;

runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;

SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
Expand Down
58 changes: 46 additions & 12 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
, mTokensPerBlock{tokensPerBlock}
, mEventManager{std::move(eventManager)}
Expand All @@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
{
auto const uniqueWindowSizeToLayers
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);

TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
"KV Cache Connector is not supported with multiple window sizes");

auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());

mIsVariableWindow = numUniqueWindowSizes > 1;
Expand All @@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
copyOnPartialReuse);
copyOnPartialReuse, kvCacheConnectorManager);
}

auto const numAllPools = getNumPools();
Expand Down Expand Up @@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mDataType{dtype}
, mWindowSize{windowSize}
, mNumPrimaryBlocks{blocksInPrimaryPool}
Expand All @@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
, mTotalInputTokens{0.0}
, mEnablePartialReuse{enablePartialReuse}
, mCopyOnPartialReuse{copyOnPartialReuse}
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
{
std::map<SizeType32, SizeType32> numLayersPerPool;

Expand Down Expand Up @@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions);
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId,
inputLength, prepopulatedPromptLen);

SizeType32 numConnectorMatchedTokens = 0;

// If we're using a KV cache connector, check if any additional blocks can be loaded.
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
{
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
}

llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
}

// There are two versions of BlockManager::addSequence function.
Expand All @@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
void WindowBlockManager::addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
{
if (mKvCacheConnectorManager)
{
TLLM_LOG_WARNING(
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
"ignored.");
}

auto const requestId = sequence.getRequestId();
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
TLLM_CHECK(emplaceDone);
Expand Down Expand Up @@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
copyOnPartialReuse)
copyOnPartialReuse, kvCacheConnectorManager)
{
}

Expand All @@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mMaxBeamWidth(maxBeamWidth)
, mDataType(dtype)
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
Expand All @@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
enablePartialReuse, copyOnPartialReuse)
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
Expand All @@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
{
}

Expand Down Expand Up @@ -2381,6 +2407,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize);
}

runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const
{
TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1,
"getUniquePrimaryPool is only supported for a single window size");
return mBlockManager.getPrimaryPool(0);
}

runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const
{
return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx));
Expand Down Expand Up @@ -2460,4 +2493,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
}

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(SRCS
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/llmRequest.cpp
executor/bindings.cpp
Expand Down
48 changes: 48 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"

#include <nanobind/trampoline.h>
#include <torch/extension.h>

namespace
{
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;

namespace tb = tensorrt_llm::batch_manager;

class PyKvCacheConnectorManager : KvCacheConnectorManager
{
public:
NB_TRAMPOLINE(KvCacheConnectorManager, 1);

SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
{
NB_OVERRIDE_PURE_NAME("get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens);
}
};

} // namespace

void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(nb::module_& m)
{
nb::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager>(m, "KvCacheConnectorManager")
.def(nb::init<>())
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
nb::arg("request"), nb::arg("num_computed_tokens"));
}
39 changes: 39 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include <nanobind/nanobind.h>

namespace nb = nanobind;

namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerConnectorBindings
{
public:
static void initBindings(nb::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

namespace tensorrt_llm::pybind::batch_manager::kv_connector
{

using namespace tensorrt_llm::batch_manager::kv_connector;

} // namespace tensorrt_llm::pybind::batch_manager::kv_connector
Loading