@@ -57,9 +57,12 @@ TargetRanksInfo TargetRanksInfoForDP(
5757 auto const selfPPNum = selfParConfig.mPipelineParallelism ;
5858 auto const peerTPNum = peerParConfig.mTensorParallelism ;
5959 auto const selfTPNum = selfParConfig.mTensorParallelism ;
60+ auto const peerCPNum = peerParConfig.mContextParallelism ;
61+ auto const selfCPNum = selfParConfig.mContextParallelism ;
6062
61- auto const selfTPRank = selfRank % selfParConfig.mTensorParallelism ;
62- auto const selfPPRank = selfRank / selfParConfig.mTensorParallelism ;
63+ auto const selfTPRank = selfRank % selfTPNum;
64+ auto const selfPPRank = selfRank / (selfTPNum * selfCPNum);
65+ auto const selfCPRank = (selfRank % (selfTPNum * selfCPNum)) / selfTPNum;
6366
6467 int peerPPRankStart = 0 ;
6568 int mDomainPPSize = 1 ;
@@ -108,13 +111,35 @@ TargetRanksInfo TargetRanksInfoForDP(
108111 peerTPRankEnd = peerTPRankStart + mDomainTPSize ;
109112 }
110113
114+ int mDomainCPSize = 1 ;
115+ int peerCPRankStart = 0 ;
116+ int peerCPRankEnd = 0 ;
117+ for (auto val : {peerCPNum, selfCPNum})
118+ {
119+ TLLM_CHECK (isPowerOfTwo (val));
120+ }
121+ if (selfCPNum <= peerCPNum)
122+ {
123+ mDomainCPSize = peerCPNum / selfCPNum;
124+ peerCPRankStart = selfCPRank * mDomainCPSize ;
125+ peerCPRankEnd = (selfCPRank + 1 ) * mDomainCPSize ;
126+ }
127+ else
128+ {
129+ peerCPRankStart = selfCPRank / (selfCPNum / peerCPNum);
130+ peerCPRankEnd = peerCPRankStart + mDomainCPSize ;
131+ }
132+
111133 std::vector<int > retRanks;
112134 for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
113135 {
114- for (int j = peerPPRankStart ; j < peerPPRankEnd ; j++)
136+ for (int j = peerCPRankStart ; j < peerCPRankEnd ; j++)
115137 {
116- int irank = j * peerTPNum + i;
117- retRanks.push_back (irank);
138+ for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
139+ {
140+ int irank = (k * peerTPNum * peerCPNum) + (j * peerTPNum) + i;
141+ retRanks.push_back (irank);
142+ }
118143 }
119144 }
120145
@@ -131,7 +156,7 @@ TargetRanksInfo TargetRanksInfoForDP(
131156 = (peerNbHeadsPerLayer * peerTPSizePerDPGroup) / (selfNbHeadsPerLayer * selfTPSizePerDPGroup);
132157 }
133158
134- return {mDomainPPSize , mDomainTPSize , std::move (retRanks), mDupHeadFactor , mPeerDupHeadFactor };
159+ return {mDomainPPSize , mDomainTPSize , mDomainCPSize , std::move (retRanks), mDupHeadFactor , mPeerDupHeadFactor };
135160}
136161
137162TargetRanksInfo targetIRanks (
0 commit comments