Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ CacheTransBufferManager::CacheTransBufferManager(
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
TLLM_LOG_INFO(
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
"mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d "
"mUseFabricMemory:%d mDataType:%d",
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);
Expand Down
38 changes: 14 additions & 24 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,11 @@ void MLACacheFormatter::format(TransferSession& session)
auto& bufferManager = session.getBufferManager();
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
TLLM_CHECK(!connections.empty());
// diff start
if (!needSendCache(selfConfig, destConfig, selfIdx))
{
return;
}

// diff end

auto const numPools = mCacheManager->getBlockManager().getNumPools();
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);

Expand Down Expand Up @@ -150,28 +147,28 @@ void MLACacheFormatter::format(TransferSession& session)
auto cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();

auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
// diff start

auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);

size_t const pPDomainSize = targetInfo.mDomainPPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu", cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize, targetBufferSize, bufferManager);
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
Comment on lines 153 to 161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard: ensure connections.size() matches PP×CP before modulo mapping.

Add a sanity check to prevent modulo on a smaller connection set, which would scatter to wrong peers.

     auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
     size_t const pPDomainSize = targetInfo.mDomainPPSize;
     size_t const cPDomainSize = targetInfo.mDomainCPSize;
+    TLLM_CHECK_WITH_INFO(
+        connections.size() == pPDomainSize * cPDomainSize,
+        "Mismatch: number of connections (%zu) must equal PP×CP (%zu).",
+        connections.size(), pPDomainSize * cPDomainSize);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
size_t const pPDomainSize = targetInfo.mDomainPPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu", cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize, targetBufferSize, bufferManager);
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
size_t const pPDomainSize = targetInfo.mDomainPPSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK_WITH_INFO(
connections.size() == pPDomainSize * cPDomainSize,
"Mismatch: number of connections (%zu) must equal PP×CP (%zu).",
connections.size(), pPDomainSize * cPDomainSize);
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu",
cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
🤖 Prompt for AI Agents
cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp around lines 153 to 161:
the code assumes the number of connections equals pPDomainSize * cPDomainSize
before doing modulo-based mapping which can scatter to wrong peers if
connections is smaller; add a sanity guard that verifies connections.size() >=
pPDomainSize * cPDomainSize (or == if strict) and fail fast with an explanatory
TLLM_CHECK or error log if the condition is not met, before computing
targetBufferSize and calling getOrAllocateSendBuffers; ensure the check prevents
division/modulo mapping against a smaller connection set and include minimal
context in the error message (e.g., actual sizes) so callers can debug.

auto& bufferCoverTargetNum = std::get<1>(result);
auto& onlyUseDynamicBuffer = std::get<2>(result);
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnecion != nullptr)
auto* agentConnnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnection != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
}
// diff end

// The size of outputSplitCaches should be equal to pPDomainSize

// The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
Expand All @@ -191,8 +188,9 @@ void MLACacheFormatter::format(TransferSession& session)

TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto cacheIdx = processIdx % pPDomainSize;
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
size_t size;
// @B: What does this check mean?
if (cacheIdx < bufferCoverTargetNum)
{
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
Expand Down Expand Up @@ -252,7 +250,7 @@ void MLACacheFormatter::format(TransferSession& session)
else
{
// concurrency num
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize);
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize * cPDomainSize);

auto remainSendNum = connections.size();

Expand Down Expand Up @@ -489,7 +487,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
outputCachesPerWindow.emplace(window, outputBuffers);
NVTX3_SCOPED_RANGE(formatInputConcatenate);

