Skip to content

Commit e7865b3

Browse files
committed
None: Update TargetsInfo to support CP in disagg later
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent a419b77 commit e7865b3

File tree

3 files changed

+292
-57
lines changed

3 files changed

+292
-57
lines changed

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ TargetRanksInfo TargetRanksInfoForDP(
5757
auto const selfPPNum = selfParConfig.mPipelineParallelism;
5858
auto const peerTPNum = peerParConfig.mTensorParallelism;
5959
auto const selfTPNum = selfParConfig.mTensorParallelism;
60+
auto const peerCPNum = peerParConfig.mContextParallelism;
61+
auto const selfCPNum = selfParConfig.mContextParallelism;
6062

61-
auto const selfTPRank = selfRank % selfParConfig.mTensorParallelism;
62-
auto const selfPPRank = selfRank / selfParConfig.mTensorParallelism;
63+
auto const selfTPRank = selfRank % selfTPNum;
64+
auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
65+
auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
6366

6467
int peerPPRankStart = 0;
6568
int mDomainPPSize = 1;
@@ -108,13 +111,35 @@ TargetRanksInfo TargetRanksInfoForDP(
108111
peerTPRankEnd = peerTPRankStart + mDomainTPSize;
109112
}
110113

114+
int mDomainCPSize = 1;
115+
int peerCPRankStart = 0;
116+
int peerCPRankEnd = 0;
117+
for (auto val : {peerCPNum, selfCPNum})
118+
{
119+
TLLM_CHECK(isPowerOfTwo(val));
120+
}
121+
if (selfCPNum <= peerCPNum)
122+
{
123+
mDomainCPSize = peerCPNum / selfCPNum;
124+
peerCPRankStart = selfCPRank * mDomainCPSize;
125+
peerCPRankEnd = (selfCPRank + 1) * mDomainCPSize;
126+
}
127+
else
128+
{
129+
peerCPRankStart = selfCPRank / (selfCPNum / peerCPNum);
130+
peerCPRankEnd = peerCPRankStart + mDomainCPSize;
131+
}
132+
111133
std::vector<int> retRanks;
112134
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
113135
{
114-
for (int j = peerPPRankStart; j < peerPPRankEnd; j++)
136+
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
115137
{
116-
int irank = j * peerTPNum + i;
117-
retRanks.push_back(irank);
138+
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
139+
{
140+
int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
141+
retRanks.push_back(irank);
142+
}
118143
}
119144
}
120145

@@ -131,7 +156,7 @@ TargetRanksInfo TargetRanksInfoForDP(
131156
= (peerNbHeadsPerLayer * peerTPSizePerDPGroup) / (selfNbHeadsPerLayer * selfTPSizePerDPGroup);
132157
}
133158

134-
return {mDomainPPSize, mDomainTPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
159+
return {mDomainPPSize, mDomainTPSize, mDomainCPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor};
135160
}
136161

137162
TargetRanksInfo targetIRanks(

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct TargetRanksInfo
3636
{
3737
int mDomainPPSize;
3838
int mDomainTPSize;
39+
int mDomainCPSize;
3940
std::vector<int> mIRanks;
4041
int mDupHeadFactor;
4142
int mPeerDupHeadFactor;

0 commit comments

Comments
 (0)