From d08cc83eabdc55c2e2293f8332c74a821c0d5bd6 Mon Sep 17 00:00:00 2001 From: luanyundu <1425036963@qq.com> Date: Tue, 3 Feb 2026 16:32:51 +0800 Subject: [PATCH 1/2] adapt ant moving to A2 single machine --- .../ops2/op_host/cam_moe_combine_normal.cpp | 2 +- .../op_host/cam_moe_combine_normal_tiling.cc | 88 +-- .../op_host/cam_moe_dispatch_normal_tiling.cc | 36 +- .../ops2/op_host/dispatch_layout_tiling.cc | 14 +- csrc/deepep/ops2/op_host/mc2_tiling_utils.h | 25 - .../moe_distribute_combine_v2_tiling.cc | 4 +- .../moe_distribute_dispatch_v2_tiling.cc | 4 +- .../ops2/op_host/notify_dispatch_tiling.cc | 27 +- .../op_api/aclnn_cam_moe_combine_normal.cpp | 4 +- .../op_api/aclnn_cam_moe_combine_normal.h | 3 +- .../op_host/op_api/aclnn_dispatch_layout.h | 1 + csrc/deepep/ops2/op_host/tiling_args.h | 5 +- .../ops2/op_kernel/cam_moe_combine_normal.cpp | 22 +- .../ops2/op_kernel/cam_moe_combine_normal.h | 53 +- .../cam_moe_combine_normal_multi_round.h | 649 ++++++++++++++++ .../op_kernel/cam_moe_combine_normal_tiling.h | 1 - .../op_kernel/cam_moe_dispatch_normal.cpp | 10 +- .../ops2/op_kernel/cam_moe_dispatch_normal.h | 257 +++++-- csrc/deepep/ops2/op_kernel/check_winsize.h | 17 +- csrc/deepep/ops2/op_kernel/comm_args.h | 4 +- csrc/deepep/ops2/op_kernel/dispatch_layout.h | 242 +++--- .../deepep/ops2/op_kernel/notify_dispatch.cpp | 24 +- csrc/deepep/ops2/op_kernel/notify_dispatch.h | 726 ++++++++++++++---- .../ops2/op_kernel/notify_dispatch_tiling.h | 1 + 24 files changed, 1701 insertions(+), 518 deletions(-) create mode 100644 csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_multi_round.h diff --git a/csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp b/csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp index 83289564d..d419f4abb 100644 --- a/csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp @@ -30,7 +30,7 @@ class CamMoeCombineNormal : public OpDef .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .AutoContiguous(); - this->Input("topk_idx") + this->Input("token_idx") .ParamType(REQUIRED) .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) diff --git a/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc b/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc index 44c6330aa..68b731ac1 100644 --- a/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc +++ b/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc @@ -18,6 +18,7 @@ #include "error_log.h" #include "graph/utils/type_utils.h" #include "register/op_def_registry.h" +#include "mc2_tiling_utils.h" #include "../op_kernel/cam_moe_combine_normal_tiling.h" #include "tiling_args.h" @@ -26,37 +27,12 @@ using namespace ge; using namespace Moe; namespace { -class Mc2TilingUtils -{ -public: - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - const char *hcclBuffSize = getenv("DEEPEP_HCCL_BUFFSIZE") == nullptr ? "HCCL_BUFFSIZE" : "DEEPEP_HCCL_BUFFSIZE"; - if (getenv(hcclBuffSize) == nullptr) { - OP_LOGD("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(hcclBuffSize)); - defaultWindowSize = std::stoi(envStr); - } catch (const std::invalid_argument &ia) { - OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); - } catch (const std::out_of_range &oor) { - OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } -}; constexpr uint32_t RECV_X_INDEX = 0; constexpr uint32_t TOKEN_SRC_INFO_INDEX = 1; constexpr uint32_t EP_RECV_COUNTS_INDEX = 2; constexpr uint32_t TOPK_WEIGHTS_INDEX = 3; -constexpr uint32_t TOPK_IDX_INDEX = 4; +constexpr uint32_t TOKEN_IDX_INDEX = 4; constexpr uint32_t TP_RECV_COUNTS_INDEX = 5; - constexpr uint32_t OUTPUT_X_INDEX = 0; constexpr uint32_t OUTPUT_SEND_COST_INDEX = 1; @@ -80,7 +56,7 @@ constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL; constexpr int64_t MAX_EP_WORLD_SIZE = 384; constexpr int64_t MIN_EP_WORLD_SIZE = 2; constexpr int64_t MAX_TP_WORLD_SIZE = 2; -constexpr int64_t BS_UPPER_BOUND = 8000; +constexpr int64_t BS_UPPER_BOUND = 65536; constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes @@ -96,6 +72,7 @@ constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL; constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL; constexpr uint64_t UB_ALIGN = 32UL; constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL; +constexpr uint64_t INIT_TILINGKEY = 10000UL; enum class CommQuantMode : int32_t { NON_QUANT = 0, INT12_QUANT = 1, INT8_QUANT = 2 }; using CommQuantModeType = std::underlying_type; @@ -231,14 +208,14 @@ static bool CheckInputTensorDim(gert::TilingContext *context, const char *nodeNa OP_LOGD(nodeName, "topkWeights dim0 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(0)); OP_LOGD(nodeName, "topkWeights dim1 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(1)); - const gert::StorageShape *topIdxStorageShape = context->GetInputShape(TOPK_IDX_INDEX); - OP_TILING_CHECK(topIdxStorageShape == nullptr, OP_LOGE(nodeName, "topkWeights is null."), return false); - OP_TILING_CHECK(topIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, - OP_LOGE(nodeName, "topkIdx must be 2-dimension, but got %lu dim", - topIdxStorageShape->GetStorageShape().GetDimNum()), + const gert::StorageShape *tokenIdxStorageShape = context->GetInputShape(TOKEN_IDX_INDEX); + OP_TILING_CHECK(tokenIdxStorageShape == nullptr, OP_LOGE(nodeName, "tokenIdx is null."), return false); + OP_TILING_CHECK(tokenIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "tokenIdx must be 2-dimension, but got %lu dim", + tokenIdxStorageShape->GetStorageShape().GetDimNum()), return false); - OP_LOGD(nodeName, "topkIdx dim0 = %ld", topIdxStorageShape->GetStorageShape().GetDim(0)); - OP_LOGD(nodeName, "topkIdx dim1 = %ld", topIdxStorageShape->GetStorageShape().GetDim(1)); + OP_LOGD(nodeName, "tokenIdx dim0 = %ld", tokenIdxStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "tokenIdx dim1 = %ld", tokenIdxStorageShape->GetStorageShape().GetDim(1)); return true; } @@ -318,13 +295,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_TILING_CHECK((topkWeightsDesc->GetDataType() != ge::DT_FLOAT), OP_LOGE(nodeName, "topkWeights dataType is invalid, dataType should be float, but is "), return false); - auto topkIdxDesc = context->GetInputDesc(TOPK_IDX_INDEX); - OP_TILING_CHECK(topkIdxDesc == nullptr, OP_LOGE(nodeName, "topkIdxDesc is null."), return false); - OP_TILING_CHECK((topkIdxDesc->GetDataType() != ge::DT_INT32), - OP_LOGE(nodeName, - "topkIdxForCombine dataType is invalid," - " dataType should be int32, but is"), - return false); + auto tokenIdxDesc = context->GetInputDesc(TOKEN_IDX_INDEX); + OP_TILING_CHECK(tokenIdxDesc == nullptr, OP_LOGE(nodeName, "tokenIdxDesc is null."), return false); + OP_TILING_CHECK((tokenIdxDesc->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "tokenIdx dataType is invalid, dataType should be int32, but is "), return false); auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX); OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); OP_TILING_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()), @@ -369,11 +343,11 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName static_cast(ge::GetPrimaryFormat(topkWeightsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, OP_LOGE(nodeName, "topkWeightsFormat is invalid"), return false); - auto topkIdxsDesc = context->GetOptionalInputDesc(TOPK_IDX_INDEX); - OP_TILING_CHECK(topkIdxsDesc == nullptr, OP_LOGE(nodeName, "topkIdxsDesc is null."), return false); + auto tokenIdxDesc = context->GetInputDesc(TOKEN_IDX_INDEX); + OP_TILING_CHECK(tokenIdxDesc == nullptr, OP_LOGE(nodeName, "tokenIdxDesc is null."), return false); OP_TILING_CHECK( - static_cast(ge::GetPrimaryFormat(topkIdxsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, - OP_LOGE(nodeName, "topkIdxsFormat is invalid"), return false); + static_cast(ge::GetPrimaryFormat(tokenIdxDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "tokenIdxFormat is invalid"), return false); auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX); OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); @@ -424,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); OP_TILING_CHECK(xDim0 != topkWeightsDim0, - OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), + OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false); OP_TILING_CHECK(xDim1 != recvXDim1, OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), @@ -577,19 +551,22 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext * uint64_t epWorldSize = static_cast(tilingData->camMoeCombineNormalInfo.epWorldSize); uint64_t k = static_cast(tilingData->camMoeCombineNormalInfo.k); uint64_t perRoundTokens = tilingData->camMoeCombineNormalInfo.perRoundTokens; + uint64_t realMaxBs = tilingData->camMoeCombineNormalInfo.realMaxBs; + uint64_t realBs = std::min(perRoundTokens, realMaxBs); + uint32_t maxRound = tilingData->camMoeCombineNormalInfo.maxRound; // combine数据区 token首地址对齐512 uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - uint64_t actualSize = - (perRoundTokens * k * tokenNeedSizeCombine + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) * - DOUBLE_DATA_BUFFER; + tokenNeedSizeCombine = maxRound > 1 ? tokenNeedSizeCombine * 2 : tokenNeedSizeCombine; + uint64_t actualSize = (realBs * k * tokenNeedSizeCombine + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) * + DOUBLE_DATA_BUFFER; OP_TILING_CHECK( (actualSize > maxWindowSize), OP_LOGE(nodeName, - "HCCL_BUFFSIZE is too SMALL, perRoundTokens = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u," + "HCCL_BUFFSIZE is too SMALL, realBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u," " tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE(" - "((perRoundTokens * k * tokenNeedSizeCombine)) + 8MB + 102MB) * 2) = %luMB, " + "((realBs * k * tokenNeedSizeCombine * 2)) + 8MB + 404MB) * 2) = %luMB, " "HCCL_BUFFSIZE=%luMB.", - perRoundTokens, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL, + realBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE), return ge::GRAPH_FAILED); tilingData->camMoeCombineNormalInfo.totalWinSize = maxWindowSize; @@ -615,6 +592,13 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext * OP_LOGD(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize); PrintTilingDataInfo(nodeName, *tilingData); + uint64_t tilingKey = INIT_TILINGKEY; + if (maxRound > 1) { + tilingKey += 1; + } + OP_LOGD(nodeName, "tilingKey is %lu", tilingKey); + context->SetTilingKey(tilingKey); + return ge::GRAPH_SUCCESS; } diff --git a/csrc/deepep/ops2/op_host/cam_moe_dispatch_normal_tiling.cc b/csrc/deepep/ops2/op_host/cam_moe_dispatch_normal_tiling.cc index cb55957f8..4c5d47b74 100644 --- a/csrc/deepep/ops2/op_host/cam_moe_dispatch_normal_tiling.cc +++ b/csrc/deepep/ops2/op_host/cam_moe_dispatch_normal_tiling.cc @@ -17,6 +17,7 @@ #include "error_log.h" #include "graph/utils/type_utils.h" #include "register/op_def_registry.h" +#include "mc2_tiling_utils.h" #include "../op_kernel/cam_moe_dispatch_normal_tiling.h" #include "tiling_args.h" @@ -25,30 +26,6 @@ using namespace ge; using namespace Moe; namespace { -class Mc2TilingUtils -{ -public: - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - const char *hcclBuffSize = getenv("DEEPEP_HCCL_BUFFSIZE") == nullptr ? "HCCL_BUFFSIZE" : "DEEPEP_HCCL_BUFFSIZE"; - if (getenv(hcclBuffSize) == nullptr) { - OP_LOGD("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(hcclBuffSize)); - defaultWindowSize = std::stoi(envStr); - } catch (const std::invalid_argument &ia) { - OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); - } catch (const std::out_of_range &oor) { - OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } -}; constexpr uint32_t X_INDEX = 0U; constexpr uint32_t EXPERT_IDS_INDEX = 1U; constexpr uint32_t SEND_OFFSET_INDEX = 2U; @@ -87,7 +64,7 @@ constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL; constexpr int64_t MAX_EP_WORLD_SIZE = 384; constexpr int64_t MIN_EP_WORLD_SIZE = 2; constexpr int64_t MAX_TP_WORLD_SIZE = 2; -constexpr int64_t BS_UPPER_BOUND = 32768; // 最大bs +constexpr int64_t BS_UPPER_BOUND = 65536; // 最大bs constexpr uint32_t TILINGKEY_TP_WORLD_SIZE = 100; constexpr uint32_t TP_WORLD_SIZE_TWO = 2; @@ -586,22 +563,25 @@ static ge::graphStatus CamMoeDispatchNormalA3TilingFuncImpl(gert::TilingContext uint64_t k = static_cast(tilingData->camMoeDispatchNormalInfo.k); uint64_t epWorldSize = static_cast(tilingData->camMoeDispatchNormalInfo.epWorldSize); uint64_t maxBs = static_cast(tilingData->camMoeDispatchNormalInfo.globalBs) / epWorldSize; - + uint32_t round = tilingData->camMoeDispatchNormalInfo.round; // dispatch数据区 token首对齐512,有效token长度h_align_32b + scale(32b) + 三元组(3*4b) uint64_t tokenActualLen = ((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_EXPAND_IDX_BUFFER; uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; + tokenNeedSizeCombine = + round > 1 ? tokenNeedSizeCombine * 2 : tokenNeedSizeCombine; // round > 1 combine要使用double buffer // 未考虑双流时大小 uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) * - DOUBLE_DATA_BUFFER; + DOUBLE_DATA_BUFFER + + Moe::STATE_SIZE * 4; OP_TILING_CHECK((actualSize > maxWindowSize), OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu," " localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu," " k = %lu, NEEDED_HCCL_BUFFSIZE((maxBs * k * (tokenNeedSizeDispatch" - " + tokenNeedSizeCombine) + 3MB + 204MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.", + " + tokenNeedSizeCombine) + 4MB + 404MB) * 2 + 4 * 2MB) = %luMB, HCCL_BUFFSIZE=%luMB.", maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE), return ge::GRAPH_FAILED); diff --git a/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc b/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc index 141550ad6..c35e611b9 100644 --- a/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc +++ b/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc @@ -34,6 +34,7 @@ constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0; constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1; constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2; constexpr uint32_t OUTPUT_NOTIFY_SEND_DATA_INDEX = 3; +constexpr uint32_t OUTPUT_SEND_TOKEN_IDX_SMALL_INDEX = 4; constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0; constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1; @@ -98,8 +99,8 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con auto numExpertsPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_EXPERTS_INDEX)); auto numTopkPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOPK_INDEX)); auto localRankSizePtr = attrs->GetAttrPointer(static_cast(ATTR_LOCAL_RANKSIZE_INDEX)); - auto rankIdPtr = attrs->GetAttrPointer(static_cast(ATTR_RANK_ID_INDEX)); auto perRoundTokensPtr = attrs->GetAttrPointer(static_cast(ATTR_PER_ROUND_TOKENS_INDEX)); + auto rankIdPtr = attrs->GetAttrPointer(static_cast(ATTR_RANK_ID_INDEX)); OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED); @@ -107,9 +108,9 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."), return ge::GRAPH_FAILED); - OP_TILING_CHECK(rankIdPtr == nullptr, OP_LOGE(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK(perRoundTokensPtr == nullptr, OP_LOGE(nodeName, "perRoundTokensPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(rankIdPtr == nullptr, OP_LOGE(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED); OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE), OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", @@ -133,9 +134,8 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con tilingData.dispatchLayoutInfo.numExperts = static_cast(*numExpertsPtr); tilingData.dispatchLayoutInfo.numTopk = static_cast(*numTopkPtr); tilingData.dispatchLayoutInfo.localRankSize = static_cast(*localRankSizePtr); - tilingData.dispatchLayoutInfo.rankId = static_cast(*rankIdPtr); tilingData.dispatchLayoutInfo.perRoundTokens = static_cast(*perRoundTokensPtr); - + tilingData.dispatchLayoutInfo.rankId = static_cast(*rankIdPtr); if (CheckIfA2MultiMachine(context, tilingData)) { OP_TILING_CHECK( (*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_LOCAL_RANKSIZE), @@ -167,12 +167,14 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX); auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX); auto notifySendData = context->GetOutputDesc(OUTPUT_NOTIFY_SEND_DATA_INDEX); + auto sendTokenIdxSmall = context->GetOutputDesc(OUTPUT_SEND_TOKEN_IDX_SMALL_INDEX); OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false); OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false); OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false); OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false); OP_TILING_CHECK(notifySendData == nullptr, OP_LOGE(nodeName, "notifySendData is null."), return false); + OP_TILING_CHECK(sendTokenIdxSmall == nullptr, OP_LOGE(nodeName, "sendTokenIdxSmall is null."), return false); OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64), OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.", @@ -194,6 +196,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_LOGE(nodeName, "notifySendData datatype is invalid, datatype should be int, but is %d.", static_cast(notifySendData->GetDataType())), return false); + OP_TILING_CHECK((sendTokenIdxSmall->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "sendTokenIdxSmall datatype is invalid, datatype should be int, but is %d.", + static_cast(sendTokenIdxSmall->GetDataType())), + return false); return true; } diff --git a/csrc/deepep/ops2/op_host/mc2_tiling_utils.h b/csrc/deepep/ops2/op_host/mc2_tiling_utils.h index bd43f0d28..484a0a7f3 100644 --- a/csrc/deepep/ops2/op_host/mc2_tiling_utils.h +++ b/csrc/deepep/ops2/op_host/mc2_tiling_utils.h @@ -1,24 +1,3 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! - * \file mc2_tiling_utils.h - * \brief - */ - #ifndef __MC2_TILING_UTILS_H__ #define __MC2_TILING_UTILS_H__ @@ -30,8 +9,6 @@ #include "tiling/tiling_api.h" #include "error_log.h" -namespace mc2tiling { - constexpr uint32_t AICPU_BLOCK_DIM_A2 = 6U; class Mc2TilingUtils { @@ -58,6 +35,4 @@ class Mc2TilingUtils } }; -} // namespace mc2tiling - #endif diff --git a/csrc/deepep/ops2/op_host/moe_distribute_combine_v2_tiling.cc b/csrc/deepep/ops2/op_host/moe_distribute_combine_v2_tiling.cc index 5f5677c55..712cdfba8 100644 --- a/csrc/deepep/ops2/op_host/moe_distribute_combine_v2_tiling.cc +++ b/csrc/deepep/ops2/op_host/moe_distribute_combine_v2_tiling.cc @@ -937,7 +937,7 @@ static ge::graphStatus MoeDistributeCombineA2SingleTilingFuncImpl(gert::TilingCo OP_LOGE(nodeName, "param dim check failed."), return ge::GRAPH_FAILED); // 校验win区大小 - uint64_t maxWindowSize = mc2tiling::Mc2TilingUtils::GetMaxWindowSize(); + uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); uint64_t h = static_cast(tilingData->moeDistributeCombineV2Info.h); uint64_t epWorldSize = static_cast(tilingData->moeDistributeCombineV2Info.epWorldSize); uint64_t k = static_cast(tilingData->moeDistributeCombineV2Info.k); @@ -1383,7 +1383,7 @@ static ge::graphStatus MoeDistributeCombineA2TilingFuncImpl(gert::TilingContext uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); context->SetBlockDim(blockDim); - context->SetAicpuBlockDim(mc2tiling::AICPU_BLOCK_DIM_A2); + context->SetAicpuBlockDim(AICPU_BLOCK_DIM_A2); uint64_t tilingKey = MoeDistributeCombineA2CalcTilingKey(context, isLayered, commQuantMode); context->SetTilingKey(tilingKey); diff --git a/csrc/deepep/ops2/op_host/moe_distribute_dispatch_v2_tiling.cc b/csrc/deepep/ops2/op_host/moe_distribute_dispatch_v2_tiling.cc index b334868f1..8f786a71f 100644 --- a/csrc/deepep/ops2/op_host/moe_distribute_dispatch_v2_tiling.cc +++ b/csrc/deepep/ops2/op_host/moe_distribute_dispatch_v2_tiling.cc @@ -956,7 +956,7 @@ static ge::graphStatus MoeDistributeDispatchA2SingleTilingFuncImpl(gert::TilingC OP_LOGE(nodeName, "Check tensor shape failed."), return ge::GRAPH_FAILED); // 校验win区大小 - uint64_t maxWindowSize = mc2tiling::Mc2TilingUtils::GetMaxWindowSize(); + uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); uint64_t h = static_cast(tilingData->moeDistributeDispatchV2Info.h); uint64_t k = static_cast(tilingData->moeDistributeDispatchV2Info.k); uint64_t epWorldSize = static_cast(tilingData->moeDistributeDispatchV2Info.epWorldSize); @@ -1319,7 +1319,7 @@ static ge::graphStatus MoeDistributeDispatchA2TilingFuncImpl(gert::TilingContext uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); context->SetBlockDim(blockDim); - context->SetAicpuBlockDim(mc2tiling::AICPU_BLOCK_DIM_A2); + context->SetAicpuBlockDim(AICPU_BLOCK_DIM_A2); uint64_t tilingKey = MoeDistributeDispatchA2CalcTilingKey(context, isLayered); context->SetTilingKey(tilingKey); diff --git a/csrc/deepep/ops2/op_host/notify_dispatch_tiling.cc b/csrc/deepep/ops2/op_host/notify_dispatch_tiling.cc index bdd2234d7..43ce8ed9a 100644 --- a/csrc/deepep/ops2/op_host/notify_dispatch_tiling.cc +++ b/csrc/deepep/ops2/op_host/notify_dispatch_tiling.cc @@ -14,6 +14,7 @@ #include "error_log.h" #include "graph/utils/type_utils.h" #include "register/op_def_registry.h" +#include "mc2_tiling_utils.h" #include "../op_kernel/notify_dispatch_tiling.h" #include "tiling/platform/platform_ascendc.h" #include "tiling/hccl/hccl_tiling.h" @@ -28,30 +29,6 @@ using namespace ge; namespace { -class Mc2TilingUtils -{ -public: - static uint64_t GetMaxWindowSize() - { - uint16_t defaultWindowSize = 200; - const char *hcclBuffSize = getenv("DEEPEP_HCCL_BUFFSIZE") == nullptr ? "HCCL_BUFFSIZE" : "DEEPEP_HCCL_BUFFSIZE"; - if (getenv(hcclBuffSize) == nullptr) { - OP_LOGD("", "Env HCCL_BUFFSIZE don't set"); - } else { - try { - std::string envStr(getenv(hcclBuffSize)); - defaultWindowSize = std::stoi(envStr); - } catch (const std::invalid_argument &ia) { - OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); - } catch (const std::out_of_range &oor) { - OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); - } - } - const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; - OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize); - return maxWindowSize; - } -}; constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll constexpr uint32_t INPUT_SEND_DATA_INDEX = 0; @@ -109,6 +86,7 @@ static void PrintTilingDataInfo(const char *nodeName, NotifyDispatchTilingData & OP_LOGD(nodeName, "perRoundTokens is %u.", tilingData.notifyDispatchInfo.perRoundTokens); OP_LOGD(nodeName, "aivNum is %u.", tilingData.notifyDispatchInfo.aivNum); OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.notifyDispatchInfo.totalUbSize); + OP_LOGD(nodeName, "totalWinSize is %lu.", tilingData.notifyDispatchInfo.totalWinSize); } static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, @@ -304,6 +282,7 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %luMB.", actualSize / MB_SIZE); return false; } + tilingData->notifyDispatchInfo.totalWinSize = maxWindowSize; return true; } diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.cpp b/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.cpp index 3250967bf..2f0bb6f1f 100644 --- a/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.cpp @@ -16,13 +16,13 @@ extern "C" { aclnnStatus aclnnCamMoeCombineNormalGetWorkspaceSize( const aclTensor *recvX, const aclTensor *tokenSrcInfo, const aclTensor *epRecvCounts, - const aclTensor *recvTopkWeights, const aclTensor *topkIdx, const aclTensor *tpRecvCountsOptional, + const aclTensor *recvTopkWeights, const aclTensor *tokenIdx, const aclTensor *tpRecvCountsOptional, char *epGroupName, int64_t epWorldSize, int64_t epRankId, char *tpGroupNameOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t realMaxBs, int32_t round, int32_t per_round_tokens, const aclTensor *out, const aclTensor *sendCostStats, uint64_t *workspaceSize, aclOpExecutor **executor) { return aclnnInnerCamMoeCombineNormalGetWorkspaceSize( - recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights, topkIdx, tpRecvCountsOptional, epGroupName, epWorldSize, + recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights, tokenIdx, tpRecvCountsOptional, epGroupName, epWorldSize, epRankId, tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum, realMaxBs, round, per_round_tokens, out, sendCostStats, workspaceSize, executor); } diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.h b/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.h index 0e8446bd5..01799fdbf 100644 --- a/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.h +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_cam_moe_combine_normal.h @@ -12,6 +12,7 @@ extern "C" { * tokenSrcInfo : required * epRecvCounts : required * recvTopkWeights : required + * tokenIdx : required * tpRecvCountsOptional : required * epGroupName : optional * epWorldSize : required @@ -27,7 +28,7 @@ extern "C" { */ __attribute__((visibility("default"))) aclnnStatus aclnnCamMoeCombineNormalGetWorkspaceSize( const aclTensor *recvX, const aclTensor *tokenSrcInfo, const aclTensor *epRecvCounts, - const aclTensor *recvTopkWeights, const aclTensor *topkIdx, const aclTensor *tpRecvCountsOptional, + const aclTensor *recvTopkWeights, const aclTensor *tokenIdx, const aclTensor *tpRecvCountsOptional, char *epGroupName, int64_t epWorldSize, int64_t epRankId, char *tpGroupNameOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t realMaxBs, int32_t round, int32_t per_round_tokens, const aclTensor *out, const aclTensor *sendCostStats, uint64_t *workspaceSize, aclOpExecutor **executor); diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h index 0c7d4189c..4bd2b5c60 100644 --- a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h @@ -20,6 +20,7 @@ extern "C" { * numTokensPerExpert : required * isTokenInRank : required * notifySendData : required + * sendTokenIdxSmall : required * workspaceSize : size of workspace(output). * executor : executor context(output). */ diff --git a/csrc/deepep/ops2/op_host/tiling_args.h b/csrc/deepep/ops2/op_host/tiling_args.h index 950cbe904..9230c9039 100644 --- a/csrc/deepep/ops2/op_host/tiling_args.h +++ b/csrc/deepep/ops2/op_host/tiling_args.h @@ -3,7 +3,8 @@ #include namespace Moe { -constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; -constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL; +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 8U * 1024UL * 1024UL; +constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 404U * 1024UL * 1024UL; +constexpr int64_t STATE_SIZE = 2 * 1024 * 1024; } // namespace Moe #endif // TILING_ARGS_H diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.cpp b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.cpp index bed704181..ab37cda91 100644 --- a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.cpp @@ -1,12 +1,15 @@ #include "kernel_operator.h" #include "lib/matmul_intf.h" #include "cam_moe_combine_normal.h" +#include "cam_moe_combine_normal_multi_round.h" #include "cam_moe_combine_normal_tiling.h" using namespace AscendC; -using namespace CamMoeCombineNormalImpl; + +#define TILINGKEY_MULTI_ROUND 10001 +#define TILINGKEY_SINGLE_ROUND 10000 extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR tpRecvCount, + GM_ADDR topkWeights, GM_ADDR tokenIdx, GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, GM_ADDR tilingGM) @@ -16,9 +19,16 @@ extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_A #if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16) GET_TILING_DATA_WITH_STRUCT(CamMoeCombineNormalTilingData, tilingData, tilingGM); - CamMoeCombineNormal op; - op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, topkIdx, tpRecvCount, XOut, sendCostStatsOut, workspaceGM, - &pipe, tilingGM); - op.Process(); + if (TILING_KEY_IS(TILINGKEY_MULTI_ROUND)) { + CamMoeCombineNormalMultiRoundImpl::CamMoeCombineNormalMultiRound op; + op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tokenIdx, tpRecvCount, XOut, sendCostStatsOut, + workspaceGM, &pipe, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(TILINGKEY_SINGLE_ROUND)) { + CamMoeCombineNormalImpl::CamMoeCombineNormal op; + op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tokenIdx, tpRecvCount, XOut, sendCostStatsOut, + workspaceGM, &pipe, tilingGM); + op.Process(); + } #endif } diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.h b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.h index 454ae61ad..05667a17b 100644 --- a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.h +++ b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal.h @@ -11,7 +11,7 @@ namespace CamMoeCombineNormalImpl { constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U; constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U; constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U; -constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL; +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 8UL * 1024UL * 1024UL; constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL; constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U; constexpr uint32_t UB_32_ALIGN = 32U; @@ -39,14 +39,14 @@ class CamMoeCombineNormal public: __aicore__ inline CamMoeCombineNormal(){}; __aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR topkIdx, GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, + GM_ADDR tokenIdx, GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tiling); __aicore__ inline void Process(); private: __aicore__ inline void InitMagic(); __aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR topkIdx, GM_ADDR XOut, + GM_ADDR topkWeights, GM_ADDR tokenIdx, GM_ADDR XOut, GM_ADDR sendCostStatsOut); __aicore__ inline void InitTilingData(__gm__ CamMoeCombineNormalTilingData *tilingData); __aicore__ inline void InitBuffLen(); @@ -99,12 +99,13 @@ class CamMoeCombineNormal uint64_t winDataSizeOffset_{0}; uint32_t selfSendCnt_{0}; uint32_t hRecvXTypeLen_{0}; + uint32_t tokenIdx32AlignLen_{0}; uint32_t h32AlignFloatLen_{0}; uint32_t h256AlignFloatLen_{0}; uint32_t h32AlignRecvXLen_{0}; uint32_t h512AlignRecvXLen_{0}; uint32_t sendCostStatsBufSize_{0}; - uint32_t tokenIdx32AlignLen_{0}; + uint64_t totalWinSize_{0}; bool isEnableDiagnose_{false}; @@ -113,8 +114,8 @@ class CamMoeCombineNormal TQue sendCostStatsOutQueue_; TQueBind localCopyQueue_; TBuf<> stateBuf_; + TBuf<> tokenIdxBuf_; TBuf<> topkWeightsBuf_; - TBuf<> topkIdxBuf_; TBuf<> tokenFloatBuf_; TBuf<> sumFloatBuf_; TBuf<> weightedMulBuf_; @@ -125,8 +126,8 @@ class CamMoeCombineNormal GlobalTensor recvXGM_; GlobalTensor tokenSrcInfoGM_; GlobalTensor epRecvCountGM_; + GlobalTensor tokenIdxGM_; GlobalTensor topkWeightsGM_; - GlobalTensor topkIdxGM_; GlobalTensor xOutGlobal_; GlobalTensor sendCostStatsGT_; GM_ADDR localRankGM_; @@ -137,7 +138,7 @@ template __aicore__ inline void CamMoeCombineNormal::InitMagic() { GlobalTensor selfMagicTensor; - selfMagicTensor.SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsInAddr(epRankId_) + epWinContext_->winSize - + selfMagicTensor.SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsInAddr(epRankId_) + totalWinSize_ - Moe::STATE_SIZE + MAGIC_WIN_OFFSET + coreIdx_ * WIN_512_ALIGN)); DataCacheCleanAndInvalid(selfMagicTensor); magic_ = selfMagicTensor(0); @@ -148,15 +149,15 @@ __aicore__ inline void CamMoeCombineNormal::InitMagic() template __aicore__ inline void CamMoeCombineNormal::InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR topkIdx, + GM_ADDR topkWeights, GM_ADDR tokenIdx, GM_ADDR XOut, GM_ADDR sendCostStatsOut) { recvXGM_.SetGlobalBuffer((__gm__ RecvXType *)recvX); tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType *)tokenSrcInfo); epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t *)epRecvCount); + tokenIdxGM_.SetGlobalBuffer((__gm__ int32_t *)tokenIdx); topkWeightsGM_.SetGlobalBuffer((__gm__ float *)topkWeights); - topkIdxGM_.SetGlobalBuffer((__gm__ int32_t *)topkIdx); xOutGlobal_.SetGlobalBuffer((__gm__ XType *)XOut); if (isEnableDiagnose_) { sendCostStatsGT_.SetGlobalBuffer((__gm__ int32_t *)sendCostStatsOut); @@ -176,6 +177,7 @@ CamMoeCombineNormal::InitTilingData(__gm__ CamMoeCombineNor epWorldSize_ = tilingData->camMoeCombineNormalInfo.epWorldSize; epRankId_ = tilingData->camMoeCombineNormalInfo.epRankId; isEnableDiagnose_ = tilingData->camMoeCombineNormalInfo.isEnableDiagnose; + totalWinSize_ = tilingData->camMoeCombineNormalInfo.totalWinSize; } template @@ -195,7 +197,7 @@ __aicore__ inline void CamMoeCombineNormal::InitBuffLen() template __aicore__ inline void CamMoeCombineNormal::Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR topkIdx, GM_ADDR tpRecvCount, + GM_ADDR tokenIdx, GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tiling) { @@ -214,11 +216,11 @@ __aicore__ inline void CamMoeCombineNormal::Init(GM_ADDR re epWinContext_ = (__gm__ HcclOpResParam *)contextGM0; InitTilingData(tilingData); InitMagic(); - InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, topkIdx, XOut, sendCostStatsOut); + InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, tokenIdx, XOut, sendCostStatsOut); InitBuffLen(); PipeBarrier(); - winDataSizeOffset_ = static_cast(magic_) * (tilingData->camMoeCombineNormalInfo.totalWinSize / 2UL); + winDataSizeOffset_ = static_cast(magic_) * ((totalWinSize_ - 4 * Moe::STATE_SIZE) / 2UL); localRankGM_ = GetBufferAddrByRankId(epRankId_); DataCacheCleanAndInvalid( epRecvCountGM_[moeExpertNum_ - 1]); @@ -240,7 +242,7 @@ __aicore__ inline void CamMoeCombineNormal::CopyBufferToSha tpipe_->InitBuffer(stateBuf_, UB_32_ALIGN); tpipe_->InitBuffer(localCopyQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_); tpipe_->InitBuffer(srcInfoBuf_, blockLen); - LocalTensor statusTensor = stateBuf_.AllocTensor(); + LocalTensor statusTensor = stateBuf_.Get(); Duplicate(statusTensor, 0x3F800000, FLOAT_NUM_PER_ALIGN); LocalTensor srcInfoLocal = srcInfoBuf_.Get(); @@ -313,7 +315,7 @@ __aicore__ inline void CamMoeCombineNormal::SetStatusBySrcI uint32_t srcTokenId, uint32_t srcTopkId) { - LocalTensor statusTensor = stateBuf_.AllocTensor(); + LocalTensor statusTensor = stateBuf_.Get(); GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN; GlobalTensor stateGMTensor; stateGMTensor.SetGlobalBuffer((__gm__ uint32_t *)stateGM); @@ -324,16 +326,16 @@ template __aicore__ inline void CamMoeCombineNormal::WaitBuffCopy(uint32_t tokenIndex, uint32_t startTokenIndex) { - uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN; - LocalTensor topkIdxTensorLocal = topkIdxBuf_.Get(); + LocalTensor tokenIdxLocal = tokenIdxBuf_.Get(); int tempValidCount = 0; for (int topkId = 0; topkId < axisK_; ++topkId) { - int expertId = topkIdxTensorLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId); + int expertId = tokenIdxLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId); if (expertId < 0 || expertId >= moeExpertNum_) { continue; } ++tempValidCount; } + uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN; GM_ADDR stateGM = GetStateAddrByRankId(epRankId_) + tokenIndex * axisK_ * UB_32_ALIGN; // 计算地址偏移 GlobalTensor stateGMTensor; stateGMTensor.SetGlobalBuffer((__gm__ float *)stateGM); @@ -364,18 +366,17 @@ __aicore__ inline void CamMoeCombineNormal::ReadBufferAndWe LocalTensor weightedMulBufLocal = weightedMulBuf_.Get(); LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); LocalTensor topkWeightsLocal = topkWeightsBuf_.Get(); - LocalTensor topkIdxTensorLocal = topkIdxBuf_.Get(); LocalTensor stateTensorLocal = stateBuf_.Get(); + LocalTensor tokenIdxLocal = tokenIdxBuf_.Get(); Duplicate(sumFloatBufLocal, static_cast(0), axisH_); const DataCopyExtParams xOutCopyParams{1U, static_cast(hRecvXTypeLen_), 0U, 0U, 0U}; for (uint32_t topkId = 0U; topkId < axisK_; topkId++) { - uint32_t topkIdOffset = (tokenIndex - startTokenIndex) * axisK_ + topkId; - int32_t expertId = topkIdxTensorLocal.GetValue(topkIdOffset); + int expertId = tokenIdxLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId); if (expertId < 0 || expertId >= moeExpertNum_) { continue; } - float scale = topkWeightsLocal.GetValue(topkIdOffset); + float scale = topkWeightsLocal.GetValue((tokenIndex - startTokenIndex) * axisK_ + topkId); GM_ADDR localTokenAddr = localRankGM_ + (tokenIndex * axisK_ + topkId) * h512AlignRecvXLen_; GlobalTensor localTokenTensor; localTokenTensor.SetGlobalBuffer((__gm__ XType *)localTokenAddr); @@ -396,6 +397,7 @@ __aicore__ inline void CamMoeCombineNormal::ReadBufferAndWe LocalTensor xOutLocal = xOutBuf_.Get(); Cast(xOutLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, axisH_); SyncFunc(); + SyncFunc(); DataCopyPad(xOutGlobal_[tokenIndex * axisH_], xOutLocal, xOutCopyParams); } @@ -407,12 +409,13 @@ __aicore__ inline void CamMoeCombineNormal::ReadBufferFromR } uint32_t tokenPerBlock = 0U, startTokenIndex = 0U, endTokenIndex = 0U; SplitCoreCal(axisBS_, tokenPerBlock, startTokenIndex, endTokenIndex); - tokenIdx32AlignLen_ = Ceil(tokenPerBlock * axisK_ * sizeof(int32_t), UB_32_ALIGN) * UB_32_ALIGN; if (tokenPerBlock == 0U) { return; } + tokenIdx32AlignLen_ = Ceil(tokenPerBlock * axisK_ * sizeof(int32_t), UB_32_ALIGN) * UB_32_ALIGN; + tpipe_->Reset(); tpipe_->InitBuffer(xOutBuf_, h32AlignRecvXLen_); tpipe_->InitBuffer(tokenFloatBuf_, h32AlignFloatLen_); @@ -422,16 +425,16 @@ __aicore__ inline void CamMoeCombineNormal::ReadBufferFromR tpipe_->InitBuffer(stateBuf_, (axisK_)*UB_32_ALIGN); tpipe_->InitBuffer(tempStateBuf_, (axisK_)*UB_32_ALIGN); tpipe_->InitBuffer(topkWeightsBuf_, tokenPerBlock * axisK_ * sizeof(float)); - tpipe_->InitBuffer(topkIdxBuf_, tokenIdx32AlignLen_); + tpipe_->InitBuffer(tokenIdxBuf_, tokenIdx32AlignLen_); LocalTensor topkWeightsLocal = topkWeightsBuf_.Get(); - LocalTensor topkIdxTensorLocal = topkIdxBuf_.Get(); + LocalTensor tokenIdxLocal = tokenIdxBuf_.Get(); const DataCopyExtParams bskParams{1U, static_cast(tokenPerBlock * axisK_ * sizeof(float)), 0U, 0U, 0U}; const DataCopyExtParams tokenIdxParams{1U, tokenIdx32AlignLen_, 0U, 0U, 0U}; const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; const DataCopyPadExtParams copyPadIntParams{false, 0U, 0U, 0U}; DataCopyPad(topkWeightsLocal, topkWeightsGM_[startTokenIndex * axisK_], bskParams, copyPadFloatParams); - DataCopyPad(topkIdxTensorLocal, topkIdxGM_[startTokenIndex * axisK_], tokenIdxParams, copyPadIntParams); + DataCopyPad(tokenIdxLocal, tokenIdxGM_[startTokenIndex * axisK_], tokenIdxParams, copyPadIntParams); SyncFunc(); for (uint32_t tokenIndex = startTokenIndex; tokenIndex < endTokenIndex; tokenIndex++) { diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_multi_round.h b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_multi_round.h new file mode 100644 index 000000000..71a80888b --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_multi_round.h @@ -0,0 +1,649 @@ +#ifndef CAM_MOE_COMBINE_NORMAL_MULTI_ROUND_H +#define CAM_MOE_COMBINE_NORMAL_MULTI_ROUND_H + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "moe_distribute_base.h" +#include "cam_moe_combine_normal_tiling.h" +#include "comm_args.h" + +namespace CamMoeCombineNormalMultiRoundImpl { +constexpr uint32_t RANK_ID_OFFSET_IN_SRC_INFO = 0U; +constexpr uint32_t TOKEN_IDX_OFFSET_IN_SRC_INFO = 1U; +constexpr uint32_t TOPK_IDX_OFFSET_IN_SRC_INFO = 2U; +constexpr uint64_t STATE_WIN_SIZE = 8UL * 1024UL * 1024UL; +constexpr uint64_t STATE_WIN_SIZE_HALF = STATE_WIN_SIZE / 2; +constexpr uint64_t MAGIC_WIN_OFFSET = 975UL * 1024UL; +constexpr uint64_t ROUND_STATE_OFFSET = Moe::BASE_ROUND_STATE_OFFSET + Moe::ROUND_STATE_MAX_SIZE * 2UL; // 458*1024 +constexpr uint32_t TOKEN_SRC_INFO_LEN = 3U; +constexpr uint32_t UB_32_ALIGN = 32U; +constexpr uint32_t MUL_256_ALIGN = 256U; +constexpr uint64_t WIN_512_ALIGN = 512UL; +constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U; +constexpr uint8_t DOUBLE_BUFFER = 2; +constexpr uint32_t WAIT_ROUND_INDEX = 2U; +constexpr int64_t CYCLE_TO_TIME = 50; // cycle num is converted into a fixed base unit of time, set at 50 +constexpr uint32_t STATE_OFFSET = 32U; +constexpr uint32_t BATCH_SRC_INFO_CNT = 128U; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +#define TemplateMC2TypeClass typename RecvXType, typename XType, typename SrcInfoType +#define TemplateMC2TypeFunc RecvXType, XType, SrcInfoType + +using namespace AscendC; +template +class CamMoeCombineNormalMultiRound +{ +public: + __aicore__ inline CamMoeCombineNormalMultiRound(){}; + __aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, + GM_ADDR tokenIdx, GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, + GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tiling); + __aicore__ inline void Process(); + +private: + __aicore__ inline void InitMagic(); + __aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, + GM_ADDR topkWeights, GM_ADDR tokenIdx, GM_ADDR XOut, + GM_ADDR sendCostStatsOut); + __aicore__ inline void InitTilingData(__gm__ CamMoeCombineNormalTilingData *tilingData); + __aicore__ inline void InitBuffLen(); + __aicore__ inline void CopyBufferToShareAndSetStatus(); + __aicore__ inline void CopyBufferToShare(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId, + uint32_t tkIndex); + __aicore__ inline void ReadBufferFromRemote(); + __aicore__ inline void WaitBuffCopy(uint32_t recvXTokenIdx, uint32_t topkWeightTokenIdx); + __aicore__ inline void SetStatusBySrcInfo(uint32_t srcRankId, uint32_t srcTokenId, uint32_t srcTopkId); + __aicore__ inline void ReadBufferAndWeightedSum(uint32_t recvXTokenIdx, uint32_t topkWeightTokenIdx); + __aicore__ inline void InitRoundSendData(); + __aicore__ inline void SetRoundStatus(); + __aicore__ inline void WaitRoundStatus(); + __aicore__ inline void InitRoundRecvData(); + + __aicore__ GM_ADDR GetStateAddrByRankId(const int32_t rankId) + { + return hccl_.GetWindowsInAddr(rankId) + winDataSizeOffset_ + Moe::NOTIFY_DISPATCH_BUFF_OFFSET; + } + + __aicore__ GM_ADDR GetBufferAddrByRankId(const int32_t rankId) + { + return GetStateAddrByRankId(rankId) + STATE_WIN_SIZE + roundMagic_ * combineDataBuffSize_; + } + + __aicore__ inline GM_ADDR GetRoundStateAddrByRankId(const int32_t rankId) + { + return hccl_.GetWindowsInAddr(rankId) + totalWinSize_ - Moe::STATE_SIZE + + roundMagic_ * Moe::ROUND_STATE_MAX_SIZE + ROUND_STATE_OFFSET; + } + + __aicore__ inline void SplitCoreCal(uint32_t totalNum, uint32_t &perCoreNum, uint32_t &startIdx, uint32_t &endIdx) + { + perCoreNum = totalNum / aivNum_; + uint32_t remainderRankNum = totalNum % aivNum_; + + startIdx = perCoreNum * coreIdx_; + if (coreIdx_ < remainderRankNum) { + perCoreNum++; + startIdx += coreIdx_; + } else { + startIdx += remainderRankNum; + } + endIdx = startIdx + perCoreNum; + } + + Hccl hccl_; + uint32_t axisBS_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t coreIdx_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertPerRankNum_{0}; + uint32_t magic_{0}; + uint32_t roundMagic_{0}; + uint64_t winDataSizeOffset_{0}; + uint32_t hRecvXTypeLen_{0}; + uint32_t h32AlignFloatLen_{0}; + uint32_t h256AlignFloatLen_{0}; + uint32_t h32AlignRecvXLen_{0}; + uint32_t h512AlignRecvXLen_{0}; + uint32_t tokenIdx32AlignLen_{0}; + uint32_t roundIndex_{0}; + uint32_t realMaxBs_{0}; + uint32_t perRoundTokens_{0}; + uint64_t totalWinSize_{0}; + uint32_t maxRound_{0}; + // send用到的数据 + uint32_t sendCostStatsBufSize_{0}; + uint32_t needSendTokenCnt_{0}; + uint32_t RecvTokenNum_{0}; + uint32_t perCoreBlockNum_{0}; // 每个core需要负责的block数,一个block表示某个expert从某个rank接收的token + uint32_t startBlockId_{0}; + uint32_t endBlockId_{0}; + uint32_t preRecvCount_{0}; + // recv用到的数据 + uint32_t totalNeedRecvTokenCnt_{0}; // 剩余需要接收的token数,初始化为axisBS_ + uint32_t roundTotalRecvTokenCnt_{0}; // 每一轮所有核需要接收的总token数 + uint32_t roundRecvTokenCnt_{0}; // 每一轮每个核接收的token数,每一轮接收开始前重新计算 + uint32_t roundRecvStartTokenIdx_{0}; // 每一轮每个核从HCCL buffer接收的token的起始index,每一轮接收开始前重新计算 + uint32_t roundRecvEndTokenIdx_{0}; // 每一轮每个核从HCCL buffer接收的token的结束index,每一轮接收开始前重新计算 + // 这一轮接收的token需要存放在xOut的偏移,即前面几轮接收的token数,每一轮每个核从topkWeightsGM_拷贝权重也需要 + uint32_t xOutTokenOffset_{0}; + uint32_t stateOffset_{0}; + uint32_t combineDataBuffSize_{0}; + + bool isEnableDiagnose_{false}; + + TPipe *tpipe_{nullptr}; + TQue weightedSumQueue_; + TQue sendCostStatsOutQueue_; + TQueBind localCopyQueue_; + TBuf<> setStateBuf_; + TBuf<> waitStateBuf_; + TBuf<> waitTempStateBuf_; + TBuf<> topkWeightsBuf_; + TBuf<> tokenFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> weightedMulBuf_; + TBuf<> srcInfoBuf_; + TBuf<> tokenIdxBuf_; + TBuf<> xOutBuf_; + TBuf<> setRoundStateBuf_; + TBuf<> waitRoundStateBuf_; + TBuf<> tempRoundStateBuf_; + TBuf<> roundNeedSendCntBuf_; + TBuf<> roundSendOffsetBuf_; + TBuf<> tempRecvCountBuf_; + + LocalTensor setStateLT_; + LocalTensor roundNeedSendCntLT_; + LocalTensor roundSendOffsetLT_; + LocalTensor srcInfoLT_; + LocalTensor tokenIdxLT_; + LocalTensor topkWeightsLT_; + + GlobalTensor recvXGM_; + GlobalTensor tokenSrcInfoGM_; + GlobalTensor epRecvCountGM_; + GlobalTensor tokenIdxGM_; + GlobalTensor topkWeightsGM_; + GlobalTensor xOutGlobal_; + GlobalTensor sendCostStatsGT_; + GlobalTensor dstRoundStatusGT_; + GM_ADDR workspaceGM_; +}; + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::InitMagic() +{ + GlobalTensor selfMagicTensor; + selfMagicTensor.SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsInAddr(epRankId_) + totalWinSize_ - + Moe::STATE_SIZE + MAGIC_WIN_OFFSET + coreIdx_ * WIN_512_ALIGN)); + DataCacheCleanAndInvalid(selfMagicTensor); + magic_ = selfMagicTensor(0); + selfMagicTensor(0) = ((magic_ == 0) ? 1 : 0); + DataCacheCleanAndInvalid(selfMagicTensor); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::InitGlobalBuffer( + GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR tokenIdx, GM_ADDR XOut, + GM_ADDR sendCostStatsOut) +{ + recvXGM_.SetGlobalBuffer((__gm__ RecvXType *)recvX); + tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType *)tokenSrcInfo); + epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t *)epRecvCount); + tokenIdxGM_.SetGlobalBuffer((__gm__ int32_t *)tokenIdx); + topkWeightsGM_.SetGlobalBuffer((__gm__ float *)topkWeights); + xOutGlobal_.SetGlobalBuffer((__gm__ XType *)XOut); + if (isEnableDiagnose_) { + sendCostStatsGT_.SetGlobalBuffer((__gm__ int32_t *)sendCostStatsOut); + } +} + +template +__aicore__ inline void +CamMoeCombineNormalMultiRound::InitTilingData(__gm__ CamMoeCombineNormalTilingData *tilingData) +{ + axisBS_ = tilingData->camMoeCombineNormalInfo.bs; + axisH_ = tilingData->camMoeCombineNormalInfo.h; + axisK_ = tilingData->camMoeCombineNormalInfo.k; + aivNum_ = tilingData->camMoeCombineNormalInfo.aivNum; + moeExpertNum_ = tilingData->camMoeCombineNormalInfo.moeExpertNum; + moeExpertPerRankNum_ = tilingData->camMoeCombineNormalInfo.moeExpertPerRankNum; + epWorldSize_ = tilingData->camMoeCombineNormalInfo.epWorldSize; + epRankId_ = tilingData->camMoeCombineNormalInfo.epRankId; + isEnableDiagnose_ = tilingData->camMoeCombineNormalInfo.isEnableDiagnose; + realMaxBs_ = tilingData->camMoeCombineNormalInfo.realMaxBs; + maxRound_ = tilingData->camMoeCombineNormalInfo.maxRound; + perRoundTokens_ = tilingData->camMoeCombineNormalInfo.perRoundTokens; + totalWinSize_ = tilingData->camMoeCombineNormalInfo.totalWinSize; +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::InitBuffLen() +{ + uint32_t hFloatSize = axisH_ * static_cast(sizeof(float)); + h32AlignFloatLen_ = Ceil(hFloatSize, UB_32_ALIGN) * UB_32_ALIGN; + h256AlignFloatLen_ = Ceil(hFloatSize, MUL_256_ALIGN) * MUL_256_ALIGN; + hRecvXTypeLen_ = axisH_ * sizeof(RecvXType); + h32AlignRecvXLen_ = Ceil(hRecvXTypeLen_, UB_32_ALIGN) * UB_32_ALIGN; + h512AlignRecvXLen_ = Ceil(hRecvXTypeLen_, WIN_512_ALIGN) * WIN_512_ALIGN; + if (isEnableDiagnose_) { + sendCostStatsBufSize_ = Ceil(epWorldSize_ * sizeof(int32_t), UB_32_ALIGN) * UB_32_ALIGN; + } +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::InitRoundSendData() +{ + SplitCoreCal(moeExpertNum_, perCoreBlockNum_, startBlockId_, + endBlockId_); // 按专家分核,每个核负责向perBlockRankNum个rank发送数据 + if (perCoreBlockNum_ == 0) { + return; + } + uint32_t sendBlockLen = perCoreBlockNum_ * sizeof(int32_t); + tpipe_->Reset(); + tpipe_->InitBuffer(tempRecvCountBuf_, sendBlockLen); // 64B + tpipe_->InitBuffer(roundNeedSendCntBuf_, sendBlockLen); // 64B + tpipe_->InitBuffer(roundSendOffsetBuf_, sendBlockLen); // 64B + + // 拷贝 epRecvCountGM_ 到 UB + LocalTensor tempRecvCountTensor = tempRecvCountBuf_.Get(); + const DataCopyExtParams sendBlockCopyParams{1U, sendBlockLen, 0U, 0U, 0U}; + const DataCopyPadExtParams sendBlockPadParams{false, 0U, 0U, 0U}; + DataCopyPad(tempRecvCountTensor, epRecvCountGM_[startBlockId_], sendBlockCopyParams, sendBlockPadParams); + SyncFunc(); + + // 每个核计算需要给每个专家发送的token数以及token起始偏移,以及每个block的srcInfo偏移 + preRecvCount_ = startBlockId_ == 0 ? 0 : epRecvCountGM_(startBlockId_ - 1); // 记录当前core发送token的起始偏移 + needSendTokenCnt_ = tempRecvCountTensor(perCoreBlockNum_ - 1) - preRecvCount_; + roundNeedSendCntLT_ = roundNeedSendCntBuf_.Get(); + roundSendOffsetLT_ = roundSendOffsetBuf_.Get(); + roundSendOffsetLT_(0) = preRecvCount_; + roundNeedSendCntLT_(0) = tempRecvCountTensor(0) - preRecvCount_; + for (uint32_t i = 1; i < perCoreBlockNum_; ++i) { + roundSendOffsetLT_(i) = tempRecvCountTensor(i - 1); + roundNeedSendCntLT_(i) = tempRecvCountTensor(i) - tempRecvCountTensor(i - 1); + } + + // 创建 srcInfoLT_ + // 为了支持一轮最大8192 bs,这里按照一批BATCH_SRC_INFO_LEN个srcInfo拷贝,这样可以保证UB占用少 + uint32_t srcInfoLen = static_cast(BATCH_SRC_INFO_CNT * TOKEN_SRC_INFO_LEN * sizeof(SrcInfoType)); + tpipe_->InitBuffer(srcInfoBuf_, srcInfoLen); // 128*3*4/1024=1.5KB + srcInfoLT_ = srcInfoBuf_.Get(); + + // 创建 setStatusLT_ + tpipe_->InitBuffer(setStateBuf_, UB_32_ALIGN); // 32B + setStateLT_ = setStateBuf_.Get(); + Duplicate(setStateLT_, 0x3F800000, FLOAT_NUM_PER_ALIGN); + + // 创建localCopyQueue_, 用于存放从GM拷贝到UB的token + tpipe_->InitBuffer(localCopyQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_); // 28KB +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::InitRoundRecvData() +{ + totalNeedRecvTokenCnt_ = axisBS_; + + // 每个核一轮最多需要接收Ceil(perRoundTokens_, aivNum_) * aivNum_个token,topkWeightBuf_也只需要开这么大 + tpipe_->InitBuffer(xOutBuf_, h32AlignRecvXLen_); // 14KB + tpipe_->InitBuffer(tokenFloatBuf_, h32AlignFloatLen_); // 28KB + tpipe_->InitBuffer(weightedMulBuf_, h256AlignFloatLen_); // 28KB + tpipe_->InitBuffer(sumFloatBuf_, h32AlignFloatLen_); // 28KB + tpipe_->InitBuffer(weightedSumQueue_, DOUBLE_BUFFER, h32AlignRecvXLen_); // 14KB + tpipe_->InitBuffer(waitStateBuf_, axisK_ * UB_32_ALIGN); // 196B + tpipe_->InitBuffer(waitTempStateBuf_, axisK_ * UB_32_ALIGN); // 196B + tpipe_->InitBuffer(setRoundStateBuf_, epWorldSize_ * FLOAT_NUM_PER_ALIGN * sizeof(float)); // 用于setRoundStatus + tpipe_->InitBuffer(waitRoundStateBuf_, epWorldSize_ * FLOAT_NUM_PER_ALIGN * sizeof(float)); // 用于waitRoundStatus + tpipe_->InitBuffer(tempRoundStateBuf_, epWorldSize_ * FLOAT_NUM_PER_ALIGN * sizeof(float)); // 用于waitRoundStatus + + // 创建topkWeightsLT_,存放每一轮每个核的权重信息 + uint32_t maxTopkWeightsLen = (perRoundTokens_ / aivNum_ + 1) * axisK_ * sizeof(float); + tokenIdx32AlignLen_ = (perRoundTokens_ / aivNum_ + 1) * axisK_ * sizeof(int32_t); + tpipe_->InitBuffer(tokenIdxBuf_, tokenIdx32AlignLen_); + tpipe_->InitBuffer(topkWeightsBuf_, maxTopkWeightsLen); // 512 分48核 需要352B + tokenIdxLT_ = tokenIdxBuf_.Get(); + topkWeightsLT_ = topkWeightsBuf_.Get(); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::Init( + GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR tokenIdx, + GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tiling) +{ + workspaceGM_ = workspaceGM; + tpipe_ = pipe; + coreIdx_ = GetBlockIdx(); + stateOffset_ = STATE_OFFSET; + + auto tilingData = (__gm__ CamMoeCombineNormalTilingData *)tiling; + __gm__ void *mc2InitTiling = (__gm__ void *)(&(tilingData->mc2InitTiling)); + __gm__ void *mc2CcTiling = (__gm__ void *)(&(tilingData->mc2CcTiling1)); + + auto contextGM0 = AscendC::GetHcclContext(); + + hccl_.Init(contextGM0, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + + InitTilingData(tilingData); + InitMagic(); + InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, tokenIdx, XOut, sendCostStatsOut); + InitBuffLen(); + combineDataBuffSize_ = perRoundTokens_ * axisK_ * h512AlignRecvXLen_; + PipeBarrier(); + winDataSizeOffset_ = static_cast(magic_) * ((totalWinSize_ - 4 * Moe::STATE_SIZE) / 2UL); + DataCacheCleanAndInvalid( + epRecvCountGM_[moeExpertNum_ - 1]); + + InitRoundSendData(); + InitRoundRecvData(); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::CopyBufferToShareAndSetStatus() +{ + if (needSendTokenCnt_ == 0) { + return; + } + LocalTensor sendCostStatsTensor; + if (isEnableDiagnose_) { + tpipe_->InitBuffer(sendCostStatsOutQueue_, DOUBLE_BUFFER, sendCostStatsBufSize_); + sendCostStatsTensor = sendCostStatsOutQueue_.AllocTensor(); + Duplicate(sendCostStatsTensor, 0, sendCostStatsBufSize_ / sizeof(int32_t)); + } + + uint32_t startTokenIndex = preRecvCount_; + int64_t sendStartCycle; + for (uint32_t blockIndex = 0; blockIndex < perCoreBlockNum_; ++blockIndex) { + uint32_t roundMaxSendCount = roundNeedSendCntLT_(blockIndex) >= perRoundTokens_ + ? perRoundTokens_ + : roundNeedSendCntLT_(blockIndex); // 这一轮最多发送 roundMaxSendCount 个token + + uint32_t sendTokenOffset = roundSendOffsetLT_(blockIndex); + uint32_t startTokenId = sendTokenOffset; // 这一轮要发的token在recvX中的偏移 + uint32_t roundActualSendCount = 0; // 这一轮blockIndex实际发送的token数,<= roundMaxSendCount + while (roundActualSendCount < roundMaxSendCount) { + uint32_t recvXTokenIdx = startTokenId + roundActualSendCount; // 要发送的token在recvX中的位置 + uint32_t tokenIdxInBatch = roundActualSendCount % BATCH_SRC_INFO_CNT; + if (tokenIdxInBatch == 0) { + uint32_t tokenCount = min(BATCH_SRC_INFO_CNT, roundMaxSendCount - roundActualSendCount); + uint32_t srcInfoLen = tokenCount * TOKEN_SRC_INFO_LEN * sizeof(SrcInfoType); + const DataCopyExtParams dataCopyParams{1U, srcInfoLen, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(srcInfoLT_, tokenSrcInfoGM_[recvXTokenIdx * TOKEN_SRC_INFO_LEN], dataCopyParams, padParams); + SyncFunc(); + } + // 要发送的token的srcInfo在srcInfoLT_中的位置,所以起始偏移为0 + uint32_t srcInfoIdx = tokenIdxInBatch * TOKEN_SRC_INFO_LEN; + uint32_t srcRankId = static_cast(srcInfoLT_(srcInfoIdx + RANK_ID_OFFSET_IN_SRC_INFO)); + uint32_t srcTokenId = static_cast(srcInfoLT_(srcInfoIdx + TOKEN_IDX_OFFSET_IN_SRC_INFO)); + if (srcTokenId >= (roundIndex_ + 1) * perRoundTokens_) { + // 这一轮实际发送的token数,接收方一轮最多接收perRoundTokens_个token + break; + } + uint32_t srcTopkId = static_cast(srcInfoLT_(srcInfoIdx + TOPK_IDX_OFFSET_IN_SRC_INFO)); + // 每一轮 put token和state 到目标rank的hccl buffer的偏移都要从0开始计算 + uint32_t roundTokenId = srcTokenId % perRoundTokens_; + if (isEnableDiagnose_) { + sendStartCycle = GetSystemCycle(); + } + CopyBufferToShare(srcRankId, roundTokenId, srcTopkId, recvXTokenIdx); + PipeBarrier(); + SetStatusBySrcInfo(srcRankId, roundTokenId, srcTopkId); + + if (isEnableDiagnose_) { + SyncFunc(); + int32_t durationTime = static_cast((GetSystemCycle() - sendStartCycle) / CYCLE_TO_TIME); // us + int32_t preTime = sendCostStatsTensor.GetValue(srcRankId); + sendCostStatsTensor.SetValue(srcRankId, preTime + durationTime); + } + ++roundActualSendCount; + } + + roundSendOffsetLT_(blockIndex) += roundActualSendCount; + roundNeedSendCntLT_(blockIndex) -= roundActualSendCount; + needSendTokenCnt_ -= roundActualSendCount; + } + + if (isEnableDiagnose_) { + SyncFunc(); + AscendC::SetAtomicAdd(); + DataCopyExtParams statsCopyOutParams = {1U, static_cast(epWorldSize_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPad(sendCostStatsGT_, sendCostStatsTensor, statsCopyOutParams); + AscendC::SetAtomicNone(); + sendCostStatsOutQueue_.FreeTensor(sendCostStatsTensor); + } + SyncFunc(); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::CopyBufferToShare(uint32_t srcRankId, + uint32_t srcTokenId, + uint32_t srcTopkId, + uint32_t tkIndex) +{ + uint32_t tokenOffset = tkIndex * axisH_; + GM_ADDR dstGM = GetBufferAddrByRankId(srcRankId) + (srcTokenId * axisK_ + srcTopkId) * h512AlignRecvXLen_; + GlobalTensor dstWindow; + dstWindow.SetGlobalBuffer((__gm__ XType *)dstGM); + DataCopyExtParams xOutCopyParams{1U, static_cast(hRecvXTypeLen_), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + + LocalTensor localCopyTensor; + localCopyTensor = localCopyQueue_.AllocTensor(); + DataCopyPad(localCopyTensor, recvXGM_[tokenOffset], xOutCopyParams, copyPadExtParams); + localCopyQueue_.EnQue(localCopyTensor); + localCopyTensor = localCopyQueue_.DeQue(); + DataCopyPad(dstWindow, localCopyTensor, xOutCopyParams); + localCopyQueue_.FreeTensor(localCopyTensor); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::SetStatusBySrcInfo(uint32_t srcRankId, + uint32_t srcTokenId, + uint32_t srcTopkId) +{ + uint32_t stateOffset = roundMagic_ * STATE_WIN_SIZE_HALF; + GM_ADDR stateGM = GetStateAddrByRankId(srcRankId) + stateOffset + (srcTokenId * axisK_ + srcTopkId) * UB_32_ALIGN; + GlobalTensor stateGMTensor; + stateGMTensor.SetGlobalBuffer((__gm__ uint32_t *)stateGM); + DataCopy(stateGMTensor, setStateLT_, FLOAT_NUM_PER_ALIGN); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::WaitBuffCopy(uint32_t recvXTokenIdx, + uint32_t topkWeightTokenIdx) +{ + int tempValidCount = 0; + for (int topkId = 0; topkId < axisK_; ++topkId) { + int expertId = tokenIdxLT_.GetValue((topkWeightTokenIdx)*axisK_ + topkId); + if (expertId < 0 || expertId >= moeExpertNum_) { + continue; + } + ++tempValidCount; + } + uint32_t calCount = axisK_ * FLOAT_NUM_PER_ALIGN; + uint32_t stateOffset = roundMagic_ * STATE_WIN_SIZE_HALF; + GM_ADDR stateGM = + GetStateAddrByRankId(epRankId_) + stateOffset + recvXTokenIdx * axisK_ * UB_32_ALIGN; // 计算地址偏移 + GlobalTensor stateGMTensor; + stateGMTensor.SetGlobalBuffer((__gm__ float *)stateGM); + float current = (float)0.0; + float target = (float)1.0 * tempValidCount * FLOAT_NUM_PER_ALIGN; + SumParams sumPerKParams{1, calCount, calCount}; + LocalTensor stateTensorLocal = waitStateBuf_.Get(); + LocalTensor tempStateTensorLocal = waitTempStateBuf_.Get(); + while (current != target) { + SyncFunc(); + DataCopy(stateTensorLocal, stateGMTensor, calCount); + SyncFunc(); + Sum(tempStateTensorLocal, stateTensorLocal, sumPerKParams); + SyncFunc(); + current = tempStateTensorLocal(0); + } + SyncFunc(); + Duplicate(tempStateTensorLocal, (float)0.0, calCount); + SyncFunc(); + DataCopy(stateGMTensor, tempStateTensorLocal, calCount); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::ReadBufferAndWeightedSum( + uint32_t recvXTokenIdx, uint32_t topkWeightTokenIdx) +{ + LocalTensor tokenFloatLocal = tokenFloatBuf_.Get(); + LocalTensor weightedMulBufLocal = weightedMulBuf_.Get(); + LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); + Duplicate(sumFloatBufLocal, static_cast(0), axisH_); + const DataCopyExtParams xOutCopyParams{1U, static_cast(hRecvXTypeLen_), 0U, 0U, 0U}; + uint32_t xOutTokenIdx = recvXTokenIdx + xOutTokenOffset_; + + for (uint32_t topkId = 0U; topkId < axisK_; topkId++) { + int expertId = tokenIdxLT_.GetValue(topkWeightTokenIdx * axisK_ + topkId); + if (expertId < 0 || expertId >= moeExpertNum_) { + continue; + } + float scale = topkWeightsLT_.GetValue(topkWeightTokenIdx * axisK_ + topkId); + GM_ADDR localTokenAddr = + GetBufferAddrByRankId(epRankId_) + (recvXTokenIdx * axisK_ + topkId) * h512AlignRecvXLen_; + GlobalTensor localTokenTensor; + localTokenTensor.SetGlobalBuffer((__gm__ XType *)localTokenAddr); + + LocalTensor tmpToken = weightedSumQueue_.AllocTensor(); + const DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(tmpToken, localTokenTensor, xOutCopyParams, copyPadExtParams); + weightedSumQueue_.EnQue(tmpToken); + tmpToken = weightedSumQueue_.DeQue(); + Cast(tokenFloatLocal, tmpToken, AscendC::RoundMode::CAST_NONE, axisH_); + PipeBarrier(); + AscendC::Muls(weightedMulBufLocal, tokenFloatLocal, scale, axisH_); + PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, weightedMulBufLocal, axisH_); + weightedSumQueue_.FreeTensor(tmpToken); + } + PipeBarrier(); + LocalTensor xOutLocal = xOutBuf_.Get(); + Cast(xOutLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, axisH_); + SyncFunc(); + DataCopyPad(xOutGlobal_[xOutTokenIdx * axisH_], xOutLocal, xOutCopyParams); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::ReadBufferFromRemote() +{ + if (totalNeedRecvTokenCnt_ == 0) { + return; + } + roundTotalRecvTokenCnt_ = min(perRoundTokens_, totalNeedRecvTokenCnt_); + SplitCoreCal(roundTotalRecvTokenCnt_, roundRecvTokenCnt_, roundRecvStartTokenIdx_, roundRecvEndTokenIdx_); + if (roundRecvTokenCnt_ == 0) { + return; + } + const DataCopyExtParams bskParams{1U, static_cast(roundRecvTokenCnt_ * axisK_ * sizeof(float)), 0U, 0U, + 0U}; + const DataCopyExtParams tokenIdxParams{1U, static_cast(roundRecvTokenCnt_ * axisK_ * sizeof(int32_t)), 0U, + 0U, 0U}; + const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadIntParams{false, 0U, 0U, 0U}; + DataCopyPad(topkWeightsLT_, topkWeightsGM_[(xOutTokenOffset_ + roundRecvStartTokenIdx_) * axisK_], bskParams, + copyPadFloatParams); + DataCopyPad(tokenIdxLT_, tokenIdxGM_[(xOutTokenOffset_ + roundRecvStartTokenIdx_) * axisK_], tokenIdxParams, + copyPadIntParams); + PipeBarrier(); + SyncFunc(); + + for (uint32_t roundTokenIdx = roundRecvStartTokenIdx_; roundTokenIdx < roundRecvEndTokenIdx_; + roundTokenIdx++) { // 每轮都从从hccl buffer起始位置读put来的数据 + uint32_t topkWeightIdx = roundTokenIdx - roundRecvStartTokenIdx_; // 用来计算每一轮token对应weight的偏移 + WaitBuffCopy(roundTokenIdx, topkWeightIdx); + SyncFunc(); // 与结果搬出datacopy同tensor + ReadBufferAndWeightedSum(roundTokenIdx, topkWeightIdx); + } + totalNeedRecvTokenCnt_ -= roundTotalRecvTokenCnt_; + xOutTokenOffset_ += roundTotalRecvTokenCnt_; +} +template +__aicore__ inline void CamMoeCombineNormalMultiRound::SetRoundStatus() +{ + if (coreIdx_ != 0) { + return; + } + LocalTensor roundStateTensor = setRoundStateBuf_.Get(); + Duplicate(roundStateTensor, 1.0, FLOAT_NUM_PER_ALIGN); + SyncFunc(); + for (uint32_t i = 0; i < epWorldSize_; ++i) { + uint32_t targetRankId = i; + uint32_t offset = stateOffset_ * epRankId_; + GM_ADDR rankGM = GetRoundStateAddrByRankId(targetRankId) + offset; + dstRoundStatusGT_.SetGlobalBuffer((__gm__ float *)rankGM); + DataCopy(dstRoundStatusGT_, roundStateTensor, FLOAT_NUM_PER_ALIGN); + } + SyncFunc(); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::WaitRoundStatus() +{ + if (coreIdx_ != 0) { + return; + } + uint32_t count = epWorldSize_ * FLOAT_NUM_PER_ALIGN; + uint32_t inner = (count * sizeof(float) + 32 - 1) / 32 * 32 / sizeof(float); + GM_ADDR roundStateGM = GetRoundStateAddrByRankId(epRankId_); + GlobalTensor roundStatusGMTensor; + + roundStatusGMTensor.SetGlobalBuffer((__gm__ float *)roundStateGM); + float current = (float)0.0; + float target = (float)(1.0) * epWorldSize_ * FLOAT_NUM_PER_ALIGN; + SumParams sumPerRankParams{1, inner, count}; + LocalTensor stateTensorLocal = waitRoundStateBuf_.Get(); + LocalTensor tempRoundStateTensorLocal = tempRoundStateBuf_.Get(); + + while (current != target) { + SyncFunc(); + DataCopy(stateTensorLocal, roundStatusGMTensor, count); + SyncFunc(); + Sum(tempRoundStateTensorLocal, stateTensorLocal, sumPerRankParams); + SyncFunc(); + current = tempRoundStateTensorLocal.GetValue(0); + } + + SyncFunc(); + Duplicate(tempRoundStateTensorLocal, (float)0.0, count); + SyncFunc(); + DataCopy(roundStatusGMTensor, tempRoundStateTensorLocal, count); +} + +template +__aicore__ inline void CamMoeCombineNormalMultiRound::Process() +{ + if ASCEND_IS_AIV { // 全aiv处理 + uint32_t realRound = (realMaxBs_ + perRoundTokens_ - 1) / perRoundTokens_; + while (roundIndex_ < realRound) { + CopyBufferToShareAndSetStatus(); + ReadBufferFromRemote(); + if (realRound > 1) { + SetRoundStatus(); + WaitRoundStatus(); + roundMagic_ = roundMagic_ == 0 ? 1 : 0; + SyncAll(); + } + roundIndex_ += 1; + } + } + hccl_.Finalize(); +} + +} // namespace CamMoeCombineNormalMultiRoundImpl +#endif // MOE_COMBINE_IMPL_H diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_tiling.h b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_tiling.h index f7205a549..89d8e29aa 100644 --- a/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_tiling.h +++ b/csrc/deepep/ops2/op_kernel/cam_moe_combine_normal_tiling.h @@ -4,7 +4,6 @@ #include #include "kernel_tiling/kernel_tiling.h" -// a3 struct CamMoeCombineNormalInfo { uint32_t epWorldSize; uint32_t tpWorldSize; diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.cpp b/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.cpp index af0288c59..3de89a5aa 100644 --- a/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.cpp +++ b/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.cpp @@ -20,8 +20,9 @@ extern "C" __global__ __aicore__ void cam_moe_dispatch_normal( if (TILING_KEY_IS(TILINGKEY_NO_QUANT)) { GET_TILING_DATA_WITH_STRUCT(CamMoeDispatchNormalTilingData, tilingData, tilingGM); CamMoeDispatchNormal op; - op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expandXOut, dynamicScalesOut, - assist_info_for_combine, waitRecvCostStatsOut, workspaceGM, &pipe, tilingGM); + op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expert_global_offset, + srcrank_in_expert_offset, r_in_srcrank_offset, expandXOut, dynamicScalesOut, assist_info_for_combine, + waitRecvCostStatsOut, workspaceGM, &pipe, tilingGM); op.Process(); return; } @@ -29,8 +30,9 @@ extern "C" __global__ __aicore__ void cam_moe_dispatch_normal( if (TILING_KEY_IS(TILINGKEY_QUANT)) { GET_TILING_DATA_WITH_STRUCT(CamMoeDispatchNormalTilingData, tilingData, tilingGM); CamMoeDispatchNormal op; - op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expandXOut, dynamicScalesOut, - assist_info_for_combine, waitRecvCostStatsOut, workspaceGM, &pipe, tilingGM); + op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expert_global_offset, + srcrank_in_expert_offset, r_in_srcrank_offset, expandXOut, dynamicScalesOut, assist_info_for_combine, + waitRecvCostStatsOut, workspaceGM, &pipe, tilingGM); op.Process(); return; } diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.h b/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.h index 718724127..75ebb2ed4 100644 --- a/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.h +++ b/csrc/deepep/ops2/op_kernel/cam_moe_dispatch_normal.h @@ -20,8 +20,10 @@ constexpr uint64_t STATE_WIN_OFFSET = 950UL * 1024UL; constexpr uint64_t WIN_ADDR_ALIGN = 512UL; constexpr uint64_t STATUS_OFFSET = 1024UL * 1024UL; constexpr uint32_t EXPAND_IDX_INFO = 3U; -constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL; +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 8UL * 1024UL * 1024UL; constexpr int64_t CYCLE_TO_TIME = 50; // cycle num is converted into a fixed base unit of time, set at 50 +constexpr uint64_t ROUND_STATE_OFFSET = Moe::BASE_ROUND_STATE_OFFSET; +constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U; template __aicore__ inline void SyncFunc() @@ -43,15 +45,19 @@ class CamMoeDispatchNormal public: __aicore__ inline CamMoeDispatchNormal(){}; __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx, - GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, - GM_ADDR expandIdxOut, GM_ADDR waitRecvCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, - GM_ADDR tilingGM); + GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expert_global_offset, + GM_ADDR srcrank_in_expert_offset, GM_ADDR r_in_srcrank_offset, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR waitRecvCostStatsOut, + GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tilingGM); __aicore__ inline void Process(); private: __aicore__ inline void InputToShare(); __aicore__ inline void SetStatus(); + __aicore__ inline void SetRoundStatus(); __aicore__ inline void WaitStatus(); + __aicore__ inline void WaitRoundStatus(); + __aicore__ inline void ShareToOutputLongSeq(); __aicore__ inline void ShareToOutput(); __aicore__ inline void UpdateOutput(); __aicore__ inline void FillTriple(LocalTensor &xOutTensor, uint32_t tokenIndex, uint32_t k); @@ -66,8 +72,13 @@ class CamMoeDispatchNormal __aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) { - return hccl_.GetWindowsInAddr(rankId) + winContext_[COMM_EP_IDX]->winSize - Moe::STATE_SIZE * 2 + - dataState * WIN_STATE_OFFSET; + return hccl_.GetWindowsInAddr(rankId) + totalWinSize_ - Moe::STATE_SIZE * 2 + dataState * WIN_STATE_OFFSET; + } + + __aicore__ inline GM_ADDR GetRoundStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + return hccl_.GetWindowsInAddr(rankId) + totalWinSize_ - Moe::STATE_SIZE * 2 + + dataState * Moe::ROUND_STATE_MAX_SIZE + ROUND_STATE_OFFSET; } TPipe *tpipe_{nullptr}; @@ -77,12 +88,15 @@ class CamMoeDispatchNormal GlobalTensor sendTokenIdxGT; GlobalTensor recvOffsetGT; GlobalTensor recvCountGT; + GlobalTensor expertGlobalOffsetGT; + GlobalTensor srcrankInExpertOffsetGT; + GlobalTensor rInSrcrankOffsetGT; GlobalTensor dynamicScalesOutGT; GlobalTensor expandIdxOutGT; GlobalTensor dstGT; GlobalTensor dstStatusGT; GlobalTensor waitRecvCostStatsGT; - + GlobalTensor dstRoundStatusGT; LocalTensor xInTensor; LocalTensor xOutTensor; LocalTensor xTmpTensor; @@ -95,6 +109,9 @@ class CamMoeDispatchNormal LocalTensor waitRecvCostStatsTensor; LocalTensor recvStatusTensor1; LocalTensor recvStatusTensor2; + LocalTensor expertGlobalOffsetTensor; + LocalTensor srcrankInExpertOffsetTensor; + LocalTensor rInSrcrankOffsetTensor; TBuf<> expertIdsBuf; TBuf<> sendOffsetBuf; @@ -108,12 +125,20 @@ class CamMoeDispatchNormal TBuf<> tokenCastFloatBuf; TBuf<> tokenAbsFloatBuf; TBuf<> recvStatusBuf; + TBuf<> roundStatusBuf; + TBuf<> tempRoundStatusBuf; + TBuf<> expertGlobalOffsetBuf; + TBuf<> srcrankInExpertOffsetBuf; + TBuf<> rInSrcrankOffsetBuf; GM_ADDR expandXOutGM; GM_ADDR shareGM; uint32_t batchSize{0}; + uint32_t realMaxBatchSize{0}; uint32_t globalBatchSize{0}; + uint32_t round{4}; + uint32_t perRoundTokens{1024}; uint32_t h{0}; uint32_t topK{0}; uint32_t blockNum{0}; @@ -124,6 +149,7 @@ class CamMoeDispatchNormal uint32_t tpRankId{0}; uint32_t moeExpertNum{0}; uint32_t moeExpertNumPerRank{0}; + uint64_t totalWinSize_{0}; bool isEnableDiagnose{false}; uint32_t hUBAlignSize{0}; @@ -142,6 +168,8 @@ class CamMoeDispatchNormal uint32_t endStatusId; uint32_t statusNumPerCore; uint32_t remainStatus; + uint32_t roundIndex; + uint32_t hScaleIdxSize; TQueBind xQueue; TQue xInQueue; @@ -155,12 +183,11 @@ class CamMoeDispatchNormal }; template -__aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, - GM_ADDR send_tokenIdx, GM_ADDR recv_offset, - GM_ADDR recv_count, GM_ADDR expandXOut, - GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, - GM_ADDR waitRecvCostStatsOut, GM_ADDR workspaceGM, - TPipe *pipe, GM_ADDR tilingGM) +__aicore__ inline void CamMoeDispatchNormal::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, + GM_ADDR expert_global_offset, GM_ADDR srcrank_in_expert_offset, GM_ADDR r_in_srcrank_offset, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR waitRecvCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, + GM_ADDR tilingGM) { tpipe_ = pipe; blockIdx = GetBlockIdx(); @@ -174,7 +201,10 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); batchSize = tilingData->camMoeDispatchNormalInfo.bs; + realMaxBatchSize = tilingData->camMoeDispatchNormalInfo.realMaxBs; globalBatchSize = tilingData->camMoeDispatchNormalInfo.globalBs; + round = tilingData->camMoeDispatchNormalInfo.round; + perRoundTokens = tilingData->camMoeDispatchNormalInfo.perRoundTokens; h = tilingData->camMoeDispatchNormalInfo.h; topK = tilingData->camMoeDispatchNormalInfo.k; blockNum = tilingData->camMoeDispatchNormalInfo.aivNum; @@ -183,10 +213,10 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD moeExpertNum = tilingData->camMoeDispatchNormalInfo.moeExpertNum; moeExpertNumPerRank = moeExpertNum / epRankSize; isEnableDiagnose = tilingData->camMoeDispatchNormalInfo.isEnableDiagnose; + totalWinSize_ = tilingData->camMoeDispatchNormalInfo.totalWinSize; GlobalTensor selfDataStatusTensor; - GM_ADDR statusDataSpaceGm = - hccl_.GetWindowsInAddr(epRankId) + winContext_[COMM_EP_IDX]->winSize - Moe::STATE_SIZE * 2; + GM_ADDR statusDataSpaceGm = hccl_.GetWindowsInAddr(epRankId) + totalWinSize_ - Moe::STATE_SIZE * 2; selfDataStatusTensor.SetGlobalBuffer( (__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET + blockIdx * WIN_ADDR_ALIGN)); @@ -196,6 +226,9 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD sendTokenIdxGT.SetGlobalBuffer((__gm__ int32_t *)(send_tokenIdx)); recvOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(recv_offset)); recvCountGT.SetGlobalBuffer((__gm__ int32_t *)(recv_count)); + expertGlobalOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(expert_global_offset)); + srcrankInExpertOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(srcrank_in_expert_offset)); + rInSrcrankOffsetGT.SetGlobalBuffer((__gm__ int32_t *)(r_in_srcrank_offset)); dynamicScalesOutGT.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); expandIdxOutGT.SetGlobalBuffer((__gm__ int32_t *)(expandIdxOut)); if (isEnableDiagnose) { @@ -208,11 +241,12 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD uint32_t hScaleSizeAlign = hUBAlignSize + UB_ALIGN; expandIdxStartIdx = hScaleSizeAlign / sizeof(int32_t); - uint32_t hScaleIdxSize = hScaleSizeAlign + EXPAND_IDX_INFO * sizeof(int32_t); + hScaleIdxSize = hScaleSizeAlign + EXPAND_IDX_INFO * sizeof(int32_t); hOutGMAlignSize = Ceil(hScaleIdxSize, WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; hGMAlignCnt = hOutGMAlignSize / sizeof(ExpandXOutType); expertIdsCnt = batchSize * topK; + roundIndex = 0; statusNumPerCore = moeExpertNum / blockNum; remainStatus = moeExpertNum % blockNum; startStatusId = statusNumPerCore * blockIdx; @@ -236,20 +270,11 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD PipeBarrier(); uint64_t hSizeAlignCombine = Ceil(h * sizeof(XType), WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN; - winDataSizeOffset = dataState * (tilingData->camMoeDispatchNormalInfo.totalWinSize / 2) + - globalBatchSize / epRankSize * topK * hSizeAlignCombine; + hSizeAlignCombine = round > 1 ? hSizeAlignCombine * 2 : hSizeAlignCombine; + winDataSizeOffset = dataState * ((totalWinSize_ - Moe::STATE_SIZE * 4) / 2) + + min(realMaxBatchSize, perRoundTokens) * topK * hSizeAlignCombine; // *2 是因为double buffer shareGM = GetWindAddrByRankId(COMM_EP_IDX, epRankId); - hOutUBAlignSize = Ceil(hScaleIdxSize, UB_ALIGN) * UB_ALIGN; - if constexpr (DynamicQuant) { - QuantInit(); - } else { - tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 2 * 14K = 28K - } - - tpipe_->InitBuffer(sendOffsetBuf, moeExpertNum * sizeof(int32_t)); // 4 * moeNum - sendOffsetTensor = sendOffsetBuf.Get(); - hCommuCopyOutParams = {1U, static_cast(hScaleIdxSize), 0U, 0U, 0U}; } @@ -341,12 +366,38 @@ __aicore__ inline void CamMoeDispatchNormal::FillTriple(LocalTensor template __aicore__ inline void CamMoeDispatchNormal::InputToShare() { + tpipe_->Reset(); + hOutUBAlignSize = Ceil(hScaleIdxSize, UB_ALIGN) * UB_ALIGN; + if constexpr (DynamicQuant) { + QuantInit(); + } else { + tpipe_->InitBuffer(xQueue, BUFFER_NUM, hOutUBAlignSize); // 2 * 14K = 28K + } + tpipe_->InitBuffer(sendOffsetBuf, moeExpertNum * sizeof(int32_t)); // 4 * moeNum + sendOffsetTensor = sendOffsetBuf.Get(); + DataCopyExtParams sendOffsetParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyPadExtParams sendOffsetCopyPadParams{false, 0U, 0U, 0U}; - DataCopyPad(sendOffsetTensor, sendOffsetGT, sendOffsetParams, sendOffsetCopyPadParams); + DataCopyPad(sendOffsetTensor, sendOffsetGT[roundIndex * moeExpertNum], sendOffsetParams, sendOffsetCopyPadParams); SyncFunc(); uint32_t startTokenId, endTokenId, sendTokenNum, remainTokenNum; + + uint32_t realRound = (realMaxBatchSize + perRoundTokens - 1) / perRoundTokens; + uint32_t localRound = (batchSize + perRoundTokens - 1) / perRoundTokens; + + if (roundIndex >= localRound) { + expertIdsCnt = 0; + } else if (roundIndex < localRound - 1) { + expertIdsCnt = perRoundTokens * topK; + } else { + uint32_t processedTokens = perRoundTokens * roundIndex; + uint32_t remaining = (batchSize > processedTokens) ? (batchSize - processedTokens) : 0; + expertIdsCnt = remaining * topK; + } + if (expertIdsCnt == 0) { + return; + } sendTokenNum = expertIdsCnt / blockNum; remainTokenNum = expertIdsCnt % blockNum; startTokenId = sendTokenNum * blockIdx; @@ -358,7 +409,7 @@ __aicore__ inline void CamMoeDispatchNormal::InputToShare() } endTokenId = startTokenId + sendTokenNum; - if (startTokenId >= expertIdsCnt) { + if (startTokenId >= expertIdsCnt || sendTokenNum == 0) { return; } tpipe_->InitBuffer(expertIdsBuf, sendTokenNum * sizeof(int32_t)); // 4 * bs * k / 48 @@ -369,8 +420,10 @@ __aicore__ inline void CamMoeDispatchNormal::InputToShare() DataCopyExtParams sendTokenIdxParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; DataCopyPadExtParams tokenCopyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(expertIdsTensor, expertIdsGT[startTokenId], expertIdsCntParams, copyPadExtParams); - DataCopyPad(sendTokenIdxTensor, sendTokenIdxGT[startTokenId], sendTokenIdxParams, copyPadExtParams); + DataCopyPad(expertIdsTensor, expertIdsGT[roundIndex * perRoundTokens * topK + startTokenId], expertIdsCntParams, + copyPadExtParams); + DataCopyPad(sendTokenIdxTensor, sendTokenIdxGT[roundIndex * perRoundTokens * topK + startTokenId], + sendTokenIdxParams, copyPadExtParams); SyncFunc(); DataCopyExtParams xCopyParams = {1U, static_cast(h * sizeof(XType)), 0U, 0U, 0U}; @@ -386,22 +439,24 @@ __aicore__ inline void CamMoeDispatchNormal::InputToShare() if constexpr (DynamicQuant) { xInTensor = xInQueue.AllocTensor(); - DataCopyPad(xInTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams); + DataCopyPad(xInTensor, xGT[(roundIndex * perRoundTokens + tokenIndex / topK) * h], xCopyParams, + tokenCopyPadExtParams); xInQueue.EnQue(xInTensor); xInTensor = xInQueue.DeQue(); xOutTensor = xOutQueue.AllocTensor(); QuantProcess(); xOutQueue.EnQue(xOutTensor); xOutTensor = xOutQueue.DeQue(); - FillTriple(xOutTensor, tokenIndex / topK, tokenIndex % topK); + FillTriple(xOutTensor, (roundIndex * perRoundTokens + tokenIndex / topK), tokenIndex % topK); DataCopyPad(dstGT, xOutTensor, hCommuCopyOutParams); xOutQueue.FreeTensor(xOutTensor); } else { xTmpTensor = xQueue.AllocTensor(); - DataCopyPad(xTmpTensor, xGT[tokenIndex / topK * h], xCopyParams, tokenCopyPadExtParams); + DataCopyPad(xTmpTensor, xGT[(roundIndex * perRoundTokens + tokenIndex / topK) * h], xCopyParams, + tokenCopyPadExtParams); xQueue.EnQue(xTmpTensor); xTmpTensor = xQueue.DeQue(); - FillTriple(xTmpTensor, tokenIndex / topK, tokenIndex % topK); + FillTriple(xTmpTensor, (roundIndex * perRoundTokens + tokenIndex / topK), tokenIndex % topK); DataCopyPad(dstGT, xTmpTensor, hCommuCopyOutParams); xQueue.FreeTensor(xTmpTensor); } @@ -438,6 +493,26 @@ __aicore__ inline void CamMoeDispatchNormal::SetStatus() SyncFunc(); } +template +__aicore__ inline void CamMoeDispatchNormal::SetRoundStatus() +{ + if (blockIdx >= 1) { + return; + } + tpipe_->InitBuffer(roundStatusBuf, epRankSize * UB_ALIGN); + LocalTensor roundStatusTensor = roundStatusBuf.AllocTensor(); + Duplicate(roundStatusTensor, 1.0, FLOAT_NUM_PER_ALIGN); + for (uint32_t i = 0; i < epRankSize; ++i) { + uint32_t targetRankId = i; + uint32_t offset = stateOffset * epRankId; + GM_ADDR rankGM = GetRoundStateAddrByRankId(COMM_EP_IDX, targetRankId) + offset; + dstRoundStatusGT.SetGlobalBuffer((__gm__ float *)rankGM); + DataCopy(dstRoundStatusGT, roundStatusTensor, FLOAT_NUM_PER_ALIGN); + } + SyncFunc(); + roundStatusBuf.FreeTensor(roundStatusTensor); +} + template __aicore__ inline void CamMoeDispatchNormal::WaitStatus() { @@ -469,8 +544,8 @@ __aicore__ inline void CamMoeDispatchNormal::WaitStatus() DataCopyExtParams recvOffsetParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyExtParams recvCountParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(recvOffsetTensor, recvOffsetGT, recvOffsetParams, copyPadExtParams); - DataCopyPad(recvCountTensor, recvCountGT, recvCountParams, copyPadExtParams); + DataCopyPad(recvOffsetTensor, recvOffsetGT[roundIndex * moeExpertNum], recvOffsetParams, copyPadExtParams); + DataCopyPad(recvCountTensor, recvCountGT[roundIndex * moeExpertNum], recvCountParams, copyPadExtParams); if (startStatusId >= moeExpertNum) { SyncAll(); @@ -544,12 +619,74 @@ __aicore__ inline void CamMoeDispatchNormal::WaitStatus() } template -__aicore__ inline void CamMoeDispatchNormal::ShareToOutput() +__aicore__ inline void CamMoeDispatchNormal::WaitRoundStatus() +{ + tpipe_->Reset(); + if (blockIdx >= 1) { + return; + } + tpipe_->InitBuffer(roundStatusBuf, epRankSize * sizeof(float)); + tpipe_->InitBuffer(tempRoundStatusBuf, epRankSize * sizeof(float)); + uint32_t count = epRankSize * FLOAT_NUM_PER_ALIGN; + uint32_t inner = (count * sizeof(float) + 32 - 1) / 32 * 32 / sizeof(float); + GM_ADDR roundStateGM = GetRoundStateAddrByRankId(COMM_EP_IDX, epRankId); + GlobalTensor roundStatusGMTensor; + + roundStatusGMTensor.SetGlobalBuffer((__gm__ float *)roundStateGM); + float current = (float)0.0; + float target = (float)(1.0) * epRankSize * FLOAT_NUM_PER_ALIGN; + SumParams sumPerRankParams{1, inner, count}; + LocalTensor stateTensorLocal = roundStatusBuf.Get(); + LocalTensor tempRoundStateTensorLocal = tempRoundStatusBuf.Get(); + + int64_t systemCycleBefore = AscendC::GetSystemCycle(); + while (current != target) { + SyncFunc(); + DataCopy(stateTensorLocal, roundStatusGMTensor, count); + SyncFunc(); + Sum(tempRoundStateTensorLocal, stateTensorLocal, sumPerRankParams); + SyncFunc(); + current = tempRoundStateTensorLocal.GetValue(0); + int64_t systemCycleAfter = AscendC::GetSystemCycle(); + } + + SyncFunc(); + Duplicate(tempRoundStateTensorLocal, (float)0.0, count); + SyncFunc(); + DataCopy(roundStatusGMTensor, tempRoundStateTensorLocal, count); + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDispatchNormal::ShareToOutputLongSeq() { if (startStatusId >= moeExpertNum) { return; } - uint32_t fromRank, count, preCount, recvOffset, targetOffset; + + tpipe_->InitBuffer(expertGlobalOffsetBuf, moeExpertNumPerRank * sizeof(int32_t)); + expertGlobalOffsetTensor = expertGlobalOffsetBuf.Get(); + DataCopyExtParams expertGlobalOffsetParams{1U, static_cast(sizeof(int32_t) * moeExpertNumPerRank), 0U, 0U, + 0U}; + DataCopyPadExtParams expertGlobalOffsetCopyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(expertGlobalOffsetTensor, expertGlobalOffsetGT, expertGlobalOffsetParams, + expertGlobalOffsetCopyPadExtParams); + + tpipe_->InitBuffer(srcrankInExpertOffsetBuf, moeExpertNum * sizeof(int32_t)); + srcrankInExpertOffsetTensor = srcrankInExpertOffsetBuf.Get(); + DataCopyExtParams srcrankInExpertOffsetParams{1U, static_cast(sizeof(int32_t) * moeExpertNum), 0U, 0U, + 0U}; + DataCopyPadExtParams srcrankInExpertOffsetCopyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(srcrankInExpertOffsetTensor, srcrankInExpertOffsetGT, srcrankInExpertOffsetParams, + srcrankInExpertOffsetCopyPadExtParams); + + tpipe_->InitBuffer(rInSrcrankOffsetBuf, round * moeExpertNum * sizeof(int32_t)); + rInSrcrankOffsetTensor = rInSrcrankOffsetBuf.Get(); + DataCopyExtParams CParams{1U, static_cast(sizeof(int32_t) * moeExpertNum * round), 0U, 0U, 0U}; + DataCopyPadExtParams CCopyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(rInSrcrankOffsetTensor, rInSrcrankOffsetGT, CParams, CCopyPadExtParams); + + uint32_t fromRank, count, preCount, recvOffset, targetOffset, local_e; DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; DataCopyExtParams dataCopyExandIdxParams{1U, sizeof(int32_t) * EXPAND_IDX_INFO, 0U, 0U, 0U}; DataCopyExtParams dataCopyOutParams{1U, static_cast(statusNumPerCore * sizeof(int32_t)), 0U, 0U, 0U}; @@ -558,15 +695,25 @@ __aicore__ inline void CamMoeDispatchNormal::ShareToOutput() AscendC::TQueSync recvCountLocalSync; recvCountLocalSync.SetFlag(0); recvCountLocalSync.WaitFlag(0); + for (uint32_t i = startStatusId; i < endStatusId; ++i) { preCount = 0; if (likely(i != 0)) { preCount = recvCountTensor(i - 1); } + fromRank = i % epRankSize; + local_e = i / epRankSize; count = recvCountTensor(i) - preCount; recvOffset = recvOffsetTensor(i); - targetOffset = preCount; + + // 目标地址 = 专家全局起始 + B[es_idx](源rank在专家内偏移) + r_in_srcrank_offset[c_idx](轮次在源rank内偏移) + int32_t rInSrcrankIndex = local_e * epRankSize * round + fromRank * round + roundIndex; + int32_t expertGlobalOffset = expertGlobalOffsetTensor(local_e); + int32_t srcrankInExpertOffset = srcrankInExpertOffsetTensor(i); + int32_t rInSrcrankOffset = rInSrcrankOffsetTensor(rInSrcrankIndex); + int32_t writeOffset = expertGlobalOffset + srcrankInExpertOffset + rInSrcrankOffset; + GM_ADDR recvStart = (__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, fromRank)) + recvOffset * hOutGMAlignSize; GlobalTensor srcTokenGT, dstTokenGT; @@ -574,19 +721,23 @@ __aicore__ inline void CamMoeDispatchNormal::ShareToOutput() srcTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(recvStart + j * hOutGMAlignSize)); xTmpTensor = xQueue.AllocTensor(); DataCopyPad(xTmpTensor, srcTokenGT, hCommuCopyOutParams, copyPadExtParams); + xQueue.EnQue(xTmpTensor); xTmpTensor = xQueue.DeQue(); xTmpTensorInt = xTmpTensor.template ReinterpretCast(); - DataCopyPad(expandIdxOutGT[(targetOffset + j) * EXPAND_IDX_INFO], xTmpTensorInt[expandIdxStartIdx], + DataCopyPad(expandIdxOutGT[(writeOffset + j) * EXPAND_IDX_INFO], xTmpTensorInt[expandIdxStartIdx], dataCopyExandIdxParams); + if constexpr (DynamicQuant) { DataCopyExtParams floatDataCopyParams = {1U, sizeof(float), 0U, 0U, 0U}; LocalTensor xOutFp32Tensor = xTmpTensor.template ReinterpretCast(); - DataCopyPad(dynamicScalesOutGT[targetOffset + j], xOutFp32Tensor[hUBAlignSize / sizeof(float)], + DataCopyPad(dynamicScalesOutGT[writeOffset + j], xOutFp32Tensor[hUBAlignSize / sizeof(float)], floatDataCopyParams); } - dstTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM) + (targetOffset + j) * h, h); + + dstTokenGT.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM) + (writeOffset + j) * h, h); DataCopyPad(dstTokenGT, xTmpTensor, expandXCopyParams); + xQueue.FreeTensor(xTmpTensor); } } @@ -596,10 +747,20 @@ template __aicore__ inline void CamMoeDispatchNormal::Process() { if ASCEND_IS_AIV { - InputToShare(); - SetStatus(); - WaitStatus(); - ShareToOutput(); + uint32_t realRound = (realMaxBatchSize + perRoundTokens - 1) / perRoundTokens; + while (roundIndex < realRound) { + InputToShare(); + SetStatus(); + WaitStatus(); + ShareToOutputLongSeq(); + if (realRound > 1) { + SyncAll(); + SetRoundStatus(); + WaitRoundStatus(); + SyncAll(); + } + roundIndex += 1; + } } hccl_.Finalize(); } diff --git a/csrc/deepep/ops2/op_kernel/check_winsize.h b/csrc/deepep/ops2/op_kernel/check_winsize.h index 654496fec..26ddc5e32 100644 --- a/csrc/deepep/ops2/op_kernel/check_winsize.h +++ b/csrc/deepep/ops2/op_kernel/check_winsize.h @@ -1,18 +1,3 @@ -/** - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - */ - -/*! - * \file check_winsize.h - * \brief - */ - #ifndef CHECK_WINSIZE_H #define CHECK_WINSIZE_H @@ -30,7 +15,7 @@ __aicore__ inline void CheckWindowSize(uint64_t tilingWinSizeBytes, uint64_t rea AscendC::TBuf exceptionBuf; tpipe_->InitBuffer(exceptionBuf, 1); // 初始化一个缓冲区 AscendC::LocalTensor exceptionLocal = exceptionBuf.Get(); - AscendC::DataCopy(exceptionLocal[0], exceptionGlobal, 1); // 从全局地址复制数据到本地地址 + AscendC::DataCopy(exceptionLocal[1], exceptionGlobal, 1); // 从全局地址复制数据到本地地址 } } #endif // CHECK_WINSIZE_H diff --git a/csrc/deepep/ops2/op_kernel/comm_args.h b/csrc/deepep/ops2/op_kernel/comm_args.h index 8cb1f6f08..0f82f7fd1 100644 --- a/csrc/deepep/ops2/op_kernel/comm_args.h +++ b/csrc/deepep/ops2/op_kernel/comm_args.h @@ -9,7 +9,7 @@ namespace Moe { constexpr int CAM_MAX_RANK_SIZE = 384; // Maximum number of NPU cards supported by the communication library -constexpr uint64_t NOTIFY_DISPATCH_BUFF_OFFSET = 404UL * 1024UL * 1024UL; +constexpr uint64_t NOTIFY_DISPATCH_BUFF_OFFSET = 202UL * 1024UL * 1024UL; constexpr int64_t IPC_BUFF_MAX_SIZE = 200 * 1024 * 1024; constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // First 2MB as flag, then 100MB as data storage constexpr int64_t PING_PONG_SIZE = 2; @@ -29,6 +29,8 @@ constexpr int64_t WAIT_SUCCESS = 112233445566; constexpr int64_t IPC_CHUNK_FLAG = 0; // Start offset for send recv, chunk flag region constexpr int64_t MAX_WAIT_ROUND_UNIT = 10 * 1000 * 1000; // Threshold for waiting to get Flag under normal conditions within the same SIO +constexpr uint64_t ROUND_STATE_MAX_SIZE = 4UL * 1024UL; +constexpr uint64_t BASE_ROUND_STATE_OFFSET = 450UL * 1024UL; constexpr static int32_t UB_HEAD_OFFSET = 96; constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + UB_ALIGN_SIZE; diff --git a/csrc/deepep/ops2/op_kernel/dispatch_layout.h b/csrc/deepep/ops2/op_kernel/dispatch_layout.h index 0540e2a92..0b8ba82bb 100644 --- a/csrc/deepep/ops2/op_kernel/dispatch_layout.h +++ b/csrc/deepep/ops2/op_kernel/dispatch_layout.h @@ -12,6 +12,7 @@ namespace MoeDispatchLayout { constexpr uint32_t UB_32_ALIGN = 32U; +constexpr uint32_t ONE_PIECE = 8U; template __aicore__ inline void SyncFunc() @@ -37,48 +38,48 @@ class DispatchLayout numRanks_ = tilingData->dispatchLayoutInfo.numRanks; numExperts_ = tilingData->dispatchLayoutInfo.numExperts; numTopk_ = tilingData->dispatchLayoutInfo.numTopk; + perRoundTokens_ = tilingData->dispatchLayoutInfo.perRoundTokens; + rankId_ = tilingData->dispatchLayoutInfo.rankId; + round_ = (numTokens_ + perRoundTokens_ - 1) / perRoundTokens_; tpipe_ = pipe; - coreIdx_ = GetBlockIdx(); + topkIdx_ = topkIdx; + numTokensPerExpert_ = numTokensPerExpert; + isTokenInRank_ = isTokenInRank; + sendTokenIdxSmall_ = sendTokenIdxSmall; + uint32_t maxAivNum = GetBlockNum(); - aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum; - if (coreIdx_ >= aivNum_) { - return; - } - uint32_t temp = numTokens_ / aivNum_; - uint32_t restNum = numTokens_ % aivNum_; - int64_t topkIdxOffset; - int64_t isTokenOffset; - tempTokens_ = temp; + coreIdx_ = GetBlockIdx(); + int32_t firstRoundTokens = (round_ == 1) ? numTokens_ : perRoundTokens_; + aivNum_ = firstRoundTokens <= maxAivNum ? firstRoundTokens : maxAivNum; + tempTokens_ = firstRoundTokens / aivNum_; + int32_t restNum = firstRoundTokens % aivNum_; if (coreIdx_ < restNum) { - tempTokens_++; + ++tempTokens_; } topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; - sendTokenIdx32AlignIntLen_ = Ceil(tempTokens_ * numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + if (coreIdx_ < restNum) { - topkIdxOffset = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); - isTokenOffset = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + topkIdxOffset_ = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); + sendIdxOffset_ = coreIdx_ * tempTokens_ * numTopk_ * sizeof(T); + isTokenOffset_ = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + tokenIdxOffset_ = coreIdx_ * tempTokens_ * sizeof(T); } else { - topkIdxOffset = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); - isTokenOffset = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + topkIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); + sendIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(T); + isTokenOffset_ = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + tokenIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * sizeof(T); } + tempExpertGM_.SetGlobalBuffer((__gm__ T *)notifySendData); - topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)(topkIdx + topkIdxOffset)); numTokensPerRankGM_.SetGlobalBuffer((__gm__ T *)numTokensPerRank); - numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T *)numTokensPerExpert); - isTokenInRankGM_.SetGlobalBuffer((__gm__ T *)(isTokenInRank + isTokenOffset)); - sendTokenIdxSmallGM_.SetGlobalBuffer((__gm__ T *)(sendTokenIdxSmall + topkIdxOffset / 2)); } __aicore__ inline void Process() { - if (coreIdx_ >= aivNum_) { - SyncAll(); - return; - } tpipe_->Reset(); tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_); tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_); @@ -86,82 +87,135 @@ class DispatchLayout tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_); tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T)); tpipe_->InitBuffer(sendTokenIdxSmallBuf_, topkIdx32AlignIntLen_); - LocalTensor topkIdxTensor = topkIdxBuf_.AllocTensor(); - const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; - const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; - DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); - SyncFunc(); LocalTensor numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor(); LocalTensor numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor(); LocalTensor isTokenInRankTensor = isTokenInRankBuf_.AllocTensor(); LocalTensor seenRankTensor = seenRankBuf_.AllocTensor(); LocalTensor sendTokenIdxSmallTensor = sendTokenIdxSmallBuf_.AllocTensor(); - Duplicate(numTokensPerRankTensor, 0, numRanks_); - Duplicate(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T)); - Duplicate(isTokenInRankTensor, 0, tempTokens_ * numRanks_); - SyncFunc(); - - int experts_per_rank = numExperts_ / numRanks_; - for (int i = 0; i < tempTokens_; ++i) { - SyncFunc(); - Duplicate(seenRankTensor, 0, numRanks_); - SyncFunc(); - for (int j = 0; j < numTopk_; ++j) { - int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); - if (expert_idx < 0 || expert_idx >= numExperts_) { - continue; + + int32_t preRoundCount = 0; + for (int r = 0; r < round_; r++) { + uint32_t roundTokens = perRoundTokens_; + if (r == round_ - 1 && (numTokens_ % perRoundTokens_ != 0)) { + roundTokens = numTokens_ % perRoundTokens_; + uint32_t temp = roundTokens / aivNum_; + uint32_t restNum = roundTokens % aivNum_; + tempTokens_ = temp; + if (coreIdx_ < restNum) { + tempTokens_++; } - uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; - numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); - int rank_id = expert_idx / experts_per_rank; - if (!seenRankTensor.GetValue(rank_id)) { - uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; - isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); - seenRankTensor.SetValue(rank_id, 1); - numTokensPerRankTensor.SetValue(rank_id, per_rank_num); + topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; + isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + + if (coreIdx_ < restNum) { + topkIdxOffset_ = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); + sendIdxOffset_ = coreIdx_ * tempTokens_ * numTopk_ * sizeof(T); + isTokenOffset_ = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + tokenIdxOffset_ = coreIdx_ * tempTokens_ * sizeof(T); + } else { + topkIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); + sendIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(T); + isTokenOffset_ = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + tokenIdxOffset_ = (restNum + coreIdx_ * tempTokens_) * sizeof(T); } } - } - uint32_t sendSize = tempTokens_ * numRanks_ * sizeof(T); - const DataCopyExtParams isTokenInRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; - DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); - AscendC::SetAtomicAdd(); - const DataCopyExtParams tempExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; - for (int i = coreIdx_ + 1; i < aivNum_; ++i) { - DataCopyPad(tempExpertGM_[i * numExperts_], numTokensPerExpertTensor, tempExpertDataCopyParams); - } - sendSize = numRanks_ * sizeof(T); - const DataCopyExtParams numTokensPerRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; - DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); - sendSize = numExperts_ * sizeof(T); - const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, sendSize, 0U, 0U, 0U}; - DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); - AscendC::SetAtomicNone(); - PipeBarrier(); - SyncAll(); - SyncFunc(); - const DataCopyPadExtParams tempPadParams{false, 0U, 0U, 0U}; - DataCopyPad(numTokensPerExpertTensor, tempExpertGM_[coreIdx_ * numExperts_], tempExpertDataCopyParams, - tempPadParams); - - SyncFunc(); - for (int i = 0; i < tempTokens_; ++i) { - for (int j = 0; j < numTopk_; ++j) { - int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); - if (expert_idx < 0 || expert_idx >= numExperts_) { - continue; + uint32_t maxAivNum = GetBlockNum(); + aivNum_ = roundTokens <= maxAivNum ? roundTokens : maxAivNum; + if (coreIdx_ >= aivNum_) { + SyncAll(); + SyncAll(); + continue; + } + + int64_t round_topkIdx_offset = r * perRoundTokens_ * numTopk_ * sizeof(int64_t); + int64_t round_sendIdx_offset = r * perRoundTokens_ * numTopk_ * sizeof(T); + + sendTokenIdxSmallGM_.SetGlobalBuffer( + (__gm__ T *)(sendTokenIdxSmall_ + round_sendIdx_offset + sendIdxOffset_)); + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)(topkIdx_ + round_topkIdx_offset + topkIdxOffset_)); + numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T *)(numTokensPerExpert_ + numExperts_ * r * sizeof(T))); + // tokens * rank; + isTokenInRankGM_.SetGlobalBuffer( + (__gm__ T *)(isTokenInRank_ + r * perRoundTokens_ * numRanks_ * sizeof(T) + isTokenOffset_)); + + const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + SyncFunc(); + DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); + SyncFunc(); + SyncFunc(); + Duplicate(numTokensPerRankTensor, 0, numTokensPerRank32AlignIntLen_ / sizeof(T)); + Duplicate(isTokenInRankTensor, 0, isTokenInRank32AlignIntLen_ / sizeof(T)); + Duplicate(numTokensPerExpertTensor, 0, numTokensPerExpert32AlignIntLen_ / sizeof(T)); + SyncFunc(); + SyncFunc(); + const DataCopyExtParams clearGmParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(tempExpertGM_[coreIdx_ * numExperts_], numTokensPerExpertTensor, clearGmParams); + PipeBarrier(); + SyncAll(); + + int experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < tempTokens_; ++i) { + SyncFunc(); + Duplicate(seenRankTensor, 0, numRanks_); + SyncFunc(); + for (int j = 0; j < numTopk_; ++j) { + int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); + if (expert_idx < 0 || expert_idx >= numExperts_) { + continue; + } + uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; + numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); + int rank_id = expert_idx / experts_per_rank; + if (!seenRankTensor.GetValue(rank_id)) { + uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; + isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); + seenRankTensor.SetValue(rank_id, 1); + numTokensPerRankTensor.SetValue(rank_id, per_rank_num); + } + } + } + uint32_t sendSize = tempTokens_ * numRanks_ * sizeof(T); + const DataCopyExtParams isTokenInRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; + SyncFunc(); + DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); + AscendC::SetAtomicAdd(); + const DataCopyExtParams tempExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; + for (int i = coreIdx_ + 1; i < aivNum_; ++i) { + DataCopyPad(tempExpertGM_[i * numExperts_], numTokensPerExpertTensor, tempExpertDataCopyParams); + } + sendSize = numRanks_ * sizeof(T); + const DataCopyExtParams numTokensPerRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); + sendSize = numExperts_ * sizeof(T); + const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + AscendC::SetAtomicNone(); + PipeBarrier(); + SyncAll(); + SyncFunc(); + const DataCopyPadExtParams tempPadParams{false, 0U, 0U, 0U}; + DataCopyPad(numTokensPerExpertTensor, tempExpertGM_[coreIdx_ * numExperts_], tempExpertDataCopyParams, + tempPadParams); + SyncFunc(); + for (int i = 0; i < tempTokens_; ++i) { + for (int j = 0; j < numTopk_; ++j) { + int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); + if (expert_idx < 0 || expert_idx >= numExperts_) { + continue; + } + T valT = numTokensPerExpertTensor(expert_idx); + sendTokenIdxSmallTensor(i * numTopk_ + j) = valT; + numTokensPerExpertTensor(expert_idx) = valT + 1; } - T valT = numTokensPerExpertTensor(expert_idx); - sendTokenIdxSmallTensor(i * numTopk_ + j) = valT; - numTokensPerExpertTensor(expert_idx) = valT + 1; } + SyncFunc(); + const DataCopyExtParams sendTokenIdxSmallDataCopyParams{ + 1U, static_cast(tempTokens_ * numTopk_ * sizeof(T)), 0U, 0U, 0U}; + DataCopyPad(sendTokenIdxSmallGM_, sendTokenIdxSmallTensor, sendTokenIdxSmallDataCopyParams); } - SyncFunc(); - const DataCopyExtParams sendTokenIdxSmallDataCopyParams{ - 1U, static_cast(tempTokens_ * numTopk_ * sizeof(T)), 0U, 0U, 0U}; - DataCopyPad(sendTokenIdxSmallGM_, sendTokenIdxSmallTensor, sendTokenIdxSmallDataCopyParams); } private: @@ -187,12 +241,24 @@ class DispatchLayout uint32_t coreIdx_{0}; uint32_t aivNum_{0}; uint32_t tempTokens_{0}; + uint32_t round_{0}; + uint32_t rankId_{0}; + uint32_t perRoundTokens_{0}; + uint32_t preRoundsCount_{0}; + int64_t topkIdxOffset_{0}; + int64_t sendIdxOffset_{0}; + int64_t isTokenOffset_{0}; + int64_t tokenIdxOffset_{0}; uint32_t topkIdx32AlignIntLen_{0}; uint32_t numTokensPerRank32AlignIntLen_{0}; uint32_t numTokensPerExpert32AlignIntLen_{0}; uint32_t isTokenInRank32AlignIntLen_{0}; - uint32_t sendTokenIdx32AlignIntLen_{0}; + + GM_ADDR topkIdx_; + GM_ADDR numTokensPerExpert_; + GM_ADDR isTokenInRank_; + GM_ADDR sendTokenIdxSmall_; }; } // namespace MoeDispatchLayout diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch.cpp b/csrc/deepep/ops2/op_kernel/notify_dispatch.cpp index fbef4497e..d504a79ee 100644 --- a/csrc/deepep/ops2/op_kernel/notify_dispatch.cpp +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch.cpp @@ -2,9 +2,6 @@ #include "notify_dispatch.h" #include "notify_dispatch_tiling.h" -#define TILING_KEY_FLOAT16 20 -#define TILING_KEY_BFLOAT16 21 -#define TILING_KEY_FLOAT 22 #define TILING_KEY_INT 23 #define KERNEL_USE_WORKSPACE (1 * 1024 * 1024) @@ -14,17 +11,20 @@ extern "C" __global__ __aicore__ void notify_dispatch(GM_ADDR sendData, GM_ADDR GM_ADDR recvOffset, GM_ADDR expertGlobalOffset, GM_ADDR srcrankInExpertOffset, GM_ADDR rInSrcrankOffset, GM_ADDR totalRecvTokens, GM_ADDR maxBs, - GM_ADDR recvTokensPerExpert, GM_ADDR workspace, GM_ADDR tilingGM) + GM_ADDR recvTokensPerExpert, GM_ADDR workspace, GM_ADDR tiling) { REGISTER_TILING_DEFAULT(NotifyDispatchTilingData); - GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tilingGM); + GET_TILING_DATA_WITH_STRUCT(NotifyDispatchTilingData, tilingData, tiling); int localRank = tilingData.notifyDispatchInfo.localRankId; int localRankSize = tilingData.notifyDispatchInfo.localRankSize; int rank = tilingData.notifyDispatchInfo.rankId; int rankSize = tilingData.notifyDispatchInfo.rankSize; int64_t len = tilingData.notifyDispatchInfo.sendCount; - int64_t numTokens = tilingData.notifyDispatchInfo.numTokens; + int numTokens = tilingData.notifyDispatchInfo.numTokens; + int round = tilingData.notifyDispatchInfo.round; + int perRoundTokens = tilingData.notifyDispatchInfo.perRoundTokens; + uint64_t totalWinSize = tilingData.notifyDispatchInfo.totalWinSize; GM_ADDR sendDataInput = sendData; GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; @@ -37,19 +37,11 @@ extern "C" __global__ __aicore__ void notify_dispatch(GM_ADDR sendData, GM_ADDR int root = 0; int op = 0; int cycleCount = 0; - int64_t scaleCount = 0; + int scaleCount = 0; GM_ADDR offset = nullptr; int blockNum = GetBlockNum(); - if (TILING_KEY_IS(TILING_KEY_FLOAT16)) { - NotifyDispatch opKernel(rank, rankSize, extraFlag); - opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); - opKernel.Process(); - } else if (TILING_KEY_IS(TILING_KEY_FLOAT)) { - NotifyDispatch opKernel(rank, rankSize, extraFlag); - opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); - opKernel.Process(); - } else if (TILING_KEY_IS(TILING_KEY_INT)) { + if (TILING_KEY_IS(TILING_KEY_INT)) { NotifyDispatch opKernel(rank, rankSize, extraFlag); opKernel.Init(KERNELS_ARGS_CALL_ALL2ALL()); opKernel.Process(); diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch.h b/csrc/deepep/ops2/op_kernel/notify_dispatch.h index 472320ff0..51dc076b4 100644 --- a/csrc/deepep/ops2/op_kernel/notify_dispatch.h +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch.h @@ -23,25 +23,37 @@ __aicore__ inline void SyncFunc() #define KERNELS_ARGS_FUN_ALL2ALL() \ GM_ADDR sendDataInput, GM_ADDR tokenPerExpertDataInput, GM_ADDR sendDataOffsetOutput, GM_ADDR recvDataOutput, \ - GM_ADDR totalRecvTokens, GM_ADDR recvCount, GM_ADDR recvOffset, GM_ADDR maxBs, GM_ADDR recvTokensPerExpert, \ - int64_t len, int64_t numTokens, int op, int root, int cycleCount, GM_ADDR scale, int64_t scaleCount, \ - GM_ADDR offset, int localRank, int localRankSize, GM_ADDR tilingGM + GM_ADDR recvCount, GM_ADDR recvOffset, GM_ADDR expertGlobalOffset, GM_ADDR srcrankInExpertOffset, \ + GM_ADDR rInSrcrankOffset, GM_ADDR totalRecvTokens, GM_ADDR maxBs, GM_ADDR recvTokensPerExpert, int64_t len, \ + int32_t round, int32_t perRoundTokens, int32_t numTokens, int op, int root, int cycleCount, GM_ADDR scale, \ + int32_t scaleCount, GM_ADDR offset, int localRank, int localRankSize, uint64_t totalWinSize, GM_ADDR tiling -#define KERNELS_ARGS_CALL_ALL2ALL() \ - sendDataInput, tokenPerExpertDataInput, sendDataOffsetOutput, recvDataOutput, totalRecvTokens, recvCount, \ - recvOffset, maxBs, recvTokensPerExpert, len, numTokens, op, root, cycleCount, scale, scaleCount, offset, \ - localRank, localRankSize, tilingGM +#define KERNELS_ARGS_CALL_ALL2ALL() \ + sendDataInput, tokenPerExpertDataInput, sendDataOffsetOutput, recvDataOutput, recvCount, recvOffset, \ + expertGlobalOffset, srcrankInExpertOffset, rInSrcrankOffset, totalRecvTokens, maxBs, recvTokensPerExpert, len, \ + round, perRoundTokens, numTokens, op, root, cycleCount, scale, scaleCount, offset, localRank, localRankSize, \ + totalWinSize, tiling template class NotifyDispatch { - constexpr static int64_t MAX_RANK_PER_CORE = 8; - constexpr static int64_t MULTI_RANK_SIZE = 40; - constexpr static int64_t MAX_BUFFER_NUMBER = 10; + constexpr static int32_t MAX_RANK_PER_CORE = 8; + constexpr static int32_t MULTI_RANK_SIZE = 48; + constexpr static int32_t MAX_BUFFER_NUMBER = 10; constexpr static uint32_t UB_FLAG_SIZE = 8U * 1024U; + + constexpr static int32_t TOTAL_CNT_CORE = 0; + constexpr static int32_t RECV_COUNT_CORE = 1; + constexpr static int32_t RECV_OFFSET_CORE = 2; + constexpr static int32_t MAX_BS_CORE = 3; + constexpr static int32_t RECV_TOKEN_PER_EXP_CORE = 4; + constexpr static int32_t EXP_GLOBAL_OFFSET_CORE = 5; + constexpr static int32_t SRC_RANK_EXP_OFFSET_CORE = 6; + constexpr static int32_t R_IN_SRCRANK_OFFSET_CORE = 7; // Synchronization flag occupies length constexpr static int64_t FLAG_UNIT_INT_NUM = 4; constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1); + constexpr static int32_t BATCH_ROUND = 32; public: __aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag) @@ -61,16 +73,25 @@ class NotifyDispatch recvOffset_ = recvOffset; maxBs_ = maxBs; recvTokensPerExpert_ = recvTokensPerExpert; - recvDataAlignLen = Ceil(numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - tokenPerExpertDataAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataOffsetAlignLen = Ceil(numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; - sendDataAlignLen = Ceil(numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + batchRounds = BATCH_ROUND; + tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendTokensPerRankAlignLen = Ceil(numRanks * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // Initialize core grouping InitCoreGroup(); // Initialize data slicing InitDataSlice(); + totalRecvTokens_ = totalRecvTokens; + recvCount_ = recvCount; + recvOffset_ = recvOffset; + maxBs_ = maxBs; + recvTokensPerExpert_ = recvTokensPerExpert; + expertGlobalOffset_ = expertGlobalOffset; + srcrankInExpertOffset_ = srcrankInExpertOffset; + rInSrcrankOffset_ = rInSrcrankOffset; this->sendDataInput = (__gm__ T *)sendDataInput; this->tokenPerExpertDataInput = (__gm__ int32_t *)tokenPerExpertDataInput; this->sendDataOffsetOutput = (__gm__ T *)sendDataOffsetOutput; @@ -80,9 +101,6 @@ class NotifyDispatch sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput); recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput); recvDataOutGt.SetGlobalBuffer((__gm__ int32_t *)recvDataOutput); - pipe.InitBuffer(sendCountBuf, tokenPerExpertDataAlignLen); - pipe.InitBuffer(sendOffsetBuf, tokenPerExpertDataAlignLen); - pipe.InitBuffer(recvDataBuf, recvDataAlignLen); } __aicore__ inline void Process() @@ -98,12 +116,15 @@ class NotifyDispatch ShareToShareSlice(); } SyncAll(); - ReorderOutput(); + pipe.Reset(); BuildTotalRecvTokens(); BuildRecvCount(); BuildRecvOffset(); BuildMaxBs(); BuildRecvTokenPerExp(); + BuildExpGlobalOffset(); + BuildsrcRankInExpOffset(); + BuildRInSrcrankOffset(); hccl_.Finalize(); } @@ -126,33 +147,95 @@ class NotifyDispatch __aicore__ inline void AssembleSendData() { + pipe.Reset(); pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen); pipe.InitBuffer(sendDataBuf, sendDataAlignLen); pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen); + int localExpertsNum = numExperts / rankSize; + int newSendDataAlignLen = + Ceil(batchRounds * localExpertsNum * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(newSendDataBuf, newSendDataAlignLen); + tokenPerExpertTensor = tokenPerExpertDataBuf.Get(); + sendDataTensor = sendDataBuf.Get(); + sendDataOffsetTensor = sendDataOffsetBuf.Get(); + newSendDataTensor = newSendDataBuf.Get(); + + int realRound = (numTokens + perRoundTokens - 1) / perRoundTokens; + int lastRoundNumTokens = numTokens % perRoundTokens; + if (lastRoundNumTokens == 0 && numTokens > 0) { + lastRoundNumTokens = perRoundTokens; + } + int totalRounds = round; + + for (int rBase = 0; rBase < totalRounds; rBase += batchRounds) { + int currentBatch = (rBase + batchRounds > totalRounds) ? (totalRounds - rBase) : batchRounds; + uint32_t copyLen = currentBatch * numExperts * sizeof(int32_t); + DataCopyExtParams tokenPerExpertParams = {1U, copyLen, 0U, 0U, 0U}; + DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopyPad(tokenPerExpertTensor, tokenPerExpertDataInputGt[rBase * numExperts], tokenPerExpertParams, + copyPadExtParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + for (int r = 0; r < currentBatch; r++) { + int absRound = rBase + r; + int prefixSum = 0; + if (absRound < realRound) { + for (int i = 0; i < numExperts; ++i) { + int numTokensExpert = tokenPerExpertTensor(r * numExperts + i); // S operation + int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; + sendDataTensor(baseUB) = numTokensExpert; + sendDataTensor(baseUB + 1) = prefixSum; + int roundNumTokens = (absRound == realRound - 1 ? lastRoundNumTokens : perRoundTokens); + sendDataTensor(baseUB + 2) = roundNumTokens; + sendDataOffsetTensor(r * numExperts + i) = prefixSum; + prefixSum += numTokensExpert; + } + } else { + // padding round + for (int i = 0; i < numExperts; ++i) { + int baseUB = r * numExperts * sendPerGroup + i * sendPerGroup; + sendDataTensor(baseUB) = 0; + sendDataTensor(baseUB + 1) = 0; + sendDataTensor(baseUB + 2) = 0; + sendDataOffsetTensor(r * numExperts + i) = 0; + } + } + } - __ubuf__ int32_t *tokenPerExpertUB = (__ubuf__ int32_t *)get_imm(96); - CpGM2UB(tokenPerExpertUB, (__gm__ int32_t *)tokenPerExpertDataInputGt.GetPhyAddr(), tokenPerExpertDataAlignLen); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - - __ubuf__ T *sendDataOffsetUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen); - __ubuf__ T *sendDataUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen + sendDataOffsetAlignLen); - - int prefixSum = 0; - for (int i = 0; i < numExperts; ++i) { - int numTokensExpert = tokenPerExpertUB[i]; - sendDataUB[i * sendPerGroup] = numTokensExpert; - sendDataUB[i * sendPerGroup + 1] = prefixSum; - sendDataUB[i * sendPerGroup + 2] = numTokens; - sendDataOffsetUB[i] = prefixSum; - - prefixSum += numTokensExpert; + uint32_t offsetCopyLen = currentBatch * numExperts * sizeof(T); + DataCopyExtParams sendDataOffsetParams = {1U, offsetCopyLen, 0U, 0U, 0U}; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopyPad(sendDataOffsetOutputGt[rBase * numExperts], sendDataOffsetTensor, sendDataOffsetParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + for (int tr = 0; tr < rankSize; ++tr) { + for (int r = 0; r < currentBatch; ++r) { + for (int le = 0; le < localExpertsNum; ++le) { + int globalExpertIdx = tr * localExpertsNum + le; + int srcIdx = (r * numExperts + globalExpertIdx) * sendPerGroup; + int dstIdx = (r * localExpertsNum + le) * sendPerGroup; + newSendDataTensor(dstIdx) = sendDataTensor(srcIdx); + newSendDataTensor(dstIdx + 1) = sendDataTensor(srcIdx + 1); + newSendDataTensor(dstIdx + 2) = sendDataTensor(srcIdx + 2); + } + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + uint32_t dataCopyLen = currentBatch * localExpertsNum * sendPerGroup * sizeof(int32_t); + DataCopyExtParams copyParams = {1U, dataCopyLen, 0U, 0U, 0U}; + uint64_t gmOffset = (uint64_t)tr * totalRounds * localExpertsNum * sendPerGroup + + (uint64_t)rBase * localExpertsNum * sendPerGroup; + DataCopyPad(sendDataInputGt[gmOffset], newSendDataTensor[0], copyParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } } - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - CpUB2GM((__gm__ T *)sendDataInputGt.GetPhyAddr(), sendDataUB, sendDataAlignLen); - CpUB2GM((__gm__ T *)sendDataOffsetOutputGt.GetPhyAddr(), sendDataOffsetUB, sendDataOffsetAlignLen); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } @@ -161,7 +244,7 @@ class NotifyDispatch __aicore__ inline void InputToShareSlice() { __ubuf__ uint64_t *inputUB = (__ubuf__ uint64_t *)get_imm(0); - int64_t copyOffset = blockIdx * rankNumPerCore; + int32_t copyOffset = blockIdx * rankNumPerCore; copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; if (copyLen > 0) { readGt = sendDataInputGt[copyOffset * perRankDataNum]; @@ -228,7 +311,7 @@ class NotifyDispatch __aicore__ inline void ShareToShareSlice() { __ubuf__ T *inputUB = (__ubuf__ T *)get_imm(96); - int64_t copyOffset = blockIdx * rankNumPerCore; + int32_t copyOffset = blockIdx * rankNumPerCore; copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; if (copyLen > 0) { int checkRank[MAX_RANK_PER_CORE]; @@ -241,7 +324,9 @@ class NotifyDispatch for (int i = 0; i < copyLen; i++) { readGt1[i].SetGlobalBuffer((__gm__ T *)(shareAddrs[checkRank[i]] + IPC_DATA_OFFSET)); } + WaitSyncFlag(magic, 1, copyOffset, rank, copyLen); + for (int i = 0; i < copyLen; i++) { CpGM2GMPingPong(perRankDataNum * sizeof(T), readGt1[i][rank * perRankDataNum], recvDataOutputGt[checkRank[i] * perRankDataNum], COPYONLY); @@ -249,133 +334,248 @@ class NotifyDispatch } } - __aicore__ inline void ReorderOutput() + __aicore__ inline void ReorderOutput(uint32_t rStart, uint32_t currentBatchRounds) { - recvDataTensor = recvDataBuf.Get(); - DataCopyExtParams recvDataParams = {1U, static_cast(recvDataAlignLen), 0, 0, 0}; - DataCopyPadExtParams DataCopyPadExtParams{false, 0U, 0U, 0U}; - DataCopyPad(recvDataTensor, recvDataOutGt, recvDataParams, DataCopyPadExtParams); + recvDataTensor = recvDataBuf.Get(); + Duplicate(recvDataTensor, 0, recvDataAlignLen / sizeof(int32_t)); + + uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup; + uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t); + DataCopyExtParams recvDataParams = {1U, static_cast(singleRankBatchDataLen), 0, 0, 0}; + DataCopyPadExtParams DataCopyPadExtParams{false, 0U, 0U, 0U}; + + for (uint32_t i = 0; i < rankSize; i++) { + uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup; + uint32_t dstOffset = i * singleRankBatchElemCount; + // 搬运该Rank下的 currentBatchRounds 数据 + DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams); + } + SyncFunc(); } - __aicore__ inline void ReorderSendCountOutput() + __aicore__ inline void ReorderSendCountOutput(uint32_t currentBatchRounds) { - sendCountTensor = sendCountBuf.Get(); - Duplicate(sendCountTensor, 0, tokenPerExpertDataAlignLen / sizeof(int32_t)); + recvCountTensor = recvCountBuf.Get(); + Duplicate(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V + SyncFunc(); - SyncFunc(); - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * numExperts / rankSize + expId); - sendCountTensor(index) = recvDataTensor(pair_idx); + uint32_t computeNum = currentBatchRounds * numLocalExperts; + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t computeNumIn = r * numLocalExperts; + uint32_t computeNumOut = r * numExperts; + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = expId * rankSize + srcRank; + uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx); + } } } } - __aicore__ inline void ReorderSendOffsetOutput() + __aicore__ inline void ReorderSendOffsetOutput(uint32_t currentBatchRounds) { - sendOffsetTensor = sendOffsetBuf.Get(); - Duplicate(sendOffsetTensor, 0, tokenPerExpertDataAlignLen / sizeof(int32_t)); + sendOffsetTensor = sendOffsetBuf.Get(); + Duplicate(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t)); SyncFunc(); - SyncFunc(); - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * numExperts / rankSize + expId); - sendOffsetTensor(index) = recvDataTensor(pair_idx + 1); + uint32_t computeNum = currentBatchRounds * numLocalExperts; + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t computeNumIn = r * numLocalExperts; + uint32_t computeNumOut = r * numExperts; + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = expId * rankSize + srcRank; + uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1); + } } } } - __aicore__ inline void ReorderMaxBsOutput() + __aicore__ inline void ReorderSendTokensPerRankOutput() { + pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); + pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen); + sendTokensPerRankTensor = sendTokensPerRankBuf.Get(); + seenRoundTensor = seenRoundBuf.Get(); + Duplicate(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + SyncFunc(); SyncFunc(); - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t pair_idx = sendPerGroup * (srcRank * numExperts / rankSize + expId); - uint32_t BsCnt = recvDataTensor(pair_idx + 2); - maxBsNum = maxBsNum < BsCnt ? BsCnt : maxBsNum; + for (uint32_t r = 0; r < round; ++r) { + Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + SyncFunc(); + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = expId * rankSize + srcRank; + uint32_t pair_idx = + sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId); + if (!seenRoundTensor(srcRank)) { + sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); + seenRoundTensor(srcRank) = 1; + } + } } + SyncFunc(); } } __aicore__ inline void BuildTotalRecvTokens() { - // 只需要sendCountTensor - if (blockIdx > 0) { + if (blockIdx != TOTAL_CNT_CORE) { return; } - ReorderSendCountOutput(); - pipe.InitBuffer(tmpBuf_, Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf2_, Ceil(numExperts * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf3_, Ceil(numExperts * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); - pipe.InitBuffer(tmpBuf4_, Ceil(numExperts * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + int32_t sumVal = 0; + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf2_, Ceil(batchRounds * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); // 32KB + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + LocalTensor batchCntFloat = tmpBuf2_.Get(); + LocalTensor batchSumCntLt = recvCountBuf.Get(); + LocalTensor sharedTmpBuffer = recvDataBuf.Get(); + uint32_t currComputeNum = currentBatchRounds * numExperts; + SyncFunc(); + Cast(batchCntFloat, recvCountTensor, RoundMode::CAST_NONE, currComputeNum); + PipeBarrier(); + ReduceSum(batchSumCntLt, batchCntFloat, sharedTmpBuffer, currComputeNum); + SyncFunc(); + sumVal += static_cast(batchSumCntLt.GetValue(0)); + SyncFunc(); + } + pipe.InitBuffer(tmpBuf_, UB_ALIGN_SIZE); LocalTensor totalCntLt = tmpBuf_.Get(); - LocalTensor floatExpTokenCntLt = tmpBuf2_.Get(); - LocalTensor floatExpTokenSumCntLt = tmpBuf3_.Get(); - LocalTensor sharedTmpBuffer = tmpBuf4_.Get(); - SyncFunc(); - Cast(floatExpTokenCntLt, sendCountTensor, RoundMode::CAST_NONE, numExperts); - PipeBarrier(); - ReduceSum(floatExpTokenSumCntLt, floatExpTokenCntLt, sharedTmpBuffer, numExperts); - SyncFunc(); - int32_t sumVal = static_cast(floatExpTokenSumCntLt.GetValue(0)); - PipeBarrier(); totalCntLt(0) = sumVal; - PipeBarrier(); - SyncFunc(); + SyncFunc(); // 拷贝到outputGT GlobalTensor totalCntGt; totalCntGt.SetGlobalBuffer((__gm__ int32_t *)totalRecvTokens_); DataCopyExtParams copyParams{1, static_cast(1 * sizeof(int32_t)), 0, 0, 0}; DataCopyPad(totalCntGt, totalCntLt, copyParams); + SyncFunc(); } __aicore__ inline void BuildRecvCount() { - // 只需要sendCountTensor - if (blockIdx != 1) { + // 只需要recvCountTensor + if (blockIdx != RECV_COUNT_CORE) { return; } - ReorderSendCountOutput(); - int32_t recvCountNum = 0; - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - recvCountNum += sendCountTensor(index); - sendCountTensor(index) = recvCountNum; + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + int32_t recvCountNum = 0; + for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + recvCountNum += recvCountTensor(index); + recvCountTensor(index) = recvCountNum; + } + } } + GlobalTensor recvCntGt; + recvCntGt.SetGlobalBuffer((__gm__ int32_t *)recvCount_); + uint32_t globalOffset = rStart * numExperts; + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * numExperts * sizeof(int32_t)), 0, + 0, 0}; + SyncFunc(); + DataCopyPad(recvCntGt[globalOffset], recvCountTensor, copyParams); + + SyncFunc(); } - GlobalTensor recvCntGt; - recvCntGt.SetGlobalBuffer((__gm__ int32_t *)recvCount_); - DataCopyExtParams copyParams{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(recvCntGt, sendCountTensor, copyParams); } __aicore__ inline void BuildRecvOffset() { - // 只需要sendOffsetTensor - if (blockIdx != 2) { + if (blockIdx != RECV_OFFSET_CORE) { return; } - ReorderSendOffsetOutput(); - GlobalTensor recvOffsetGt; - recvOffsetGt.SetGlobalBuffer((__gm__ int32_t *)recvOffset_); - DataCopyExtParams copyParams{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; - SyncFunc(); - DataCopyPad(recvOffsetGt, sendOffsetTensor, copyParams); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen); + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendOffsetOutput(currentBatchRounds); + + GlobalTensor recvOffsetGt; + recvOffsetGt.SetGlobalBuffer((__gm__ int32_t *)recvOffset_); + uint32_t globalOffset = rStart * numExperts; + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * numExperts * sizeof(int32_t)), 0, + 0, 0}; + SyncFunc(); + DataCopyPad(recvOffsetGt[globalOffset], sendOffsetTensor, copyParams); + + SyncFunc(); + } } __aicore__ inline void BuildMaxBs() { // 只需要maxBsNum - if (blockIdx != 3) { + if (blockIdx != MAX_BS_CORE) { return; } - ReorderMaxBsOutput(); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + + pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); + pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen); + sendTokensPerRankTensor = sendTokensPerRankBuf.Get(); + seenRoundTensor = seenRoundBuf.Get(); + Duplicate(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + + SyncFunc(); + SyncFunc(); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; ++r) { + Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); + SyncFunc(); + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds + + r * numLocalExperts + expId); + if (!seenRoundTensor(srcRank)) { + sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); + seenRoundTensor(srcRank) = 1; + } + } + } + } + SyncFunc(); + } + + for (uint32_t srcRank = 0; srcRank < numRanks; ++srcRank) { + uint32_t tempBs = sendTokensPerRankTensor(srcRank); + maxBsNum = maxBsNum >= tempBs ? maxBsNum : tempBs; + } GlobalTensor maxBsGt; maxBsGt.SetGlobalBuffer((__gm__ int32_t *)maxBs_); maxBsGt.SetValue(0, maxBsNum); @@ -384,27 +584,198 @@ class NotifyDispatch __aicore__ inline void BuildRecvTokenPerExp() { - // 只需要sendCountTensor - if (blockIdx != 4) { + // 只需要recvCountTensor + if (blockIdx != RECV_TOKEN_PER_EXP_CORE) { return; } - ReorderSendCountOutput(); - pipe.InitBuffer(tmpBuf_, Ceil(numExperts / rankSize * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + pipe.InitBuffer(tmpBuf_, Ceil(batchRounds * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); LocalTensor tmpTensor = tmpBuf_.Get(); - for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { - int32_t localRecvCount = 0; - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - localRecvCount += sendCountTensor(index); + GlobalTensor recvTokenPerExpGt; + recvTokenPerExpGt.SetGlobalBuffer((__gm__ int32_t *)recvTokensPerExpert_); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + SyncFunc(); + Duplicate(tmpTensor, 0, batchRounds * numLocalExperts); + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount += recvCountTensor(index); + } + tmpTensor(r * numLocalExperts + expId) = localRecvCount; + } } - tmpTensor(expId) = localRecvCount; + SyncFunc(); + DataCopyExtParams copyParams{ + 1, static_cast(currentBatchRounds * numLocalExperts * sizeof(int32_t)), 0, 0, 0}; + SyncFunc(); + SyncFunc(); + DataCopyPad(recvTokenPerExpGt[rStart * numLocalExperts], tmpTensor, copyParams); + + SyncFunc(); } + } + + __aicore__ inline void BuildExpGlobalOffset() + { + // 只需要recvCountTensor + if (blockIdx != EXP_GLOBAL_OFFSET_CORE) { + return; + } + + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + + // tmpBuf_,需要常驻,消耗:16 *4 + pipe.InitBuffer(tmpBuf_, Ceil(numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor tmpTensor = tmpBuf_.Get(); + Duplicate(tmpTensor, 0, numLocalExperts); + + SyncFunc(); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount += recvCountTensor(index); + } + tmpTensor(expId) += localRecvCount; + } + } + SyncFunc(); // waiting for recvCountTensor + } + pipe.InitBuffer(tmpBuf2_, Ceil(numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor expTensor = tmpBuf2_.Get(); + expTensor(0) = 0; + for (uint32_t expId = 1; expId < numLocalExperts; ++expId) { + expTensor(expId) = expTensor(expId - 1) + tmpTensor(expId - 1); + } + GlobalTensor expGlobalOffsetGt; + expGlobalOffsetGt.SetGlobalBuffer((__gm__ int32_t *)expertGlobalOffset_); + DataCopyExtParams copyParams{1, static_cast(numLocalExperts * sizeof(int32_t)), 0, 0, 0}; SyncFunc(); - GlobalTensor recvTokenPerExpGt; - recvTokenPerExpGt.SetGlobalBuffer((__gm__ int32_t *)recvTokensPerExpert_); - DataCopyExtParams copyParams{1, static_cast(numExperts / rankSize * sizeof(int32_t)), 0, 0, 0}; + DataCopyPad(expGlobalOffsetGt, expTensor, copyParams); + } + + __aicore__ inline void BuildsrcRankInExpOffset() + { + if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) { + return; + } + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + + pipe.InitBuffer(tmpBuf_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor expSrcTotalTensor = tmpBuf_.Get(); + Duplicate(expSrcTotalTensor, 0, numExperts); + SyncFunc(); + + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t localRecvCount = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = r * numExperts + expId * rankSize + srcRank; + localRecvCount = recvCountTensor(index); + expSrcTotalTensor(expId * numRanks + srcRank) += localRecvCount; + } + } + } + } + + pipe.InitBuffer(tmpBuf2_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor srcRankInExpOffsetTensor = tmpBuf2_.Get(); + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + int32_t cumOffset = 0; + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + srcRankInExpOffsetTensor(expId * numRanks + srcRank) = cumOffset; + cumOffset += expSrcTotalTensor(expId * numRanks + srcRank); + } + } + GlobalTensor srcRankInExpOffsetGt; + srcRankInExpOffsetGt.SetGlobalBuffer((__gm__ int32_t *)srcrankInExpertOffset_); + DataCopyExtParams copyParams{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; SyncFunc(); - DataCopyPad(recvTokenPerExpGt, tmpTensor, copyParams); + DataCopyPad(srcRankInExpOffsetGt, srcRankInExpOffsetTensor, copyParams); + } + + __aicore__ inline void BuildRInSrcrankOffset() + { + if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) { + return; + } + recvDataAlignLen = + Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + pipe.InitBuffer(recvDataBuf, recvDataAlignLen); + sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb + pipe.InitBuffer(recvCountBuf, sendCountAlignLen); + + pipe.InitBuffer(tmpBuf2_, Ceil(numRanks * numLocalExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + LocalTensor expSrcCumPrevTensor = tmpBuf2_.Get(); + Duplicate(expSrcCumPrevTensor, 0, numExperts); + + pipe.InitBuffer(tmpBuf_, Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); + GlobalTensor rInSrcrankOffsetGt; + rInSrcrankOffsetGt.SetGlobalBuffer((__gm__ int32_t *)rInSrcrankOffset_); + for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { + uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; + + ReorderOutput(rStart, currentBatchRounds); + ReorderSendCountOutput(currentBatchRounds); + LocalTensor rInSrcrankOffsetTensor = tmpBuf_.Get(); + + DataCopyExtParams copyParams{1, static_cast(currentBatchRounds * sizeof(int32_t)), 0, 0, 0}; + SyncFunc(); + + for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { + for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { + uint32_t index = expId * rankSize + srcRank; + uint32_t ubBlockOffset = (expId * rankSize + srcRank) * currentBatchRounds; + uint32_t ubBlockOffsetAlign = Ceil(ubBlockOffset * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t ubBlockAlignIndex = ubBlockOffsetAlign / sizeof(int32_t); + uint32_t gmOffset = expId * numRanks * round + srcRank * round + rStart; + + Duplicate(rInSrcrankOffsetTensor, 0, currentBatchRounds * numExperts); + SyncFunc(); + for (uint32_t r = 0; r < currentBatchRounds; r++) { + uint32_t pairIdx = r * numExperts + index; + int32_t recvCnt = recvCountTensor(pairIdx); + int32_t offset = expSrcCumPrevTensor(index); + rInSrcrankOffsetTensor(ubBlockAlignIndex + r) = offset; + expSrcCumPrevTensor(index) = offset + recvCnt; + } + uint32_t copyLenByte = currentBatchRounds * sizeof(int32_t); + DataCopyPad(rInSrcrankOffsetGt[gmOffset], rInSrcrankOffsetTensor[ubBlockAlignIndex], copyParams); + SyncFunc(); + } + } + SyncFunc(); + } } __aicore__ inline int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum); @@ -434,27 +805,19 @@ class NotifyDispatch __gm__ int *tokenPerExpertDataInput; __gm__ T *sendDataOffsetOutput; __gm__ T *recvDataOutput; - int64_t isPad = 0; - int64_t maxSliceNum; - int64_t revLen = 0; - int64_t sendLen = 0; - int64_t sliceLen; int64_t perNodeDataNum; int64_t perRankDataNum; int64_t curRankDataNum; - int64_t sendOffset[MULTI_RANK_SIZE]; - int64_t revOffset[MULTI_RANK_SIZE]; - int64_t inputDataLen[MULTI_RANK_SIZE]; - - int64_t nodeNum; - int64_t localRankId; - int64_t localNodeId; - int64_t coreNumPerStageX; // Number of cores used per stage - int64_t coreNumPerStageY; // Number of cores used per stage - int64_t coreNumPerStageZ; // Number of cores used per stage - int64_t coreNumPerRank; // Number of cores allocated per rank - int64_t rankNumPerCore; // Number of ranks responsible per core - int64_t copyLen; // Length of the current data slice being copied (in terms of T) + + int32_t nodeNum; + int32_t localRankId; + int32_t localNodeId; + int32_t coreNumPerStageX; // Number of cores used per stage + int32_t coreNumPerStageY; // Number of cores used per stage + int32_t coreNumPerStageZ; // Number of cores used per stage + int32_t coreNumPerRank; // Number of cores allocated per rank + int32_t rankNumPerCore; // Number of ranks responsible per core + int32_t copyLen; // Length of the current data slice being copied (in terms of T) // for coll int rank; @@ -465,41 +828,62 @@ class NotifyDispatch int yRankSize = 0; int xRankIdx = 0; int yRankIdx = 0; + uint64_t totalWinSize_ = 0; uint32_t extraFlag; + int round; + int32_t perRoundTokens; int numTokens; + int numRanks; int sendPerGroup = 3; int root; int64_t len; - int64_t numExperts; + int32_t numExperts; + int32_t numLocalExperts; uint64_t magic{0}; - int64_t blockIdx; // Index of the current aicore - int64_t blockNum; // Total number of aicores for the current rank + int32_t blockIdx; // Index of the current aicore + int32_t blockNum; // Total number of aicores for the current rank uint32_t maxBsNum{0}; + int batchRounds{32}; + GM_ADDR scale; + GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses GM_ADDR totalRecvTokens_; GM_ADDR recvCount_; GM_ADDR recvOffset_; + GM_ADDR expertGlobalOffset_; + GM_ADDR srcrankInExpertOffset_; + GM_ADDR rInSrcrankOffset_; GM_ADDR maxBs_; GM_ADDR recvTokensPerExpert_; - GM_ADDR scale; - GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // List of shared memory addresses __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; Hccl hccl_; TPipe pipe; TBuf tBuf; TBuf<> tokenPerExpertDataBuf; TBuf<> sendDataOffsetBuf; - TBuf<> sendCountBuf; + TBuf<> recvCountBuf; TBuf<> sendOffsetBuf; TBuf<> sendDataBuf; + TBuf<> newSendDataBuf; TBuf<> recvDataBuf; - LocalTensor sendCountTensor; + TBuf<> sendTokensPerRankBuf; + TBuf<> seenRoundBuf; + + LocalTensor tokenPerExpertTensor; + LocalTensor sendDataTensor; + LocalTensor sendDataOffsetTensor; + LocalTensor newSendDataTensor; + LocalTensor recvCountTensor; LocalTensor sendOffsetTensor; + LocalTensor sendTokensPerRankTensor; LocalTensor recvDataTensor; + LocalTensor seenRoundTensor; uint32_t sendDataAlignLen{0}; uint32_t tokenPerExpertDataAlignLen{0}; - uint32_t sendDataOffsetAlignLen{0}; uint32_t recvDataAlignLen{0}; + uint32_t sendDataOffsetAlignLen{0}; + uint32_t sendCountAlignLen{0}; + uint32_t sendTokensPerRankAlignLen{0}; TBuf<> tmpBuf_; TBuf<> tmpBuf2_; @@ -531,7 +915,7 @@ __aicore__ inline uint64_t NotifyDispatch::GetMagicValue(void) { uint64_t magic = 0; GlobalTensor selfDataStatusTensor; - GM_ADDR statusDataSpaceGm = hccl_.GetWindowsInAddr(rank) + winContext_[COMM_EP_IDX]->winSize - Moe::STATE_SIZE * 3; + GM_ADDR statusDataSpaceGm = hccl_.GetWindowsInAddr(rank) + totalWinSize_ - Moe::STATE_SIZE * 3; selfDataStatusTensor.SetGlobalBuffer((__gm__ uint64_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); DataCacheCleanAndInvalid( selfDataStatusTensor[blockIdx * UB_ALIGN_SIZE]); @@ -548,7 +932,11 @@ __aicore__ inline void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL { this->root = root; this->len = len; - this->numExperts = len / sendPerGroup; + this->round = round; + this->perRoundTokens = perRoundTokens; + this->numRanks = rankSize; + this->numExperts = len / sendPerGroup / round; + this->numLocalExperts = numExperts / rankSize; this->numTokens = numTokens; this->scale = scale; this->localRank = localRank; @@ -557,11 +945,12 @@ __aicore__ inline void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL this->yRankSize = rankSize / localRankSize; this->xRankIdx = rank % localRankSize; this->yRankIdx = rank / localRankSize; + this->totalWinSize_ = totalWinSize; blockIdx = GetBlockIdx(); blockNum = GetBlockNum(); uint8_t ctxIdx; - auto tilingData = (__gm__ NotifyDispatchTilingData *)tilingGM; + auto tilingData = (__gm__ NotifyDispatchTilingData *)tiling; __gm__ void *mc2InitTiling = (__gm__ void *)(&(tilingData->mc2InitTiling)); __gm__ void *mc2CcTiling = (__gm__ void *)(&(tilingData->mc2CcTiling1)); @@ -569,21 +958,20 @@ __aicore__ inline void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL hccl_.Init(contextGM0, mc2InitTiling); hccl_.SetCcTiling(mc2CcTiling); - this->winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)contextGM0; - // notifyMemoryOffset = winContext_[COMM_EP_IDX]->winSize - IPC_BUFF_MAX_SIZE * 2; + + winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)contextGM0; this->magic = GetMagicValue(); ctxIdx = COMM_EP_IDX; + uint64_t winDataOffset = ((totalWinSize_ - 4 * Moe::STATE_SIZE) / 2) * (this->magic % PING_PONG_SIZE); - shareAddrs[rank] = - GetWindAddrByRankId(rank, ctxIdx) + (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + shareAddrs[rank] = GetWindAddrByRankId(rank, ctxIdx) + winDataOffset; - int64_t rankNumPerCore = (rankSize + blockNum - 1) / blockNum; - int64_t copyOffset = blockIdx * rankNumPerCore; - int64_t copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; + int32_t rankNumPerCore = (rankSize + blockNum - 1) / blockNum; + int32_t copyOffset = blockIdx * rankNumPerCore; + int32_t copyLen = rankSize - copyOffset < rankNumPerCore ? rankSize - copyOffset : rankNumPerCore; if (copyLen > 0) { for (int i = copyOffset; i < copyOffset + copyLen; ++i) { - shareAddrs[i] = - GetWindAddrByRankId(i, ctxIdx) + (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + shareAddrs[i] = GetWindAddrByRankId(i, ctxIdx) + winDataOffset; } } @@ -594,11 +982,9 @@ __aicore__ inline void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL int maxCore = coreNumPerRank * rankSize; // Calculate the maximum number of cores that can be used for reading, // cores exceeding this number will not take action if (blockIdx < maxCore) { - int readRank = - blockIdx / - coreNumPerRank; // Calculate the rank to be read based on the block, 48 cores divided into 4 groups - shareAddrs[readRank] = GetWindAddrByRankId(readRank, ctxIdx) + - (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + // Calculate the rank to be read based on the block, 48 cores divided into 4 groups + int readRank = blockIdx / coreNumPerRank; + shareAddrs[readRank] = GetWindAddrByRankId(readRank, ctxIdx) + winDataOffset; } pipe.InitBuffer(tBuf, UB_FLAG_SIZE); diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling.h b/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling.h index 4e0e66957..c96389e6b 100644 --- a/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling.h +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling.h @@ -14,6 +14,7 @@ struct NotifyDispatchInfo { uint32_t perRoundTokens; uint32_t aivNum; uint64_t totalUbSize; + uint64_t totalWinSize; }; struct NotifyDispatchTilingData { From 176b24909e73da847e600d245f14a490331005ce Mon Sep 17 00:00:00 2001 From: luanyundu <1425036963@qq.com> Date: Wed, 4 Feb 2026 15:06:35 +0800 Subject: [PATCH 2/2] fix CI bug that misalign when localExpertsNum less than 8 or more than 256 Co-authored-by: WSEmma --- .../op_host/cam_moe_combine_normal_tiling.cc | 2 +- csrc/deepep/ops2/op_kernel/notify_dispatch.h | 120 ++++++++---------- 2 files changed, 57 insertions(+), 65 deletions(-) diff --git a/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc b/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc index 68b731ac1..d6cee1946 100644 --- a/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc +++ b/csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc @@ -398,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); OP_TILING_CHECK(xDim0 != topkWeightsDim0, - OP_LOGE(nodeName, "x's dim0 is greater than bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), + OP_LOGE(nodeName, "x's dim0 not equal bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0), return false); OP_TILING_CHECK(xDim1 != recvXDim1, OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1), diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch.h b/csrc/deepep/ops2/op_kernel/notify_dispatch.h index 51dc076b4..c9e8bee4d 100644 --- a/csrc/deepep/ops2/op_kernel/notify_dispatch.h +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch.h @@ -53,7 +53,8 @@ class NotifyDispatch // Synchronization flag occupies length constexpr static int64_t FLAG_UNIT_INT_NUM = 4; constexpr static int64_t MAGIC_MASK = ~((1LL << 32) - 1); - constexpr static int32_t BATCH_ROUND = 32; + constexpr static int32_t EXPERT_NORMAL_NUM = 256; + constexpr static int32_t BATCH_ROUND = 16; public: __aicore__ inline NotifyDispatch(int rank, int rankSize, uint32_t extraFlag) @@ -73,7 +74,7 @@ class NotifyDispatch recvOffset_ = recvOffset; maxBs_ = maxBs; recvTokensPerExpert_ = recvTokensPerExpert; - batchRounds = BATCH_ROUND; + batchRounds = numExperts > EXPERT_NORMAL_NUM ? BATCH_ROUND : BATCH_ROUND * 2; tokenPerExpertDataAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; sendDataOffsetAlignLen = Ceil(batchRounds * numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; sendDataAlignLen = Ceil(batchRounds * numExperts * sendPerGroup * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; @@ -338,16 +339,16 @@ class NotifyDispatch { recvDataTensor = recvDataBuf.Get(); Duplicate(recvDataTensor, 0, recvDataAlignLen / sizeof(int32_t)); - uint32_t singleRankTotalElemCount = round * numLocalExperts * sendPerGroup; uint32_t singleRankBatchElemCount = currentBatchRounds * numLocalExperts * sendPerGroup; uint32_t singleRankBatchDataLen = singleRankBatchElemCount * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); // 目标地址也改变,使用对齐后的地址 DataCopyExtParams recvDataParams = {1U, static_cast(singleRankBatchDataLen), 0, 0, 0}; DataCopyPadExtParams DataCopyPadExtParams{false, 0U, 0U, 0U}; - for (uint32_t i = 0; i < rankSize; i++) { uint32_t srcOffset = i * singleRankTotalElemCount + rStart * numLocalExperts * sendPerGroup; - uint32_t dstOffset = i * singleRankBatchElemCount; + uint32_t dstOffset = i * strideElem; // 搬运该Rank下的 currentBatchRounds 数据 DataCopyPad(recvDataTensor[dstOffset], recvDataOutputGt[srcOffset], recvDataParams, DataCopyPadExtParams); } @@ -360,6 +361,10 @@ class NotifyDispatch Duplicate(recvCountTensor, 0, sendCountAlignLen / sizeof(int32_t)); // V SyncFunc(); + // 新增 + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); uint32_t computeNum = currentBatchRounds * numLocalExperts; for (uint32_t r = 0; r < currentBatchRounds; ++r) { uint32_t computeNumIn = r * numLocalExperts; @@ -367,7 +372,8 @@ class NotifyDispatch for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; recvCountTensor(computeNumOut + index) = recvDataTensor(pair_idx); } } @@ -379,6 +385,10 @@ class NotifyDispatch sendOffsetTensor = sendOffsetBuf.Get(); Duplicate(sendOffsetTensor, 0, sendCountAlignLen / sizeof(int32_t)); SyncFunc(); + // 新增 + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); uint32_t computeNum = currentBatchRounds * numLocalExperts; for (uint32_t r = 0; r < currentBatchRounds; ++r) { uint32_t computeNumIn = r * numLocalExperts; @@ -386,54 +396,28 @@ class NotifyDispatch for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = sendPerGroup * (srcRank * computeNum + computeNumIn + expId); + uint32_t offsetInRank = sendPerGroup * (computeNumIn + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; sendOffsetTensor(computeNumOut + index) = recvDataTensor(pair_idx + 1); } } } } - __aicore__ inline void ReorderSendTokensPerRankOutput() - { - pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); - pipe.InitBuffer(seenRoundBuf, sendTokensPerRankAlignLen); - sendTokensPerRankTensor = sendTokensPerRankBuf.Get(); - seenRoundTensor = seenRoundBuf.Get(); - Duplicate(sendTokensPerRankTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); - SyncFunc(); - SyncFunc(); - for (uint32_t r = 0; r < round; ++r) { - Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); - SyncFunc(); - for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { - for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t index = expId * rankSize + srcRank; - uint32_t pair_idx = - sendPerGroup * (srcRank * numLocalExperts * round + r * numLocalExperts + expId); - if (!seenRoundTensor(srcRank)) { - sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); - seenRoundTensor(srcRank) = 1; - } - } - } - SyncFunc(); - } - } - __aicore__ inline void BuildTotalRecvTokens() { if (blockIdx != TOTAL_CNT_CORE) { return; } int32_t sumVal = 0; - - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); pipe.InitBuffer(tmpBuf2_, Ceil(batchRounds * sizeof(float), UB_ALIGN_SIZE) * UB_ALIGN_SIZE); // 32KB - for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; ReorderOutput(rStart, currentBatchRounds); @@ -470,17 +454,17 @@ class NotifyDispatch if (blockIdx != RECV_COUNT_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; - ReorderOutput(rStart, currentBatchRounds); ReorderSendCountOutput(currentBatchRounds); - for (uint32_t r = 0; r < currentBatchRounds; ++r) { int32_t recvCountNum = 0; for (uint32_t expId = 0; expId < numExperts / rankSize; ++expId) { @@ -498,7 +482,6 @@ class NotifyDispatch 0, 0}; SyncFunc(); DataCopyPad(recvCntGt[globalOffset], recvCountTensor, copyParams); - SyncFunc(); } } @@ -508,18 +491,17 @@ class NotifyDispatch if (blockIdx != RECV_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(sendOffsetBuf, sendCountAlignLen); - for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; - ReorderOutput(rStart, currentBatchRounds); ReorderSendOffsetOutput(currentBatchRounds); - GlobalTensor recvOffsetGt; recvOffsetGt.SetGlobalBuffer((__gm__ int32_t *)recvOffset_); uint32_t globalOffset = rStart * numExperts; @@ -527,7 +509,6 @@ class NotifyDispatch 0, 0}; SyncFunc(); DataCopyPad(recvOffsetGt[globalOffset], sendOffsetTensor, copyParams); - SyncFunc(); } } @@ -538,8 +519,10 @@ class NotifyDispatch if (blockIdx != MAX_BS_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); pipe.InitBuffer(sendTokensPerRankBuf, sendTokensPerRankAlignLen); @@ -552,16 +535,19 @@ class NotifyDispatch SyncFunc(); for (uint32_t rStart = 0; rStart < round; rStart += batchRounds) { uint32_t currentBatchRounds = (rStart + batchRounds > round) ? (round - rStart) : batchRounds; - + uint32_t singleRankBatchDataLen = currentBatchRounds * numLocalExperts * sendPerGroup * sizeof(int32_t); + uint32_t alignedDataLen = Ceil(singleRankBatchDataLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t strideElem = alignedDataLen / sizeof(int32_t); ReorderOutput(rStart, currentBatchRounds); SyncFunc(); for (uint32_t r = 0; r < currentBatchRounds; ++r) { + uint32_t offsetInRound = r * numLocalExperts; Duplicate(seenRoundTensor, 0, sendTokensPerRankAlignLen / sizeof(int32_t)); SyncFunc(); for (uint32_t expId = 0; expId < numLocalExperts; ++expId) { for (uint32_t srcRank = 0; srcRank < rankSize; ++srcRank) { - uint32_t pair_idx = sendPerGroup * (srcRank * numLocalExperts * currentBatchRounds + - r * numLocalExperts + expId); + uint32_t offsetInRank = sendPerGroup * (offsetInRound + expId); + uint32_t pair_idx = srcRank * strideElem + offsetInRank; if (!seenRoundTensor(srcRank)) { sendTokensPerRankTensor(srcRank) += recvDataTensor(pair_idx + 2); seenRoundTensor(srcRank) = 1; @@ -588,8 +574,10 @@ class NotifyDispatch if (blockIdx != RECV_TOKEN_PER_EXP_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -632,9 +620,10 @@ class NotifyDispatch if (blockIdx != EXP_GLOBAL_OFFSET_CORE) { return; } - - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -679,8 +668,10 @@ class NotifyDispatch if (blockIdx != SRC_RANK_EXP_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -729,8 +720,10 @@ class NotifyDispatch if (blockIdx != R_IN_SRCRANK_OFFSET_CORE) { return; } - recvDataAlignLen = - Ceil(batchRounds * numExperts * sendPerGroup * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + uint32_t singleRankMaxElem = batchRounds * numLocalExperts * sendPerGroup; + uint32_t singleRankMaxLen = singleRankMaxElem * sizeof(int32_t); + uint32_t singleRankAlignLen = Ceil(singleRankMaxLen, UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + recvDataAlignLen = rankSize * singleRankAlignLen; pipe.InitBuffer(recvDataBuf, recvDataAlignLen); sendCountAlignLen = Ceil(batchRounds * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // 32Kb pipe.InitBuffer(recvCountBuf, sendCountAlignLen); @@ -949,7 +942,6 @@ __aicore__ inline void NotifyDispatch::InitSmallFullMesh(KERNELS_ARGS_FUN_ALL blockIdx = GetBlockIdx(); blockNum = GetBlockNum(); uint8_t ctxIdx; - auto tilingData = (__gm__ NotifyDispatchTilingData *)tiling; __gm__ void *mc2InitTiling = (__gm__ void *)(&(tilingData->mc2InitTiling)); __gm__ void *mc2CcTiling = (__gm__ void *)(&(tilingData->mc2CcTiling1));