// recvSplitCaches size == ppdomainsize
// recvSplitCaches size == ppdomainsize * cPDomainSize.
executor::kv_cache::concatKvCacheV2Dispatch(
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
}
Expand Down Expand Up @@ -564,14 +562,6 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
Expand Down
77 changes: 56 additions & 21 deletions cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ TargetRanksInfo TargetRanksInfoForDP(
auto const selfPPNum = selfParConfig.mPipelineParallelism;
auto const peerTPNum = peerParConfig.mTensorParallelism;
auto const selfTPNum = selfParConfig.mTensorParallelism;
auto const peerCPNum = peerParConfig.mContextParallelism;
auto const selfCPNum = selfParConfig.mContextParallelism;

auto const selfTPRank = selfRank % selfParConfig.mTensorParallelism;
auto const selfPPRank = selfRank / selfParConfig.mTensorParallelism;
auto const selfTPRank = selfRank % selfTPNum;
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;

int peerPPRankStart = 0;
int mDomainPPSize = 1;
Expand Down Expand Up @@ -108,13 +111,35 @@ TargetRanksInfo TargetRanksInfoForDP(
peerTPRankEnd = peerTPRankStart + mDomainTPSize;
}

int mDomainCPSize = 1;
int peerCPRankStart = 0;
int peerCPRankEnd = 0;
for (auto val : {peerCPNum, selfCPNum})
{
TLLM_CHECK(isPowerOfTwo(val));
}
if (selfCPNum <= peerCPNum)
{
mDomainCPSize = peerCPNum / selfCPNum;
peerCPRankStart = selfCPRank * mDomainCPSize;
peerCPRankEnd = (selfCPRank + 1) * mDomainCPSize;
}
else
{
peerCPRankStart = selfCPRank / (selfCPNum / peerCPNum);
peerCPRankEnd = peerCPRankStart + mDomainCPSize;
}

std::vector<int> retRanks;
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
{
for (int j = peerPPRankStart; j < peerPPRankEnd; j++)
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
{
int irank = j * peerTPNum + i;
retRanks.push_back(irank);
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
{
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
retRanks.push_back(irank);
}
}
}

Expand All @@ -131,7 +156,7 @@ TargetRanksInfo TargetRanksInfoForDP(
= (peerNbHeadsPerLayer * peerTPSizePerDPGroup) / (selfNbHeadsPerLayer * selfTPSizePerDPGroup);
}

return {mDomainPPSize, mDomainTPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
return {mDomainPPSize, mDomainTPSize, mDomainCPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
}

TargetRanksInfo targetIRanks(
Expand Down Expand Up @@ -472,11 +497,11 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState)
}

// MLA Head 1: One thread block per [(2), tokens, dimsPerHead]

// @B: Why do we not use Domain{P,T,C}PSize?
template <typename T, int subWarpSize, int vecSizeByte>
__global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches,
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize,
int DomainTPSize, int layerNumDomainPP, int kvFactor)
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int domainPPSize,
int domainTPSize, int domainCPSize, int layerNumDomainPP, int kvFactor)
{
int const subWarpId = threadIdx.x / subWarpSize;
int const laneId = threadIdx.x % subWarpSize;
Expand All @@ -496,16 +521,20 @@ __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T**
#pragma unroll 1
for (int headId = 0; headId < headNum; headId++)
{
// Input block memory layout: [Layer][KV_Factor][Head][Token][Dimension].
T const* inputBlockPtr = inputBlocks[blockId];
T const* kInputPtr = inputBlockPtr + layerId * kvFactor * headNum * tokensPerBlock * dimsPerHead
+ headId * tokensPerBlock * dimsPerHead;
int const outputCacheIdx = layerId / layerNumDomainPP;
int const outputCacheIdx = (layerId / layerNumDomainPP) * domainCPSize + blockId % domainCPSize;
TLLM_LOG_INFO("[splitKVCacheForMLAKernel] layerId: %d, blockId: %d, outputCacheIdx: %d", layerId, blockId, outputCacheIdx);
T* outputCachePtr = outputCaches[outputCacheIdx];
int const layerIdInDomainPP = layerId % layerNumDomainPP;
int const headIdInDomainTP = headId;
int const blockIdInDomainCP = blockId / domainCPSize;

// Unlike inputBlocks which are non-contiguous, blocks in outputCachePtr are contiguous.
T* kOutputPtr = outputCachePtr
+ blockId * (layerNumDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
+ blockIdInDomainCP * (layerNumDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead)
+ layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead
+ headIdInDomainTP * tokensPerBlock * dimsPerHead;
int const kvOffset = headNum * tokensPerBlock * dimsPerHead;
Expand Down Expand Up @@ -905,11 +934,16 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
}
auto targetRankInfo = targetIRanks(destCacheState, selfCacheState, selfIdx);
TLLM_CHECK(targetRankInfo.mIRanks.size()
== (static_cast<size_t>(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize)));
== (static_cast<size_t>(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize * targetRankInfo.mDomainCPSize)));
TLLM_LOG_INFO("[splitKVCache] targetRankInfo.mIRanks.size(): %d", targetRankInfo.mIRanks.size());
for (auto rank : targetRankInfo.mIRanks)
{
TLLM_LOG_INFO("[splitKVCache] target rank: %d, ", rank);
}
auto outputCacheNum = targetRankInfo.mIRanks.size();
if (selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA)
{
outputCacheNum = targetRankInfo.mDomainPPSize;
outputCacheNum = targetRankInfo.mDomainPPSize * targetRankInfo.mDomainCPSize;
}
else
{
Expand All @@ -929,6 +963,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
auto cacheBlockSize = blocks.front()->getSize();
auto cacheDataType = blocks.front()->getDataType();
TLLM_LOG_DEBUG("[splitKVCache] cacheBlockSize: %zu, cacheDataType: %d", cacheBlockSize, cacheDataType);
windowSizes.push_back(window);
blockNumInwindow.push_back(blocks.size());
TLLM_LOG_DEBUG("window: %d, blockNum: %d blockshape:[%d,%d]", window, blocks.size(),
Expand Down Expand Up @@ -972,7 +1007,6 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>

for (auto layerNum : layersInWindow)
{

TLLM_CHECK_WITH_INFO(
layerNum % targetRankInfo.mDomainPPSize == 0, "layerNum in Window must be divisible by domainPPSize");
}
Expand Down Expand Up @@ -1018,6 +1052,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
int const dimsPerHead = selfModelConfig.mSizePerHead;
int const DomainPPSize = targetRankInfo.mDomainPPSize;
int const DomainTPSize = targetRankInfo.mDomainTPSize;
int const DomainCPSize = targetRankInfo.mDomainCPSize;
int const layerNumDomainPP = numLayers / DomainPPSize;
int const headNumDomainTP
= headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); // TODO: duplicate head factor
Expand All @@ -1026,9 +1061,9 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
constexpr int mlaSubWarpSize = 16;

TLLM_LOG_DEBUG(
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, "
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, domainCPSize: %d, "
"layersPerDomainPP: %d, headsPerDomainTP: %d",
numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP);
numLayers, headNum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, headNumDomainTP);

int const remainder = sizePerHead * sizeof(T) % 16;
switch (remainder)
Expand All @@ -1039,7 +1074,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 16><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand All @@ -1063,7 +1098,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 8><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1091,7 +1126,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 4>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1124,7 +1159,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 2>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1153,7 +1188,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 1>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct TargetRanksInfo
{
int mDomainPPSize;
int mDomainTPSize;
int mDomainCPSize;
std::vector<int> mIRanks;
int mDupHeadFactor;
int mPeerDupHeadFactor;
Expand Down
Loading
Loading