Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/cpp/disaggServerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,8 @@ class DisaggExecutorServer
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices));
executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig);
executorConfig.setCacheTransceiverConfig(
texec::CacheTransceiverConfig(texec::CacheTransceiverConfig::BackendType::DEFAULT));
constexpr int maxIterationsForRequestStats = 1000;
if (mEnableCollectKvCacheTransferTime)
{
Expand Down
19 changes: 5 additions & 14 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,20 @@ class BaseCacheTransceiver
class CacheTransceiver : public BaseCacheTransceiver
{
public:
enum class CommType : std::uint8_t
{
UNKNOWN = 0,
MPI = 1,
UCX = 2,
NIXL = 3
};

CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager,
executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType dataType,
executor::kv_cache::CacheState::AttentionType attentionType
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);

CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
std::vector<SizeType32> numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType,
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector<SizeType32> numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType dataType,
executor::kv_cache::CacheState::AttentionType attentionType
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt)
: CacheTransceiver(cacheManager, commType,
: CacheTransceiver(cacheManager,
executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig,
dataType, attentionType, cacheTransceiverConfig)
{
Expand All @@ -118,7 +110,6 @@ class CacheTransceiver : public BaseCacheTransceiver

void setContextState(LlmRequest* llmRequest);

CommType mCommType;
std::unique_ptr<DataResponder> mDataResponder;
std::unique_ptr<DataRequester> mDataRequester;
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
Expand Down
19 changes: 15 additions & 4 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1430,18 +1430,29 @@ class LogitsPostProcessorConfig
class CacheTransceiverConfig
{
public:
explicit CacheTransceiverConfig(std::optional<size_t> maxNumTokens = std::nullopt);
enum class BackendType : std::uint8_t
{
DEFAULT = 0,
MPI = 1,
UCX = 2,
NIXL = 3
};
explicit CacheTransceiverConfig(
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);

bool operator==(CacheTransceiverConfig const& other) const;
void setBackendType(std::optional<BackendType> backendType);
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);

[[nodiscard]] std::optional<size_t> getMaxNumTokens() const;
void setMaxNumTokens(size_t maxNumTokens);
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
[[nodiscard]] std::optional<BackendType> getBackendType() const;

private:
std::optional<BackendType> mBackendType;
/// @brief The maximum number of tokens that the CacheTransceiver's pre-allocated buffer can hold. If the number of
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
/// transfer may be degraded.
std::optional<size_t> mMaxNumTokens;
std::optional<size_t> mMaxTokensInBuffer;
};

/// @brief Configuration class for the model executor
Expand Down
37 changes: 28 additions & 9 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ CacheTransBufferManager::CacheTransBufferManager(
{
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
auto validTokenNum = windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value();
auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value());
bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
}
}
Expand All @@ -230,26 +230,37 @@ CacheTransBufferManager::CacheTransBufferManager(
TLLM_LOG_INFO(
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
"mUseFabricMemory:%d",
"mUseFabricMemory:%d mDataType:%d",
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory);
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);

TLLM_CHECK_WITH_INFO(to_allocate, "CacheTransBufferManager: to_allocate is false");
allocateBuffer();
}

