Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 65 additions & 44 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,39 @@
namespace tensorrt_llm::batch_manager::kv_cache_manager
{

int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict)
{
TLLM_CHECK(cpRank >= 0 && cpRank < cpSize);
if (cpSize == 1)
{
return numTotalBlocks;
}
// NOTE: Non-strict mode may over-allocate blocks when numTotalBlocks is not divisible by cpSize.
// This is a known limitation and will be addressed in a future MR.
if (!strict)
{
// Simple ceiling division.
return (numTotalBlocks + cpSize - 1) / cpSize;
}
// In strict mode, blocks are distributed among CP ranks in a round-robin fashion as evenly as possible.
// When the number of blocks is not divisible by cpSize, the remainder shall be distributed evenly among
// lowest-indexed CP ranks (let's call them overflow ranks).
int numBlocksCurrRank = numTotalBlocks / cpSize;
if (numTotalBlocks % cpSize > cpRank)
{
numBlocksCurrRank++;
}
return numBlocksCurrRank;
}

// some context rank in connection
std::vector<size_t> MLACacheFormatter::pickRecvConnections(
size_t numConnections, CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
{

auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
// This function is called only by gen side and we only support CPSize=1 on context size.
TLLM_CHECK(targetInfo.mDomainCPSize == 1);
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
std::vector<size_t> ret;
// targetInfo , mRanks [tpranks, dpranks]
Expand Down Expand Up @@ -97,14 +124,11 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
auto& bufferManager = session.getBufferManager();
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
TLLM_CHECK(!connections.empty());
// diff start
if (!needSendCache(selfConfig, destConfig, selfIdx))
{
return;
}

// diff end

auto const numPools = mCacheManager->getBlockManager().getNumPools();
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);

Expand Down Expand Up @@ -147,43 +171,48 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
return;
}

auto cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();

auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
// diff start

auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
auto ppRank = selfIdx
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
int selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
size_t pPDomainSize = targetInfo.mDomainPPSize;
size_t cPDomainSize = targetInfo.mDomainCPSize;

auto getBufferSizeForTarget = [&]()
{
std::vector<size_t> bufferSizeForTarget(pPDomainSize, 0);
size_t cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
for (size_t i = 0; i < pPDomainSize; i++)
auto const ppRank = selfIdx
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
auto const selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
auto const cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();
auto const blockSizePerLayer = cacheBlockSize / selfAttentionLayerNum;
std::vector<size_t> bufferSizeForTarget(pPDomainSize * cPDomainSize, 0);
for (size_t ppDomainIdx = 0; ppDomainIdx < pPDomainSize; ppDomainIdx++)
{
auto layerNum = targetInfo.getPeerPPDomainLayerNum(i);
bufferSizeForTarget[i] = cacheSizePerLayer * layerNum;
auto const peerAttentionLayerNum = targetInfo.getPeerPPDomainLayerNum(ppDomainIdx);
for (size_t cpDomainIdx = 0; cpDomainIdx < cPDomainSize; cpDomainIdx++)
{
auto const idx = cpDomainIdx * pPDomainSize + ppDomainIdx;
// Note: contextCP is always 1. So, cpDomainSize == genCPSize and cpDomainIdx == genCPRank.
auto const peerBlockNum
= getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum, /*strict=*/false);
bufferSizeForTarget[idx] = blockSizePerLayer * peerAttentionLayerNum * peerBlockNum;
}
}
return bufferSizeForTarget;
};
auto bufferEleSizes = getBufferSizeForTarget();
auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, static_cast<int>(pPDomainSize), bufferEleSizes, bufferManager);
cacheBufferId, static_cast<int>(pPDomainSize * cPDomainSize), bufferEleSizes, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
auto& bufferCoverTargetNum = std::get<1>(result);
auto& onlyUseDynamicBuffer = std::get<2>(result);
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnecion != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK_WITH_INFO(
bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
}
// diff end

// The size of outputSplitCaches should be equal to pPDomainSize

// The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
Expand All @@ -203,7 +232,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses

TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto cacheIdx = processIdx % pPDomainSize;
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
if (cacheIdx < bufferCoverTargetNum)
{
size_t size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
Expand Down Expand Up @@ -259,7 +288,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
else
{
// concurrency num
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize);
auto concurrencyNum
= std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize * cPDomainSize);

auto remainSendNum = connections.size();

Expand Down Expand Up @@ -307,9 +337,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
auto& bufferManager = session.getBufferManager();
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
// diff start
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
// diff end
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest);
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
Expand Down Expand Up @@ -364,23 +392,24 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv();
}

auto cacheBlockSize = outputBuffers.at(0)->getSize();

auto targetNum = pickUpConnections.size();
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
auto ppRank = selfIdx
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
auto selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0");

auto getBufferSizeForTarget = [&]()
{
auto const targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
auto const cacheBlockSize = outputBuffers.at(0)->getSize();
auto const ppRank = selfIdx
/ (selfConfig.getParallelConfig().mTensorParallelism
* selfConfig.getParallelConfig().mContextParallelism);
auto const selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0");
std::vector<size_t> bufferEleSizes(targetNum, 0);
auto cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
auto const cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
for (size_t i = 0; i < targetNum; i++)
{
auto layerNum = targetInfo.getPeerPPDomainLayerNum(static_cast<SizeType32>(pickUpConnections[i]));
bufferEleSizes[i] = cacheSizePerLayer * layerNum;
auto const peerAttentionLayerNum
= targetInfo.getPeerPPDomainLayerNum(static_cast<SizeType32>(pickUpConnections[i]));
bufferEleSizes[i] = cacheSizePerLayer * peerAttentionLayerNum;
}
return bufferEleSizes;
};
Expand Down Expand Up @@ -506,7 +535,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
outputCachesPerWindow.emplace(window, outputBuffers);
NVTX3_SCOPED_RANGE(formatInputConcatenate);

// recvSplitCaches size == ppdomainsize
// recvSplitCaches size == ppdomainsize * cpdomainsize.
executor::kv_cache::concatKvCacheV2Dispatch(
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
}
Expand Down Expand Up @@ -581,14 +610,6 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
Expand Down
18 changes: 18 additions & 0 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@
namespace tensorrt_llm::batch_manager::kv_cache_manager
{

/**
* @brief Calculate the number of blocks allocated to a specific Context Parallelism (CP) rank.
*
* This function determines how many blocks should be allocated to a given CP rank when
* distributing a total number of blocks across multiple CP ranks. It supports two distribution
* modes: strict and non-strict.
*
* @param cpRank The rank (index) of the current CP process. Must be in range [0, cpSize).
* @param cpSize The total number of CP ranks/processes in the parallel group.
* @param numTotalBlocks The total number of blocks to be distributed across all CP ranks.
* @param strict Flag controlling the distribution strategy:
* - true: Use strict round-robin distribution with exact allocation
* - false: Use ceiling division which may over-allocate
*
* @return The number of blocks allocated to the specified CP rank.
*/
int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict);

// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
// parallel topology is completely identical, making it the preferred method.
class MLACacheFormatter final : public BaseCacheFormatter
Expand Down
Loading