Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/deepep/ops2/op_host/cam_moe_combine_normal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CamMoeCombineNormal : public OpDef
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.AutoContiguous();
this->Input("topk_idx")
this->Input("token_idx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
Expand Down
88 changes: 36 additions & 52 deletions csrc/deepep/ops2/op_host/cam_moe_combine_normal_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "error_log.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
#include "mc2_tiling_utils.h"
#include "../op_kernel/cam_moe_combine_normal_tiling.h"
#include "tiling_args.h"

Expand All @@ -26,37 +27,12 @@ using namespace ge;
using namespace Moe;

namespace {
class Mc2TilingUtils
{
public:
static uint64_t GetMaxWindowSize()
{
uint16_t defaultWindowSize = 200;
const char *hcclBuffSize = getenv("DEEPEP_HCCL_BUFFSIZE") == nullptr ? "HCCL_BUFFSIZE" : "DEEPEP_HCCL_BUFFSIZE";
if (getenv(hcclBuffSize) == nullptr) {
OP_LOGD("", "Env HCCL_BUFFSIZE don't set");
} else {
try {
std::string envStr(getenv(hcclBuffSize));
defaultWindowSize = std::stoi(envStr);
} catch (const std::invalid_argument &ia) {
OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
} catch (const std::out_of_range &oor) {
OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
}
}
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize);
return maxWindowSize;
}
};
constexpr uint32_t RECV_X_INDEX = 0;
constexpr uint32_t TOKEN_SRC_INFO_INDEX = 1;
constexpr uint32_t EP_RECV_COUNTS_INDEX = 2;
constexpr uint32_t TOPK_WEIGHTS_INDEX = 3;
constexpr uint32_t TOPK_IDX_INDEX = 4;
constexpr uint32_t TOKEN_IDX_INDEX = 4;
constexpr uint32_t TP_RECV_COUNTS_INDEX = 5;

constexpr uint32_t OUTPUT_X_INDEX = 0;
constexpr uint32_t OUTPUT_SEND_COST_INDEX = 1;

Expand All @@ -80,7 +56,7 @@ constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
constexpr int64_t BS_UPPER_BOUND = 8000;
constexpr int64_t BS_UPPER_BOUND = 65536;

constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes
Expand All @@ -96,6 +72,7 @@ constexpr uint64_t DOUBLE_DATA_BUFFER = 2UL;
constexpr uint64_t MAX_OUT_DTYPE_SIZE = 2UL;
constexpr uint64_t UB_ALIGN = 32UL;
constexpr int64_t DISPATCH_STATUS_MAX_SUPPORT_NUM = 1280UL;
constexpr uint64_t INIT_TILINGKEY = 10000UL;

enum class CommQuantMode : int32_t { NON_QUANT = 0, INT12_QUANT = 1, INT8_QUANT = 2 };
using CommQuantModeType = std::underlying_type<CommQuantMode>;
Expand Down Expand Up @@ -231,14 +208,14 @@ static bool CheckInputTensorDim(gert::TilingContext *context, const char *nodeNa
OP_LOGD(nodeName, "topkWeights dim0 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(0));
OP_LOGD(nodeName, "topkWeights dim1 = %ld", topkWeightsStorageShape->GetStorageShape().GetDim(1));

const gert::StorageShape *topIdxStorageShape = context->GetInputShape(TOPK_IDX_INDEX);
OP_TILING_CHECK(topIdxStorageShape == nullptr, OP_LOGE(nodeName, "topkWeights is null."), return false);
OP_TILING_CHECK(topIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OP_LOGE(nodeName, "topkIdx must be 2-dimension, but got %lu dim",
topIdxStorageShape->GetStorageShape().GetDimNum()),
const gert::StorageShape *tokenIdxStorageShape = context->GetInputShape(TOKEN_IDX_INDEX);
OP_TILING_CHECK(tokenIdxStorageShape == nullptr, OP_LOGE(nodeName, "tokenIdx is null."), return false);
OP_TILING_CHECK(tokenIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
OP_LOGE(nodeName, "tokenIdx must be 2-dimension, but got %lu dim",
tokenIdxStorageShape->GetStorageShape().GetDimNum()),
return false);
OP_LOGD(nodeName, "topkIdx dim0 = %ld", topIdxStorageShape->GetStorageShape().GetDim(0));
OP_LOGD(nodeName, "topkIdx dim1 = %ld", topIdxStorageShape->GetStorageShape().GetDim(1));
OP_LOGD(nodeName, "tokenIdx dim0 = %ld", tokenIdxStorageShape->GetStorageShape().GetDim(0));
OP_LOGD(nodeName, "tokenIdx dim1 = %ld", tokenIdxStorageShape->GetStorageShape().GetDim(1));

return true;
}
Expand Down Expand Up @@ -318,13 +295,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
OP_TILING_CHECK((topkWeightsDesc->GetDataType() != ge::DT_FLOAT),
OP_LOGE(nodeName, "topkWeights dataType is invalid, dataType should be float, but is "),
return false);
auto topkIdxDesc = context->GetInputDesc(TOPK_IDX_INDEX);
OP_TILING_CHECK(topkIdxDesc == nullptr, OP_LOGE(nodeName, "topkIdxDesc is null."), return false);
OP_TILING_CHECK((topkIdxDesc->GetDataType() != ge::DT_INT32),
OP_LOGE(nodeName,
"topkIdxForCombine dataType is invalid,"
" dataType should be int32, but is"),
return false);
auto tokenIdxDesc = context->GetInputDesc(TOKEN_IDX_INDEX);
OP_TILING_CHECK(tokenIdxDesc == nullptr, OP_LOGE(nodeName, "tokenIdxDesc is null."), return false);
OP_TILING_CHECK((tokenIdxDesc->GetDataType() != ge::DT_INT32),
OP_LOGE(nodeName, "tokenIdx dataType is invalid, dataType should be int32, but is "), return false);
auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false);
OP_TILING_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()),
Expand Down Expand Up @@ -369,11 +343,11 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName
static_cast<ge::Format>(ge::GetPrimaryFormat(topkWeightsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
OP_LOGE(nodeName, "topkWeightsFormat is invalid"), return false);

auto topkIdxsDesc = context->GetOptionalInputDesc(TOPK_IDX_INDEX);
OP_TILING_CHECK(topkIdxsDesc == nullptr, OP_LOGE(nodeName, "topkIdxsDesc is null."), return false);
auto tokenIdxDesc = context->GetInputDesc(TOKEN_IDX_INDEX);
OP_TILING_CHECK(tokenIdxDesc == nullptr, OP_LOGE(nodeName, "tokenIdxDesc is null."), return false);
OP_TILING_CHECK(
static_cast<ge::Format>(ge::GetPrimaryFormat(topkIdxsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
OP_LOGE(nodeName, "topkIdxsFormat is invalid"), return false);
static_cast<ge::Format>(ge::GetPrimaryFormat(tokenIdxDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
OP_LOGE(nodeName, "tokenIdxFormat is invalid"), return false);

auto xDesc = context->GetOutputDesc(OUTPUT_X_INDEX);
OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false);
Expand Down Expand Up @@ -424,7 +398,7 @@ static bool CheckTensorShape(gert::TilingContext *context, CamMoeCombineNormalTi
int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0);
int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1);
OP_TILING_CHECK(xDim0 != topkWeightsDim0,
OP_LOGE(nodeName, "x's dim0 not equal to bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
OP_LOGE(nodeName, "x's dim0 not equal bs, bs = %ld, x's dim0 = %ld", topkWeightsDim0, xDim0),
return false);
OP_TILING_CHECK(xDim1 != recvXDim1,
OP_LOGE(nodeName, "x's dim1 not equal to h, x's dim1 = %ld, h = %ld", xDim1, recvXDim1),
Expand Down Expand Up @@ -577,19 +551,22 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->camMoeCombineNormalInfo.epWorldSize);
uint64_t k = static_cast<uint64_t>(tilingData->camMoeCombineNormalInfo.k);
uint64_t perRoundTokens = tilingData->camMoeCombineNormalInfo.perRoundTokens;
uint64_t realMaxBs = tilingData->camMoeCombineNormalInfo.realMaxBs;
uint64_t realBs = std::min(perRoundTokens, realMaxBs);
uint32_t maxRound = tilingData->camMoeCombineNormalInfo.maxRound;
// combine数据区 token首地址对齐512
uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
uint64_t actualSize =
(perRoundTokens * k * tokenNeedSizeCombine + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) *
DOUBLE_DATA_BUFFER;
tokenNeedSizeCombine = maxRound > 1 ? tokenNeedSizeCombine * 2 : tokenNeedSizeCombine;
uint64_t actualSize = (realBs * k * tokenNeedSizeCombine + COMBINE_STATE_WIN_OFFSET + NOTIFY_DISPATCH_WIN_OFFSET) *
DOUBLE_DATA_BUFFER;
OP_TILING_CHECK(
(actualSize > maxWindowSize),
OP_LOGE(nodeName,
"HCCL_BUFFSIZE is too SMALL, perRoundTokens = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
"HCCL_BUFFSIZE is too SMALL, realBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u,"
" tokenNeedSizeCombine = %lu, k = %lu, NEEDED_HCCL_BUFFSIZE("
"((perRoundTokens * k * tokenNeedSizeCombine)) + 8MB + 102MB) * 2) = %luMB, "
"((realBs * k * tokenNeedSizeCombine * 2)) + 8MB + 404MB) * 2) = %luMB, "
"HCCL_BUFFSIZE=%luMB.",
perRoundTokens, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL,
realBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeCombine, k, actualSize / MB_SIZE + 1UL,
maxWindowSize / MB_SIZE),
return ge::GRAPH_FAILED);
tilingData->camMoeCombineNormalInfo.totalWinSize = maxWindowSize;
Expand All @@ -615,6 +592,13 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
OP_LOGD(nodeName, "blockdim = %u, aivNum = %lu, ubsize = %lu", blockDim, aivNum, ubSize);
PrintTilingDataInfo(nodeName, *tilingData);

uint64_t tilingKey = INIT_TILINGKEY;
if (maxRound > 1) {
tilingKey += 1;
}
OP_LOGD(nodeName, "tilingKey is %lu", tilingKey);
context->SetTilingKey(tilingKey);

return ge::GRAPH_SUCCESS;
}

Expand Down
36 changes: 8 additions & 28 deletions csrc/deepep/ops2/op_host/cam_moe_dispatch_normal_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "error_log.h"
#include "graph/utils/type_utils.h"
#include "register/op_def_registry.h"
#include "mc2_tiling_utils.h"
#include "../op_kernel/cam_moe_dispatch_normal_tiling.h"
#include "tiling_args.h"

Expand All @@ -25,30 +26,6 @@ using namespace ge;
using namespace Moe;

namespace {
class Mc2TilingUtils
{
public:
static uint64_t GetMaxWindowSize()
{
uint16_t defaultWindowSize = 200;
const char *hcclBuffSize = getenv("DEEPEP_HCCL_BUFFSIZE") == nullptr ? "HCCL_BUFFSIZE" : "DEEPEP_HCCL_BUFFSIZE";
if (getenv(hcclBuffSize) == nullptr) {
OP_LOGD("", "Env HCCL_BUFFSIZE don't set");
} else {
try {
std::string envStr(getenv(hcclBuffSize));
defaultWindowSize = std::stoi(envStr);
} catch (const std::invalid_argument &ia) {
OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what());
} catch (const std::out_of_range &oor) {
OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what());
}
}
const uint64_t maxWindowSize = static_cast<uint64_t>(defaultWindowSize) * 1024UL * 1024UL;
OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize);
return maxWindowSize;
}
};
constexpr uint32_t X_INDEX = 0U;
constexpr uint32_t EXPERT_IDS_INDEX = 1U;
constexpr uint32_t SEND_OFFSET_INDEX = 2U;
Expand Down Expand Up @@ -87,7 +64,7 @@ constexpr size_t MAX_GROUP_NAME_LENGTH = 128UL;
constexpr int64_t MAX_EP_WORLD_SIZE = 384;
constexpr int64_t MIN_EP_WORLD_SIZE = 2;
constexpr int64_t MAX_TP_WORLD_SIZE = 2;
constexpr int64_t BS_UPPER_BOUND = 32768; // 最大bs
constexpr int64_t BS_UPPER_BOUND = 65536; // 最大bs

constexpr uint32_t TILINGKEY_TP_WORLD_SIZE = 100;
constexpr uint32_t TP_WORLD_SIZE_TWO = 2;
Expand Down Expand Up @@ -586,22 +563,25 @@ static ge::graphStatus CamMoeDispatchNormalA3TilingFuncImpl(gert::TilingContext
uint64_t k = static_cast<uint64_t>(tilingData->camMoeDispatchNormalInfo.k);
uint64_t epWorldSize = static_cast<uint64_t>(tilingData->camMoeDispatchNormalInfo.epWorldSize);
uint64_t maxBs = static_cast<uint64_t>(tilingData->camMoeDispatchNormalInfo.globalBs) / epWorldSize;

uint32_t round = tilingData->camMoeDispatchNormalInfo.round;
// dispatch数据区 token首对齐512,有效token长度h_align_32b + scale(32b) + 三元组(3*4b)
uint64_t tokenActualLen =
((h * MAX_OUT_DTYPE_SIZE + UB_ALIGN - 1UL) / UB_ALIGN) * UB_ALIGN + SCALE_EXPAND_IDX_BUFFER;
uint64_t tokenNeedSizeDispatch = ((tokenActualLen + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
uint64_t tokenNeedSizeCombine = ((h * MAX_OUT_DTYPE_SIZE + WIN_ADDR_ALIGN - 1UL) / WIN_ADDR_ALIGN) * WIN_ADDR_ALIGN;
tokenNeedSizeCombine =
round > 1 ? tokenNeedSizeCombine * 2 : tokenNeedSizeCombine; // round > 1 combine要使用double buffer
// 未考虑双流时大小
uint64_t actualSize = (maxBs * k * (tokenNeedSizeCombine + tokenNeedSizeDispatch) + COMBINE_STATE_WIN_OFFSET +
NOTIFY_DISPATCH_WIN_OFFSET) *
DOUBLE_DATA_BUFFER;
DOUBLE_DATA_BUFFER +
Moe::STATE_SIZE * 4;
OP_TILING_CHECK((actualSize > maxWindowSize),
OP_LOGE(nodeName,
"HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu,"
" localMoeExpertNum = %u, tokenNeedSizeDispatch = %lu, tokenNeedSizeCombine = %lu,"
" k = %lu, NEEDED_HCCL_BUFFSIZE((maxBs * k * (tokenNeedSizeDispatch"
" + tokenNeedSizeCombine) + 3MB + 204MB) * 2) = %luMB, HCCL_BUFFSIZE=%luMB.",
" + tokenNeedSizeCombine) + 4MB + 404MB) * 2 + 4 * 2MB) = %luMB, HCCL_BUFFSIZE=%luMB.",
maxBs, h, epWorldSize, localMoeExpertNum, tokenNeedSizeDispatch, tokenNeedSizeCombine, k,
actualSize / MB_SIZE + 1UL, maxWindowSize / MB_SIZE),
return ge::GRAPH_FAILED);
Expand Down
14 changes: 10 additions & 4 deletions csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0;
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1;
constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2;
constexpr uint32_t OUTPUT_NOTIFY_SEND_DATA_INDEX = 3;
constexpr uint32_t OUTPUT_SEND_TOKEN_IDX_SMALL_INDEX = 4;

constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0;
constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1;
Expand Down Expand Up @@ -98,18 +99,18 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
auto numExpertsPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_EXPERTS_INDEX));
auto numTopkPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOPK_INDEX));
auto localRankSizePtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_LOCAL_RANKSIZE_INDEX));
auto rankIdPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_RANK_ID_INDEX));
auto perRoundTokensPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_PER_ROUND_TOKENS_INDEX));
auto rankIdPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_RANK_ID_INDEX));

OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(rankIdPtr == nullptr, OP_LOGE(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(perRoundTokensPtr == nullptr, OP_LOGE(nodeName, "perRoundTokensPtr is null."),
return ge::GRAPH_FAILED);
OP_TILING_CHECK(rankIdPtr == nullptr, OP_LOGE(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED);

OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
Expand All @@ -133,9 +134,8 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
tilingData.dispatchLayoutInfo.numExperts = static_cast<uint32_t>(*numExpertsPtr);
tilingData.dispatchLayoutInfo.numTopk = static_cast<uint32_t>(*numTopkPtr);
tilingData.dispatchLayoutInfo.localRankSize = static_cast<uint32_t>(*localRankSizePtr);
tilingData.dispatchLayoutInfo.rankId = static_cast<uint32_t>(*rankIdPtr);
tilingData.dispatchLayoutInfo.perRoundTokens = static_cast<uint32_t>(*perRoundTokensPtr);

tilingData.dispatchLayoutInfo.rankId = static_cast<uint32_t>(*rankIdPtr);
if (CheckIfA2MultiMachine(context, tilingData)) {
OP_TILING_CHECK(
(*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_LOCAL_RANKSIZE),
Expand Down Expand Up @@ -167,12 +167,14 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX);
auto notifySendData = context->GetOutputDesc(OUTPUT_NOTIFY_SEND_DATA_INDEX);
auto sendTokenIdxSmall = context->GetOutputDesc(OUTPUT_SEND_TOKEN_IDX_SMALL_INDEX);

OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false);
OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false);
OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false);
OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false);
OP_TILING_CHECK(notifySendData == nullptr, OP_LOGE(nodeName, "notifySendData is null."), return false);
OP_TILING_CHECK(sendTokenIdxSmall == nullptr, OP_LOGE(nodeName, "sendTokenIdxSmall is null."), return false);

OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64),
OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.",
Expand All @@ -194,6 +196,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
OP_LOGE(nodeName, "notifySendData datatype is invalid, datatype should be int, but is %d.",
static_cast<ge::DataType>(notifySendData->GetDataType())),
return false);
OP_TILING_CHECK((sendTokenIdxSmall->GetDataType() != ge::DT_INT32),
OP_LOGE(nodeName, "sendTokenIdxSmall datatype is invalid, datatype should be int, but is %d.",
static_cast<ge::DataType>(sendTokenIdxSmall->GetDataType())),
return false);

return true;
}
Expand Down
Loading