Skip to content

Commit ee471df

Browse files
authored
[None][chore] optimize kv cache transfer for context TEP and gen DEP (#6657)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 3e41e6c commit ee471df

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
7575
bool CacheFormatter::needSendCache(
7676
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
7777
{
78-
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
7978
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
8079
if (targetInfo.mDupHeadFactor <= 1)
8180
{
@@ -90,8 +89,9 @@ bool CacheFormatter::needSendCache(
9089
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
9190
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
9291
}
92+
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
9393

94-
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
94+
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
9595
}
9696

9797
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
@@ -128,11 +128,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
128128
return ret;
129129
}
130130
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
131+
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
131132

132133
std::vector<size_t> ret;
133134
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
134135
{
135-
if (i % targetInfo.mPeerDupHeadFactor == 0)
136+
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
136137
{
137138
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
138139
{

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
4545
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
4646
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
4747
std::vector<size_t> ret;
48-
// targetInfo , mRanks [tpranks, dpranks]
48+
// targetInfo , mRanks [tpranks, ppranks]
49+
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
50+
4951
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
5052
{
51-
ret.push_back(i);
53+
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
5254
}
5355
return ret;
5456
}
@@ -58,19 +60,24 @@ bool MLACacheFormatter::needSendCache(
5860
{
5961
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
6062

63+
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
64+
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
65+
: destConfig.getParallelConfig().mTensorParallelism;
66+
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
67+
6168
if (selfConfig.getParallelConfig().mEnableAttentionDP)
6269
{
6370
int selfTPNumInDPGroup
6471
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
65-
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
66-
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
67-
: destConfig.getParallelConfig().mTensorParallelism;
72+
6873
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
6974
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
7075
{
7176
return true;
7277
}
73-
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
78+
79+
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
80+
return selfTPrankINDPGroup % dupHeadFactor == destDPRank;
7481
}
7582

7683
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
@@ -81,7 +88,8 @@ bool MLACacheFormatter::needSendCache(
8188
{
8289
return true;
8390
}
84-
return selfTpRank % (selfTPNum / destTPNum) == 0;
91+
int dupHeadFactor = selfTPNum / destTPNum;
92+
return selfTpRank % dupHeadFactor == destDPRank;
8593
}
8694

8795
void MLACacheFormatter::format(TransferSession& session)

cpp/tests/batch_manager/cacheTransceiverTest.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,12 +1457,15 @@ TEST(targetTest, CacheStateNODP)
14571457

14581458
verifyContext(
14591459
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1460+
14601461
verifyContext(
14611462
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1463+
14621464
verifyContext(
14631465
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
14641466
verifyContext(
14651467
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1468+
14661469
verifyContext(
14671470
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
14681471
verifyContext(
@@ -1474,7 +1477,6 @@ TEST(targetTest, CacheStateNODP)
14741477

14751478
contextTP = 2;
14761479
genTP = 4;
1477-
14781480
verifyContext(
14791481
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
14801482
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
@@ -1564,13 +1566,13 @@ TEST(targetTest, CacheStateContextDP)
15641566
/*expectNeedSend*/ true);
15651567
verifyContext(
15661568
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1567-
/*expectNeedSend*/ true);
1569+
/*expectNeedSend*/ false);
15681570
verifyContext(
15691571
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
15701572
/*expectNeedSend*/ false);
15711573
verifyContext(
15721574
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1573-
/*expectNeedSend*/ false);
1575+
/*expectNeedSend*/ true);
15741576
verifyContext(
15751577
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
15761578
/*expectNeedSend*/ false);

0 commit comments

Comments
 (0)