size_t CacheTransBufferManager::preAllocBufferSize(std::optional<size_t> maxNumTokens)
size_t CacheTransBufferManager::preAllocBufferSize(
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig)
{
bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache();
if (!to_allocate)
if (!cacheTransceiverConfig.has_value())
{
return 0;
}
if (!cacheTransceiverConfig->getBackendType().has_value())
{
return 0;
}
auto maxNumTokens = cacheTransceiverConfig->getMaxTokensInBuffer();
size_t TransferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer();
if (maxNumTokens.has_value())
{
TransferBufferSize = maxNumTokens.value();
TransferBufferSize = 0;
for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow)
{
auto validTokenNum
= (static_cast<size_t>(windowSize) < maxNumTokens.value() ? static_cast<size_t>(windowSize)
: maxNumTokens.value());
TransferBufferSize += validTokenNum * cacheSizeBytesPerToken;
}
}
bool useFabricMemory = FabricMemory::supportFbaricMemory()
&& (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()));
Expand Down Expand Up @@ -329,6 +340,14 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
size_t bufferCoverTargetNum = std::min(
static_cast<size_t>(targetNum), mTransferBufferSize / (targetBufferEleSize * common::getDTypeSize(mDataType)));
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
{
TLLM_LOG_WARNING(
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic buffer, "
"it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may "
"be degraded",
bufferCoverTargetNum, targetNum);
}
if (bufferId.has_value())
{
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <atomic>
Expand Down Expand Up @@ -59,7 +60,8 @@ class CacheTransBufferManager
CacheTransBufferManager(
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);

static size_t preAllocBufferSize(std::optional<size_t> maxNumTokens = std::nullopt);
static size_t preAllocBufferSize(std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);

std::optional<int> assignBufferIndexForSend();
void freeBufferIndexForSend(std::optional<int> bufferId);
Expand Down
150 changes: 79 additions & 71 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,49 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
runtime::WorldConfig const& worldConfig, executor::kv_cache::CacheState::AttentionType attentionType,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
{

std::optional<CacheTransceiver::CommType> commType;
if (common::getEnvUseUCXKvCache())
{
commType = CacheTransceiver::CommType::UCX;
TLLM_LOG_INFO("Enable UCX KV cache transport.");
}
else if (common::getEnvUseNixlKvCache())
if (!cacheTransceiverConfig.has_value() || !cacheTransceiverConfig.value().getBackendType().has_value())
{
commType = CacheTransceiver::CommType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
TLLM_LOG_INFO("CacheTransceiver is disabled.");
return nullptr;
}
else if (common::getEnvUseMPIKvCache())
auto backendType = cacheTransceiverConfig.value().getBackendType();
if (backendType.value() == executor::CacheTransceiverConfig::BackendType::DEFAULT)
{
commType = CacheTransceiver::CommType::MPI;
TLLM_LOG_INFO("Enable MPI KV cache transport.");
if (common::getEnvUseUCXKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::UCX;
TLLM_LOG_INFO("Enable UCX KV cache transport.");
}
else if (common::getEnvUseNixlKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
}
else if (common::getEnvUseMPIKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
TLLM_LOG_INFO("Enable MPI KV cache transport.");
TLLM_LOG_WARNING("MPI KV cache transport is deprecated, please use UCX or NIXL instead.");
}
else
{
backendType = executor::CacheTransceiverConfig::BackendType::UCX;
}
}
cacheTransceiverConfig.value().setBackendType(backendType);

if (commType)
{
executor::kv_cache::CacheState::ModelConfig cacheStateCfg{
modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()};
executor::kv_cache::CacheState::ModelConfig cacheStateCfg{
modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()};

return std::make_unique<CacheTransceiver>(cacheManager, commType.value(), cacheStateCfg, worldConfig,
modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig);
}
return nullptr;
return std::make_unique<CacheTransceiver>(
cacheManager, cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig);
}

CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType,
CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager,
executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig)
: mCommType{commType}
, mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
, mCacheTransceiverConfig{cacheTransceiverConfig}
{
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
Expand Down Expand Up @@ -138,59 +146,59 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
}
}
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
if (mCommType == CommType::MPI || mCommType == CommType::UCX || mCommType == CommType::NIXL)
{
std::optional<size_t> maxNumTokens = std::nullopt;
if (mCacheTransceiverConfig.has_value())
{
maxNumTokens = mCacheTransceiverConfig.value().getMaxNumTokens();
}
mCacheTransBufferManager
= std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
if (mCommType == CommType::UCX)
{
std::lock_guard<std::mutex> lock(mDllMutex);
mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
TLLM_CHECK_WITH_INFO(ret != nullptr,
"Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not "
"built with UCX support, please rebuild in UCX-enabled environment.");
return ret;
};
std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
*(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager");
mManager = makeUcxConnectionManager();
TLLM_LOG_INFO("UCX Connection Manager created");
}
else if (mCommType == CommType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManager.get());
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
mManager = std::make_unique<executor::kv_cache::MpiConnectionManager>(mMpiWorldComm);
TLLM_LOG_INFO("MPI Connection Manager created");
}
TLLM_CHECK_WITH_INFO(mCacheTransceiverConfig.has_value(), "CacheTransceiverConfig is not set.");
auto backendType = mCacheTransceiverConfig.value().getBackendType();
TLLM_CHECK_WITH_INFO(
backendType.has_value() && (backendType.value() != executor::CacheTransceiverConfig::BackendType::DEFAULT),
" CacheTransceiverConfig::BackendType is not set.");

using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
auto makeFormatter = [cacheManager, isMLA, this]()
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
std::optional<size_t> maxNumTokens = mCacheTransceiverConfig.value().getMaxTokensInBuffer();

mDataResponder = std::make_unique<DataResponder>(
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
mDataRequester = std::make_unique<DataRequester>(
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
mCacheTransBufferManager = std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens);
if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX)
{
std::lock_guard<std::mutex> lock(mDllMutex);
mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME);
TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
{
void* ret = dllGetSym(handle, name);
TLLM_CHECK_WITH_INFO(ret != nullptr,
"Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not "
"built with UCX support, please rebuild in UCX-enabled environment.");
return ret;
};
std::unique_ptr<tensorrt_llm::executor::kv_cache::ConnectionManager> (*makeUcxConnectionManager)();
*(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager");
mManager = makeUcxConnectionManager();
TLLM_LOG_INFO("UCX Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManager.get());
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
mManager = std::make_unique<executor::kv_cache::MpiConnectionManager>(mMpiWorldComm);
TLLM_LOG_INFO("MPI Connection Manager created");
}
else
{
TLLM_THROW("Unsupported communication type.");
TLLM_THROW("Unsupported cache transceiver backend type ");
}

using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
auto makeFormatter = [cacheManager, isMLA, this]()
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };

mDataResponder = std::make_unique<DataResponder>(
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
mDataRequester = std::make_unique<DataRequester>(
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));

initializeCommState();
}

Expand Down
9 changes: 2 additions & 7 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2235,13 +2235,8 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
}

auto const extraCostMemoryBytes = extraCostMemory
* std::accumulate(cacheSizeBytesPerTokenPerWindow.cbegin(), cacheSizeBytesPerTokenPerWindow.cend(),
SizeType32{0}, [](SizeType32 acc, auto const cost) { return acc + cost.second; });

TLLM_LOG_DEBUG(
"extraCostMemoryBytes [all windows] [Gib]: %0.2f", extraCostMemoryBytes / static_cast<double>(1 << 30));

TLLM_LOG_DEBUG("extraCostMemory [Gib]: %0.2f", extraCostMemory / static_cast<double>(1 << 30));
allottedPrimaryMemBytes = allottedPrimaryMemBytes - extraCostMemory;
auto const tokensPerBlock = modelConfig.getTokensPerBlock();
auto const calculatePrimaryBlocks
= [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken)
Expand Down
Loading