Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,26 @@ void MLACacheFormatter::format(TransferSession& session)
auto cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();

auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
// diff start

auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);

size_t const pPDomainSize = targetInfo.mDomainPPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize, targetBufferSize, bufferManager);
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
auto& bufferCoverTargetNum = std::get<1>(result);
auto& onlyUseDynamicBuffer = std::get<2>(result);
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnecion != nullptr)
auto* agentConnnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnection != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
}
// diff end

// The size of outputSplitCaches should be equal to pPDomainSize

// The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
Expand All @@ -191,8 +189,9 @@ void MLACacheFormatter::format(TransferSession& session)

TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto cacheIdx = processIdx % pPDomainSize;
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
size_t size;
// @B: What does this check mean?
if (cacheIdx < bufferCoverTargetNum)
{
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
Expand Down Expand Up @@ -252,7 +251,7 @@ void MLACacheFormatter::format(TransferSession& session)
else
{
// concurrency num
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize);
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize * cPDomainSize);

auto remainSendNum = connections.size();

Expand Down Expand Up @@ -489,7 +488,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
outputCachesPerWindow.emplace(window, outputBuffers);
NVTX3_SCOPED_RANGE(formatInputConcatenate);

// recvSplitCaches size == ppdomainsize
// recvSplitCaches size == ppdomainsize * cPDomainSize.
executor::kv_cache::concatKvCacheV2Dispatch(
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
}
Expand Down Expand Up @@ -564,14 +563,6 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
Expand Down
65 changes: 48 additions & 17 deletions cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ TargetRanksInfo TargetRanksInfoForDP(
auto const selfPPNum = selfParConfig.mPipelineParallelism;
auto const peerTPNum = peerParConfig.mTensorParallelism;
auto const selfTPNum = selfParConfig.mTensorParallelism;
auto const peerCPNum = peerParConfig.mContextParallelism;
auto const selfCPNum = selfParConfig.mContextParallelism;

auto const selfTPRank = selfRank % selfParConfig.mTensorParallelism;
auto const selfPPRank = selfRank / selfParConfig.mTensorParallelism;
auto const selfTPRank = selfRank % selfTPNum;
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;

int peerPPRankStart = 0;
int mDomainPPSize = 1;
Expand Down Expand Up @@ -108,13 +111,35 @@ TargetRanksInfo TargetRanksInfoForDP(
peerTPRankEnd = peerTPRankStart + mDomainTPSize;
}

int mDomainCPSize = 1;
int peerCPRankStart = 0;
int peerCPRankEnd = 0;
for (auto val : {peerCPNum, selfCPNum})
{
TLLM_CHECK(isPowerOfTwo(val));
}
if (selfCPNum <= peerCPNum)
{
mDomainCPSize = peerCPNum / selfCPNum;
peerCPRankStart = selfCPRank * mDomainCPSize;
peerCPRankEnd = (selfCPRank + 1) * mDomainCPSize;
}
else
{
peerCPRankStart = selfCPRank / (selfCPNum / peerCPNum);
peerCPRankEnd = peerCPRankStart + mDomainCPSize;
}

std::vector<int> retRanks;
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
{
for (int j = peerPPRankStart; j < peerPPRankEnd; j++)
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
{
int irank = j * peerTPNum + i;
retRanks.push_back(irank);
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
{
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
retRanks.push_back(irank);
}
}
}

Expand All @@ -131,7 +156,7 @@ TargetRanksInfo TargetRanksInfoForDP(
= (peerNbHeadsPerLayer * peerTPSizePerDPGroup) / (selfNbHeadsPerLayer * selfTPSizePerDPGroup);
}

return {mDomainPPSize, mDomainTPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
return {mDomainPPSize, mDomainTPSize, mDomainCPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
}

TargetRanksInfo targetIRanks(
Expand Down Expand Up @@ -476,7 +501,7 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState)
template <typename T, int subWarpSize, int vecSizeByte>
__global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches,
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize,
int DomainTPSize, int layerNumDomainPP, int kvFactor)
int DomainTPSize, int DomainCPSize, int layerNumDomainPP, int kvFactor)
{
int const subWarpId = threadIdx.x / subWarpSize;
int const laneId = threadIdx.x % subWarpSize;
Expand Down Expand Up @@ -905,11 +930,16 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
}
auto targetRankInfo = targetIRanks(destCacheState, selfCacheState, selfIdx);
TLLM_CHECK(targetRankInfo.mIRanks.size()
== (static_cast<size_t>(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize)));
== (static_cast<size_t>(targetRankInfo.mDomainPPSize * targetRankInfo.mDomainTPSize * targetRankInfo.mDomainCPSize)));
TLLM_LOG_INFO("[splitKVCache] targetRankInfo.mIRanks.size(): %d", targetRankInfo.mIRanks.size());
for (auto rank : targetRankInfo.mIRanks)
{
TLLM_LOG_INFO("[splitKVCache] target rank: %d, ", rank);
}
auto outputCacheNum = targetRankInfo.mIRanks.size();
if (selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA)
{
outputCacheNum = targetRankInfo.mDomainPPSize;
outputCacheNum = targetRankInfo.mDomainPPSize * targetRankInfo.mDomainCPSize;
}
else
{
Expand All @@ -929,6 +959,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
auto cacheBlockSize = blocks.front()->getSize();
auto cacheDataType = blocks.front()->getDataType();
TLLM_LOG_DEBUG("[splitKVCache] cacheBlockSize: %zu, cacheDataType: %d", cacheBlockSize, cacheDataType);
windowSizes.push_back(window);
blockNumInwindow.push_back(blocks.size());
TLLM_LOG_DEBUG("window: %d, blockNum: %d blockshape:[%d,%d]", window, blocks.size(),
Expand Down Expand Up @@ -972,7 +1003,6 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>

for (auto layerNum : layersInWindow)
{

TLLM_CHECK_WITH_INFO(
layerNum % targetRankInfo.mDomainPPSize == 0, "layerNum in Window must be divisible by domainPPSize");
}
Expand Down Expand Up @@ -1018,6 +1048,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
int const dimsPerHead = selfModelConfig.mSizePerHead;
int const DomainPPSize = targetRankInfo.mDomainPPSize;
int const DomainTPSize = targetRankInfo.mDomainTPSize;
int const DomainCPSize = targetRankInfo.mDomainCPSize;
int const layerNumDomainPP = numLayers / DomainPPSize;
int const headNumDomainTP
= headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); // TODO: duplicate head factor
Expand All @@ -1026,9 +1057,9 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
constexpr int mlaSubWarpSize = 16;

TLLM_LOG_DEBUG(
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, "
"splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, domainCPSize: %d, "
"layersPerDomainPP: %d, headsPerDomainTP: %d",
numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP);
numLayers, headNum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, headNumDomainTP);

int const remainder = sizePerHead * sizeof(T) % 16;
switch (remainder)
Expand All @@ -1039,7 +1070,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 16><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand All @@ -1063,7 +1094,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
{
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 8><<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(
inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead,
inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor);
inputBlockNumSum, DomainPPSize, DomainTPSize, DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1091,7 +1122,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 4>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1124,7 +1155,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 2>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down Expand Up @@ -1153,7 +1184,7 @@ void splitKVCache(std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>>
splitKVCacheForMLAKernel<T, mlaSubWarpSize, 1>
<<<gridDim, blockDimx, 0, bufferManager.getStream().get()>>>(inputBlockPtrsDev, outputCachePtrsDev,
tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize,
layerNumDomainPP, kvFactor);
DomainCPSize, layerNumDomainPP, kvFactor);
}
else if (isWindow)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct TargetRanksInfo
{
int mDomainPPSize;
int mDomainTPSize;
int mDomainCPSize;
std::vector<int> mIRanks;
int mDupHeadFactor;
int mPeerDupHeadFactor;
Expand Down
Loading
Loading