Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
constexpr uint64_t UB_ALIGN = 32UL;
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
constexpr uint64_t INIT_TILINGKEY = 10000UL;

enum class CommQuantMode : int32_t { NON_QUANT = 0, INT12_QUANT = 1, INT8_QUANT = 2 };
using CommQuantModeType = std::underlying_type<CommQuantMode>;
Expand Down Expand Up @@ -541,7 +542,7 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
OP_LOGE(nodeName,
"HCCL_BUFFSIZE is too SMALL, realBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
" tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE("
"((perRoundTokens * k * tokenNeedSizeCombine)) + 8MB + 102MB) * 2) = %luMB, "
"((realBs * k * tokenNeedSizeCombine)) + 4MB + 204MB) * 2) = %luMB, "
"HCCL_BUFFSIZE=%luMB.",
realBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL,
maxWindowSize / MB_SIZE),
Expand Down Expand Up @@ -569,6 +570,14 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
OP_LOGD(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize);
PrintTilingDataInfo(nodeName, *tilingData);

uint64_t tilingKey = INIT_TILINGKEY;
uint32_t maxRound = tilingData->camMoeCombineNormalInfo.maxRound;
if (maxRound > 1) {
tilingKey += 1;
}
OP_LOGD(nodeName, "tilingKey is %lu", tilingKey);
context->SetTilingKey(tilingKey);

return ge::GRAPH_SUCCESS;
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ static ge::graphStatus CamMoeDispatchNormalA3TilingFuncImpl(gert::TilingContext
"HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu,"
" localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu,"
" k = %lu, NEEDED_HCCL_BUFFSIZE((maxBs * k * (tokenNeedSizeDispatch"
" + tokenNeedSizeCombine) + 3MB + 204MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.",
" + tokenNeedSizeCombine) + 4MB + 204MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.",
maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k,
actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE),
return ge::GRAPH_FAILED);
Expand Down
4 changes: 2 additions & 2 deletions csrc/deepep/ops/op_host/tiling_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <cstdint>

namespace Moe {
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 8U * 1024UL * 1024UL;
constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 102U * 1024UL * 1024UL;
constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 4U * 1024UL * 1024UL;
constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL;
} // namespace Moe
#endif // TILING_ARGS_H
20 changes: 15 additions & 5 deletions csrc/deepep/ops/op_kernel/cam_moe_combine_normal.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#include "kernel_operator.h"
#include "lib/matmul_intf.h"
#include "cam_moe_combine_normal.h"
#include "cam_moe_combine_normal_multi_round.h"
#include "cam_moe_combine_normal_tiling.h"
using namespace AscendC;
using namespace CamMoeCombineNormalImpl;

#define TILINGKEY_MULTI_ROUND 10001
#define TILINGKEY_SINGLE_ROUND 10000

extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount,
GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut,
Expand All @@ -16,9 +19,16 @@ extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_A

#if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16)
GET_TILING_DATA_WITH_STRUCT(CamMoeCombineNormalTilingData, tilingData, tilingGM);
CamMoeCombineNormal<DTYPE_RECV_X, DTYPE_X, int32_t> op;
op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, sendCostStatsOut, workspaceGM, &pipe,
&tilingData);
op.Process();
if (TILING_KEY_IS(TILINGKEY_MULTI_ROUND)) {
CamMoeCombineNormalMultiRoundImpl::CamMoeCombineNormalMultiRound<DTYPE_RECV_X, DTYPE_X, int32_t> op;
op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, sendCostStatsOut, workspaceGM, &pipe,
&tilingData);
op.Process();
} else if (TILING_KEY_IS(TILINGKEY_SINGLE_ROUND)) {
CamMoeCombineNormalImpl::CamMoeCombineNormal<DTYPE_RECV_X, DTYPE_X, int32_t> op;
op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, sendCostStatsOut, workspaceGM, &pipe,
&tilingData);
op.Process();
}
#endif
}
Loading