diff --git a/csrc/deepep/ops/op_kernel/notify_dispatch.h b/csrc/deepep/ops/op_kernel/notify_dispatch.h index 1b7c92a02..a4f75a218 100644 --- a/csrc/deepep/ops/op_kernel/notify_dispatch.h +++ b/csrc/deepep/ops/op_kernel/notify_dispatch.h @@ -51,6 +51,7 @@ class NotifyDispatch // Synchronization flag occupies length constexpr static int64_t FLAG_UNIT_INT_NUM = 4; constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1); + constexpr static int32_t BATCH_ROUND = 32; public: __aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag) @@ -70,14 +71,11 @@ class NotifyDispatch recvOffset_ = recvOffset; maxBs_ = maxBs; recvTokensPerExpert_ = recvTokensPerExpert; - tokenPerExpertDataAlignLen = Ceil(round * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataOffsetAlignLen = Ceil(round * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataAlignLen = Ceil(round * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - recvDataAlignLen = Ceil(round * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * - UB_ALIGN_SIZE; // 32 * 256 * 3 * 4 = 96KB + batchRounds = BATCH_ROUND; + tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; sendTokensPerRankAlignLen = Ceil(numRanks * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendCountAlignLen = - Ceil(round * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32 * 256 * 4 = 32KB // Initialize core grouping InitCoreGroup(); @@ -101,9 +99,6 @@ class NotifyDispatch sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput); recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput); recvDataOutGt.SetGlobalBuffer((__gm__ int32_t *)recvDataOutput); - pipe.InitBuffer(sendCountBuf, tokenPerExpertDataAlignLen); - pipe.InitBuffer(sendOffsetBuf, tokenPerExpertDataAlignLen); - pipe.InitBuffer(recvDataBuf, recvDataAlignLen); } __aicore__ inline void Process() @@ -119,8 +114,8 @@ class NotifyDispatch ShareToShareSlice(); } SyncAll(); - ReorderOutput(); - BuildTotalRecvTokens(); // 出错点 + pipe.Reset(); + BuildTotalRecvTokens(); BuildRecvCount(); BuildRecvOffset(); BuildMaxBs(); @@ -153,87 +148,90 @@ class NotifyDispatch pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen); pipe.InitBuffer(sendDataBuf, sendDataAlignLen); pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen); - int batchRounds = 32; int localExpertsNum = numExperts / rankSize; int newSendDataAlignLen = Ceil(batchRounds * localExpertsNum * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(newSendDataBuf, newSendDataAlignLen); - tokenPerExpertTensor = tokenPerExpertDataBuf.Get(); sendDataTensor = sendDataBuf.Get(); sendDataOffsetTensor = sendDataOffsetBuf.Get(); newSendDataTensor = newSendDataBuf.Get(); - DataCopyExtParams tokenPerExpertParams = {1U, tokenPerExpertDataAlignLen, 0U, 0U, 0U}; - DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(tokenPerExpertTensor, tokenPerExpertDataInputGt, tokenPerExpertParams, copyPadExtParams); - - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); int realRound = (numTokens + perRoundTokens - 1) / perRoundTokens; int lastRoundNumTokens = numTokens % perRoundTokens; if (lastRoundNumTokens == 0 && numTokens > 0) { lastRoundNumTokens = perRoundTokens; } + int totalRounds = round; - int prefixSum = 0; - - for (int r = 0; r < realRound; ++r) { - prefixSum = 0; - for (int i = 0; i < numExperts; ++i) { - int numTokensExpert = tokenPerExpertTensor(r * numExperts + i); - int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; - sendDataTensor(baseUB) = numTokensExpert; - sendDataTensor(baseUB + 1) = prefixSum; - int roundNumTokens = (r == realRound - 1 ? lastRoundNumTokens : perRoundTokens); - sendDataTensor(baseUB + 2) = roundNumTokens; - sendDataOffsetTensor(r * numExperts + i) = prefixSum; - prefixSum += numTokensExpert; + for (int rBase = 0; rBase < totalRounds; rBase += batchRounds) { + int currentBatch = (rBase + batchRounds > totalRounds) ? (totalRounds - rBase) : batchRounds; + uint32_t copyLen = currentBatch * numExperts * sizeof(int32_t); + DataCopyExtParams tokenPerExpertParams = {1U, copyLen, 0U, 0U, 0U}; + DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopyPad(tokenPerExpertTensor, tokenPerExpertDataInputGt[rBase * numExperts], tokenPerExpertParams, + copyPadExtParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + for (int r = 0; r < currentBatch; r++) { + int absRound = rBase + r; + int prefixSum = 0; + if (absRound < realRound) { + for (int i = 0; i < numExperts; ++i) { + int numTokensExpert = tokenPerExpertTensor(r * numExperts + i); // S operation + int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; + sendDataTensor(baseUB) = numTokensExpert; + sendDataTensor(baseUB + 1) = prefixSum; + int roundNumTokens = (absRound == realRound - 1 ? lastRoundNumTokens : perRoundTokens); + sendDataTensor(baseUB + 2) = roundNumTokens; + sendDataOffsetTensor(r * numExperts + i) = prefixSum; + prefixSum += numTokensExpert; + } + } else { + // padding round + for (int i = 0; i < numExperts; ++i) { + int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; + sendDataTensor(baseUB) = 0; + sendDataTensor(baseUB + 1) = 0; + sendDataTensor(baseUB + 2) = 0; + sendDataOffsetTensor(r * numExperts + i) = 0; + } + } } - } - for (int r = realRound; r < round; ++r) { - for (int i = 0; i < numExperts; ++i) { - int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; - sendDataTensor(baseUB) = 0; - sendDataTensor(baseUB + 1) = 0; - sendDataTensor(baseUB + 2) = 0; - sendDataOffsetTensor(r * numExperts + i) = 0; - } - } + uint32_t offsetCopyLen = currentBatch * numExperts * sizeof(T); + DataCopyExtParams sendDataOffsetParams = {1U, offsetCopyLen, 0U, 0U, 0U}; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopyPad(sendDataOffsetOutputGt[rBase * numExperts], sendDataOffsetTensor, sendDataOffsetParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); - int totalRounds = round; - if (round > 1) { for (int tr = 0; tr < rankSize; ++tr) { - for (int rBase = 0; rBase < totalRounds; rBase += batchRounds) { - int currentBatch = (rBase + batchRounds > totalRounds) ? (totalRounds - rBase) : batchRounds; - for (int r = 0; r < currentBatch; ++r) { - int absRound = rBase + r; - for (int le = 0; le < localExpertsNum; ++le) { - int globalExpertIdx = tr * localExpertsNum + le; - int srcIdx = (absRound * numExperts + globalExpertIdx) * sendPerGroup; - int dstIdx = (r * localExpertsNum + le) * sendPerGroup; - newSendDataTensor(dstIdx) = sendDataTensor(srcIdx); - newSendDataTensor(dstIdx + 1) = sendDataTensor(srcIdx + 1); - newSendDataTensor(dstIdx + 2) = sendDataTensor(srcIdx + 2); - } + for (int r = 0; r < currentBatch; ++r) { + for (int le = 0; le < localExpertsNum; ++le) { + int globalExpertIdx = tr * localExpertsNum + le; + int srcIdx = (r * numExperts + globalExpertIdx) * sendPerGroup; + int dstIdx = (r * localExpertsNum + le) * sendPerGroup; + newSendDataTensor(dstIdx) = sendDataTensor(srcIdx); + newSendDataTensor(dstIdx + 1) = sendDataTensor(srcIdx + 1); + newSendDataTensor(dstIdx + 2) = sendDataTensor(srcIdx + 2); } - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - uint32_t copyLen = currentBatch * localExpertsNum * sendPerGroup * sizeof(int32_t); - DataCopyExtParams copyParams = {1U, copyLen, 0U, 0U, 0U}; - uint64_t gmOffset = (uint64_t)tr * totalRounds * localExpertsNum * sendPerGroup + - (uint64_t)rBase * localExpertsNum * sendPerGroup; - DataCopyPad(sendDataInputGt[gmOffset], newSendDataTensor[0], copyParams); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + uint32_t dataCopyLen = currentBatch * localExpertsNum * sendPerGroup * sizeof(int32_t); + DataCopyExtParams copyParams = {1U, dataCopyLen, 0U, 0U, 0U}; + uint64_t gmOffset = (uint64_t)tr * totalRounds * localExpertsNum * sendPerGroup + + (uint64_t)rBase * localExpertsNum * sendPerGroup; + DataCopyPad(sendDataInputGt[gmOffset], newSendDataTensor[0], copyParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); } - } else { - DataCopyPad(sendDataInputGt, sendDataTensor, {1U, sendDataAlignLen, 0U, 0U, 0U}); } - DataCopyExtParams sendDataOffsetParams = {1U, sendDataOffsetAlignLen, 0U, 0U, 0U}; - DataCopyPad(sendDataOffsetOutputGt, sendDataOffsetTensor, sendDataOffsetParams); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); @@ -333,49 +331,60 @@ class NotifyDispatch } } - __aicore__ inline void ReorderOutput() + __aicore__ inline void ReorderOutput(uint32_t rStart, uint32_t currentBatchRounds) { - pipe.Reset(); - pipe.InitBuffer(recvDataBuf, recvDataAlignLen); recvDataTensor = recvDataBuf.Get(); - DataCopyExtParams recvDataParams = {1U, static_cast(recvDataAlignLen), 0, 0, 0}; + Duplicate(recvDataTensor, 0, recvDataAlignLen / sizeof(int32_t)); + + uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup; + uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t); + DataCopyExtParams recvDataParams = {1U, static_cast(singleRankBatchDataLen), 0, 0, 0}; DataCopyPadExtParams DataCopyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(recvDataTensor, recvDataOutputGt, recvDataParams, DataCopyPadExtParams); + + for (uint32_t i = 0; i < rankSize; i++) { + uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup; + uint32_t dstOffset = i * singleRankBatchElemCount; + // 搬运该Rank下的 currentBatchRounds 数据 + DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams); + } + SyncFunc(); } - __aicore__ inline void ReorderSendCountOutput() + __aicore__ inline void ReorderSendCountOutput(uint32_t currentBatchRounds) { - pipe.InitBuffer(sendCountBuf, sendCountAlignLen); - sendCountTensor = sendCountBuf.Get(); - Duplicate(sendCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); + recvCountTensor = recvCountBuf.Get(); + Duplicate(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V + SyncFunc(); - SyncFunc(); - for (uint32_t r = 0; r < round; ++r) { + uint32_t computeNum = currentBatchRounds * numLocalExperts; + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t computeNumIn = r * numLocalExperts; + uint32_t computeNumOut = r * numExperts; for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = - sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId); - sendCountTensor(r * numExperts + index) = recvDataTensor(pair_idx); + uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx); } } } } - __aicore__ inline void ReorderSendOffsetOutput() + __aicore__ inline void ReorderSendOffsetOutput(uint32_t currentBatchRounds) { - pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen); sendOffsetTensor = sendOffsetBuf.Get(); Duplicate(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t)); SyncFunc(); - SyncFunc(); - for (uint32_t r = 0; r < round; ++r) { + uint32_t computeNum = currentBatchRounds * numLocalExperts; + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t computeNumIn = r * numLocalExperts; + uint32_t computeNumOut = r * numExperts; for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = - sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId); - sendOffsetTensor(r * numExperts + index) = recvDataTensor(pair_idx + 1); + uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1); } } } @@ -413,69 +422,111 @@ class NotifyDispatch if (blockIdx != TOTAL_CNT_CORE) { return; } + int32_t sumVal = 0; - ReorderSendCountOutput(); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf2_, Ceil(batchRounds * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); // 32KB + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + LocalTensor batchCntFloat = tmpBuf2_.Get(); + LocalTensor batchSumCntLt = recvCountBuf.Get(); + LocalTensor sharedTmpBuffer = recvDataBuf.Get(); + uint32_t currComputeNum = currentBatchRounds * numExperts; + SyncFunc(); + Cast(batchCntFloat, recvCountTensor, RoundMode::CAST_NONE, currComputeNum); + PipeBarrier(); + ReduceSum(batchSumCntLt, batchCntFloat, sharedTmpBuffer, currComputeNum); + SyncFunc(); + sumVal += static_cast(batchSumCntLt.GetValue(0)); + SyncFunc(); + } pipe.InitBuffer(tmpBuf_, UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf2_, Ceil(round * numExperts * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - LocalTensor totalCntLt = tmpBuf_.Get(); - LocalTensor floatExpTokenCntLt = tmpBuf2_.Get(); - LocalTensor floatExpTokenSumCntLt = sendCountBuf.Get(); - LocalTensor sharedTmpBuffer = recvDataBuf.Get(); - - SyncFunc(); - Cast(floatExpTokenCntLt, sendCountTensor, RoundMode::CAST_NONE, round * numExperts); - PipeBarrier(); - ReduceSum(floatExpTokenSumCntLt, floatExpTokenCntLt, sharedTmpBuffer, round * numExperts); - SyncFunc(); - int32_t sumVal = static_cast(floatExpTokenSumCntLt.GetValue(0)); - PipeBarrier(); totalCntLt(0) = sumVal; - PipeBarrier(); - SyncFunc(); + SyncFunc(); + // 拷贝到outputGT GlobalTensor totalCntGt; totalCntGt.SetGlobalBuffer((__gm__ int32_t *)totalRecvTokens_); DataCopyExtParams copyParams{1, static_cast(1 * sizeof(int32_t)), 0, 0, 0}; DataCopyPad(totalCntGt, totalCntLt, copyParams); + SyncFunc(); } __aicore__ inline void BuildRecvCount() { - // 只需要sendCountTensor + // 只需要recvCountTensor if (blockIdx != RECV_COUNT_CORE) { return; } - ReorderSendCountOutput(); - for (uint32_t r = 0; r < round; ++r) { - int32_t recvCountNum = 0; - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = r * numExperts + expId * rankSize + srcRank; - recvCountNum += sendCountTensor(index); - sendCountTensor(index) = recvCountNum; + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + int32_t recvCountNum = 0; + for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + recvCountNum += recvCountTensor(index); + recvCountTensor(index) = recvCountNum; + } } } + GlobalTensor recvCntGt; + recvCntGt.SetGlobalBuffer((__gm__ int32_t *)recvCount_); + uint32_t globalOffset = rStart * numExperts; + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * numExperts * sizeof(int32_t)), 0, + 0, 0}; + SyncFunc(); + DataCopyPad(recvCntGt[globalOffset], recvCountTensor, copyParams); + + SyncFunc(); } - GlobalTensor recvCntGt; - recvCntGt.SetGlobalBuffer((__gm__ int32_t *)recvCount_); - DataCopyExtParams copyParams{1, static_cast(round * numExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(recvCntGt, sendCountTensor, copyParams); } __aicore__ inline void BuildRecvOffset() { - // 只需要sendOffsetTensor if (blockIdx != RECV_OFFSET_CORE) { return; } - ReorderSendOffsetOutput(); - GlobalTensor recvOffsetGt; - recvOffsetGt.SetGlobalBuffer((__gm__ int32_t *)recvOffset_); - DataCopyExtParams copyParams{1, static_cast(round * numExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(recvOffsetGt, sendOffsetTensor, copyParams); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen); + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendOffsetOutput(currentBatchRounds); + + GlobalTensor recvOffsetGt; + recvOffsetGt.SetGlobalBuffer((__gm__ int32_t *)recvOffset_); + uint32_t globalOffset = rStart * numExperts; + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * numExperts * sizeof(int32_t)), 0, + 0, 0}; + SyncFunc(); + DataCopyPad(recvOffsetGt[globalOffset], sendOffsetTensor, copyParams); + + SyncFunc(); + } } __aicore__ inline void BuildMaxBs() @@ -484,7 +535,40 @@ class NotifyDispatch if (blockIdx != MAX_BS_CORE) { return; } - ReorderSendTokensPerRankOutput(); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + + pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); + pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen); + sendTokensPerRankTensor = sendTokensPerRankBuf.Get(); + seenRoundTensor = seenRoundBuf.Get(); + Duplicate(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + + SyncFunc(); + SyncFunc(); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + SyncFunc(); + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds + + r * numLocalExperts + expId); + if (!seenRoundTensor(srcRank)) { + sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); + seenRoundTensor(srcRank) = 1; + } + } + } + } + SyncFunc(); + } + for (uint32_t srcRank = 0; srcRank < numRanks; ++srcRank) { uint32_t tempBs = sendTokensPerRankTensor(srcRank); maxBsNum = maxBsNum >= tempBs ? maxBsNum : tempBs; @@ -497,55 +581,86 @@ class NotifyDispatch __aicore__ inline void BuildRecvTokenPerExp() { - // 只需要sendCountTensor + // 只需要recvCountTensor if (blockIdx != RECV_TOKEN_PER_EXP_CORE) { return; } - ReorderSendCountOutput(); - pipe.InitBuffer(tmpBuf_, Ceil(round * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf_, Ceil(batchRounds * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); LocalTensor tmpTensor = tmpBuf_.Get(); - for (uint32_t r = 0; r < round; r++) { - for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - int32_t localRecvCount = 0; - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = r * numExperts + expId * rankSize + srcRank; - localRecvCount += sendCountTensor(index); + GlobalTensor recvTokenPerExpGt; + recvTokenPerExpGt.SetGlobalBuffer((__gm__ int32_t *)recvTokensPerExpert_); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + SyncFunc(); + Duplicate(tmpTensor, 0, batchRounds * numLocalExperts); + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount += recvCountTensor(index); + } + tmpTensor(r * numLocalExperts + expId) = localRecvCount; } - tmpTensor(r * numLocalExperts + expId) = localRecvCount; } + SyncFunc(); + DataCopyExtParams copyParams{ + 1, static_cast(currentBatchRounds * numLocalExperts * sizeof(int32_t)), 0, 0, 0}; + SyncFunc(); + SyncFunc(); + DataCopyPad(recvTokenPerExpGt[rStart * numLocalExperts], tmpTensor, copyParams); + + SyncFunc(); } - GlobalTensor recvTokenPerExpGt; - recvTokenPerExpGt.SetGlobalBuffer((__gm__ int32_t *)recvTokensPerExpert_); - DataCopyExtParams copyParams{1, static_cast(round * numLocalExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(recvTokenPerExpGt, tmpTensor, copyParams); } __aicore__ inline void BuildExpGlobalOffset() { - // 只需要sendCountTensor + // 只需要recvCountTensor if (blockIdx != EXP_GLOBAL_OFFSET_CORE) { return; } - ReorderSendCountOutput(); + + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + + // tmpBuf_,需要常驻,消耗:16 *4 pipe.InitBuffer(tmpBuf_, Ceil(numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf2_, Ceil(numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); LocalTensor tmpTensor = tmpBuf_.Get(); - LocalTensor expTensor = tmpBuf2_.Get(); Duplicate(tmpTensor, 0, numLocalExperts); - expTensor(0) = 0; + SyncFunc(); - int32_t localExpTotal = 0; - for (uint32_t r = 0; r < round; r++) { - for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - int32_t localRecvCount = 0; - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = r * numExperts + expId * rankSize + srcRank; - localRecvCount += sendCountTensor(index); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount += recvCountTensor(index); + } + tmpTensor(expId) += localRecvCount; } - tmpTensor(expId) += localRecvCount; } + SyncFunc(); // waiting for recvCountTensor } + pipe.InitBuffer(tmpBuf2_, Ceil(numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor expTensor = tmpBuf2_.Get(); + expTensor(0) = 0; for (uint32_t expId = 1; expId < numLocalExperts; ++expId) { expTensor(expId) = expTensor(expId - 1) + tmpTensor(expId - 1); } @@ -561,24 +676,37 @@ class NotifyDispatch if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) { return; } - ReorderSendCountOutput(); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf2_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); LocalTensor expSrcTotalTensor = tmpBuf_.Get(); - LocalTensor srcRankInExpOffsetTensor = tmpBuf2_.Get(); Duplicate(expSrcTotalTensor, 0, numExperts); SyncFunc(); - int32_t localExpTotal = 0; - for (uint32_t r = 0; r < round; r++) { - for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - int32_t localRecvCount = 0; - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = r * numExperts + expId * rankSize + srcRank; - localRecvCount = sendCountTensor(index); - expSrcTotalTensor(expId * numRanks + srcRank) += localRecvCount; + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount = recvCountTensor(index); + expSrcTotalTensor(expId * numRanks + srcRank) += localRecvCount; + } } } } + + pipe.InitBuffer(tmpBuf2_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor srcRankInExpOffsetTensor = tmpBuf2_.Get(); for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { int32_t cumOffset = 0; for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { @@ -598,32 +726,53 @@ class NotifyDispatch if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) { return; } - ReorderSendCountOutput(); - pipe.InitBuffer(tmpBuf_, Ceil(round * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf2_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - LocalTensor rInSrcrankOffsetTensor = tmpBuf_.Get(); LocalTensor expSrcCumPrevTensor = tmpBuf2_.Get(); Duplicate(expSrcCumPrevTensor, 0, numExperts); - SyncFunc(); - for (uint32_t r = 0; r < round; r++) { + + pipe.InitBuffer(tmpBuf_, Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + GlobalTensor rInSrcrankOffsetGt; + rInSrcrankOffsetGt.SetGlobalBuffer((__gm__ int32_t *)rInSrcrankOffset_); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + LocalTensor rInSrcrankOffsetTensor = tmpBuf_.Get(); + + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * sizeof(int32_t)), 0, 0, 0}; + SyncFunc(); + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - int32_t localRecvCount = 0; for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t pairIdx = r * numExperts + expId * rankSize + srcRank; uint32_t index = expId * rankSize + srcRank; - uint32_t cIdx = expId * numRanks * round + srcRank * round + r; - int32_t recvCnt = sendCountTensor(pairIdx); - int32_t offset = expSrcCumPrevTensor(index); - rInSrcrankOffsetTensor(cIdx) = offset; - expSrcCumPrevTensor(index) = offset + recvCnt; + uint32_t ubBlockOffset = (expId * rankSize + srcRank) * currentBatchRounds; + uint32_t ubBlockOffsetAlign = Ceil(ubBlockOffset * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t ubBlockAlignIndex = ubBlockOffsetAlign / sizeof(int32_t); + uint32_t gmOffset = expId * numRanks * round + srcRank * round + rStart; + + Duplicate(rInSrcrankOffsetTensor, 0, currentBatchRounds * numExperts); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + uint32_t pairIdx = r * numExperts + index; + int32_t recvCnt = recvCountTensor(pairIdx); + int32_t offset = expSrcCumPrevTensor(index); + rInSrcrankOffsetTensor(ubBlockAlignIndex + r) = offset; + expSrcCumPrevTensor(index) = offset + recvCnt; + } + uint32_t copyLenByte = currentBatchRounds * sizeof(int32_t); + DataCopyPad(rInSrcrankOffsetGt[gmOffset], rInSrcrankOffsetTensor[ubBlockAlignIndex], copyParams); + SyncFunc(); } } + SyncFunc(); } - GlobalTensor rInSrcrankOffsetGt; - rInSrcrankOffsetGt.SetGlobalBuffer((__gm__ int32_t *)rInSrcrankOffset_); - DataCopyExtParams copyParams{1, static_cast(round * numExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(rInSrcrankOffsetGt, rInSrcrankOffsetTensor, copyParams); } __aicore__ inline int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum); @@ -690,6 +839,7 @@ class NotifyDispatch int32_t blockIdx; // Index of the current aicore int32_t blockNum; // Total number of aicores for the current rank uint32_t maxBsNum{0}; + int batchRounds{32}; GM_ADDR scale; GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses GM_ADDR totalRecvTokens_; @@ -706,7 +856,7 @@ class NotifyDispatch TBuf tBuf; TBuf<> tokenPerExpertDataBuf; TBuf<> sendDataOffsetBuf; - TBuf<> sendCountBuf; + TBuf<> recvCountBuf; TBuf<> sendOffsetBuf; TBuf<> sendDataBuf; TBuf<> newSendDataBuf; @@ -718,7 +868,7 @@ class NotifyDispatch LocalTensor sendDataTensor; LocalTensor sendDataOffsetTensor; LocalTensor newSendDataTensor; - LocalTensor sendCountTensor; + LocalTensor recvCountTensor; LocalTensor sendOffsetTensor; LocalTensor sendTokensPerRankTensor; LocalTensor recvDataTensor;