Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
OP_TILING_CHECK(xDim0 != topkWeightsDim0,
OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
return false);
OP_TILING_CHECK(xDim1 != recvXDim1,
OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1),
Expand Down
106 changes: 54 additions & 52 deletions csrc/deepep/ops/op_kernel/notify_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class NotifyDispatch
// Synchronization flag occupies length
constexpr static int64_t FLAG_UNIT_INT_NUM = 4;
constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1);
constexpr static int32_t BATCH_ROUND = 32;
constexpr static int32_t EXPERT_NORMAL_NUM = 256;
constexpr static int32_t BATCH_ROUND = 16;

public:
__aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag)
Expand All @@ -71,7 +72,7 @@ class NotifyDispatch
recvOffset_ = recvOffset;
maxBs_ = maxBs;
recvTokensPerExpert_ = recvTokensPerExpert;
batchRounds = BATCH_ROUND;
batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2;
tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
Expand Down Expand Up @@ -339,12 +340,14 @@ class NotifyDispatch
uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup;
uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t);
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t strideElem = alignedDataLen / sizeof(int32_t); // 目标地址也改变,使用对齐后的地址
DataCopyExtParams recvDataParams = {1U, static_cast<uint32_t>(singleRankBatchDataLen), 0, 0, 0};
DataCopyPadExtParams<T> DataCopyPadExtParams{false, 0U, 0U, 0U};

for (uint32_t i = 0; i < rankSize; i++) {
uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup;
uint32_t dstOffset = i * singleRankBatchElemCount;
uint32_t dstOffset = i * strideElem;
// 搬运该Rank下的 currentBatchRounds 数据
DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams);
}
Expand All @@ -357,14 +360,18 @@ class NotifyDispatch
Duplicate<T>(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V

SyncFunc<AscendC::HardEvent::V_S>();
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
uint32_t computeNum = currentBatchRounds * numLocalExperts;
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
uint32_t computeNumIn = r * numLocalExperts;
uint32_t computeNumOut = r * numExperts;
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
uint32_t index = expId * rankSize + srcRank;
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx);
}
}
Expand All @@ -376,56 +383,34 @@ class NotifyDispatch
sendOffsetTensor = sendOffsetBuf.Get<T>();
Duplicate<T>(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t));
SyncFunc<AscendC::HardEvent::V_S>();
uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
uint32_t computeNum = currentBatchRounds * numLocalExperts;
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
uint32_t computeNumIn = r * numLocalExperts;
uint32_t computeNumOut = r * numExperts;
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
uint32_t index = expId * rankSize + srcRank;
uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId);
uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId);
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1);
}
}
}
}

__aicore__ inline void ReorderSendTokensPerRankOutput()
{
pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen);
sendTokensPerRankTensor = sendTokensPerRankBuf.Get<int32_t>();
seenRoundTensor = seenRoundBuf.Get<int32_t>();
Duplicate<int32_t>(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
SyncFunc<AscendC::HardEvent::V_S>();
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t r = 0; r < round; ++r) {
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
SyncFunc<AscendC::HardEvent::V_S>();
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
uint32_t index = expId * rankSize + srcRank;
uint32_t pair_idx =
sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId);
if (!seenRoundTensor(srcRank)) {
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
seenRoundTensor(srcRank) = 1;
}
}
}
SyncFunc<AscendC::HardEvent::S_V>();
}
}

__aicore__ inline void BuildTotalRecvTokens()
{
if (blockIdx != TOTAL_CNT_CORE) {
return;
}
int32_t sumVal = 0;

recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down Expand Up @@ -467,8 +452,10 @@ class NotifyDispatch
if (blockIdx != RECV_COUNT_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down Expand Up @@ -505,8 +492,10 @@ class NotifyDispatch
if (blockIdx != RECV_OFFSET_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen);
Expand Down Expand Up @@ -535,8 +524,10 @@ class NotifyDispatch
if (blockIdx != MAX_BS_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);

pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen);
Expand All @@ -549,16 +540,19 @@ class NotifyDispatch
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) {
uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds;

uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t);
uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t strideElem = alignedDataLen / sizeof(int32_t);
ReorderOutput(rStart, currentBatchRounds);
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t r = 0; r < currentBatchRounds; ++r) {
uint32_t offsetInRound = r * numLocalExperts;
Duplicate<int32_t>(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t));
SyncFunc<AscendC::HardEvent::V_S>();
for (uint32_t expId = 0; expId < numLocalExperts; ++expId) {
for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) {
uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds +
r * numLocalExperts + expId);
uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId);
uint32_t pair_idx = srcRank * strideElem + offsetInRank;
if (!seenRoundTensor(srcRank)) {
sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2);
seenRoundTensor(srcRank) = 1;
Expand All @@ -585,8 +579,10 @@ class NotifyDispatch
if (blockIdx != RECV_TOKEN_PER_EXP_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down Expand Up @@ -630,8 +626,10 @@ class NotifyDispatch
return;
}

recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down Expand Up @@ -676,8 +674,10 @@ class NotifyDispatch
if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down Expand Up @@ -726,8 +726,10 @@ class NotifyDispatch
if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) {
return;
}
recvDataAlignLen =
Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup;
uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t);
uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE;
recvDataAlignLen = rankSize * singleRankAlignLen;
pipe.InitBuffer(recvDataBuf, recvDataAlignLen);
sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb
pipe.InitBuffer(recvCountBuf, sendCountAlignLen);
Expand Down