diff --git a/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc b/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc index 2ccade07d..0fbac1977 100644 --- a/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc +++ b/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc @@ -398,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); OP_TILING_CHECK(xDim0 != topkWeightsDim0, - OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), + OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false); OP_TILING_CHECK(xDim1 != recvXDim1, OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), diff --git a/csrc/deepep/ops/op_kernel/notify_dispatch.h b/csrc/deepep/ops/op_kernel/notify_dispatch.h index a4f75a218..6d5b34312 100644 --- a/csrc/deepep/ops/op_kernel/notify_dispatch.h +++ b/csrc/deepep/ops/op_kernel/notify_dispatch.h @@ -51,7 +51,8 @@ class NotifyDispatch // Synchronization flag occupies length constexpr static int64_t FLAG_UNIT_INT_NUM = 4; constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1); - constexpr static int32_t BATCH_ROUND = 32; + constexpr static int32_t EXPERT_NORMAL_NUM = 256; + constexpr static int32_t BATCH_ROUND = 16; public: __aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag) @@ -71,7 +72,7 @@ class NotifyDispatch recvOffset_ = recvOffset; maxBs_ = maxBs; recvTokensPerExpert_ = recvTokensPerExpert; - batchRounds = BATCH_ROUND; + batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2; tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; @@ -339,12 +340,14 @@ class NotifyDispatch uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup; uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup; uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); // 目标地址也改变,使用对齐后的地址 DataCopyExtParams recvDataParams = {1U, static_cast(singleRankBatchDataLen), 0, 0, 0}; DataCopyPadExtParams DataCopyPadExtParams{false, 0U, 0U, 0U}; for (uint32_t i = 0; i < rankSize; i++) { uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup; - uint32_t dstOffset = i * singleRankBatchElemCount; + uint32_t dstOffset = i * strideElem; // 搬运该Rank下的 currentBatchRounds 数据 DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams); } @@ -357,6 +360,9 @@ class NotifyDispatch Duplicate(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V SyncFunc(); + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); uint32_t computeNum = currentBatchRounds * numLocalExperts; for (uint32_t r = 0; r < currentBatchRounds; ++r) { uint32_t computeNumIn = r * numLocalExperts; @@ -364,7 +370,8 @@ class NotifyDispatch for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx); } } @@ -376,6 +383,9 @@ class NotifyDispatch sendOffsetTensor = sendOffsetBuf.Get(); Duplicate(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t)); SyncFunc(); + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); uint32_t computeNum = currentBatchRounds * numLocalExperts; for (uint32_t r = 0; r < currentBatchRounds; ++r) { uint32_t computeNumIn = r * numLocalExperts; @@ -383,49 +393,24 @@ class NotifyDispatch for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1); } } } } - __aicore__ inline void ReorderSendTokensPerRankOutput() - { - pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); - pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen); - sendTokensPerRankTensor = sendTokensPerRankBuf.Get(); - seenRoundTensor = seenRoundBuf.Get(); - Duplicate(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); - SyncFunc(); - SyncFunc(); - for (uint32_t r = 0; r < round; ++r) { - Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); - SyncFunc(); - for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = - sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId); - if (!seenRoundTensor(srcRank)) { - sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); - seenRoundTensor(srcRank) = 1; - } - } - } - SyncFunc(); - } - } - __aicore__ inline void BuildTotalRecvTokens() { if (blockIdx != TOTAL_CNT_CORE) { return; } int32_t sumVal = 0; - - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -467,8 +452,10 @@ class NotifyDispatch if (blockIdx != RECV_COUNT_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -505,8 +492,10 @@ class NotifyDispatch if (blockIdx != RECV_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen); @@ -535,8 +524,10 @@ class NotifyDispatch if (blockIdx != MAX_BS_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); @@ -549,16 +540,19 @@ class NotifyDispatch SyncFunc(); for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; - + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); ReorderOutput(rStart, currentBatchRounds); SyncFunc(); for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t offsetInRound = r * numLocalExperts; Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); SyncFunc(); for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds + - r * numLocalExperts + expId); + uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; if (!seenRoundTensor(srcRank)) { sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); seenRoundTensor(srcRank) = 1; @@ -585,8 +579,10 @@ class NotifyDispatch if (blockIdx != RECV_TOKEN_PER_EXP_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -630,8 +626,10 @@ class NotifyDispatch return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -676,8 +674,10 @@ class NotifyDispatch if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -726,8 +726,10 @@ class NotifyDispatch if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb pipe.InitBuffer(recvCountBuf, sendCountAlignLen);