Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_END
};

extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
bool transB, bool weightNz,
Expand All @@ -55,8 +55,8 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor,



aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ extern "C" {
* @param [out] executor: op executor containing the operator compute flow.
* @return aclnnStatus: status code.
*/
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ class DispatchFFNCombine : public OpDef {
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("w1")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("w2")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
Expand All @@ -41,12 +41,12 @@ class DispatchFFNCombine : public OpDef {
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale1")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale2")
.ParamType(REQUIRED)
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,27 +91,42 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
{
const char *nodeName = context->GetNodeName();
// OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling.");

const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX);
const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX);
const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX);
auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0);
uint32_t M = aStorageShape->GetStorageShape().GetDim(0);
uint32_t K = aStorageShape->GetStorageShape().GetDim(1);
uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0);
uint32_t N = bStorageShape->GetStorageShape().GetDim(2);
uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1);

auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0);
uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum();
uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1);

uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1);
uint32_t listLen = 0;
while (true) {
auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen);
if (wTensorT == nullptr) {break;}
}

uint32_t expertPerRank;
if (listLen == 1) {
expertPerRank = wTensor->GetStorageShape().GetDim(0);
} else {
expertPerRank = listLen;
}

info.M = M;
info.N = N;
info.K = K;
info.expertPerRank = expertPerRank;
info.topK = topK;
info.listLen = listLen;
OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M);
OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K);
OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N);
OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank);
OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK);
OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen);

return ge::GRAPH_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class DispatchFFNCombine {
int32_t expertPerRank;
int32_t maxOutputSize;
int32_t EP;
int32_t listLen;

optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData;
uint64_t initRoutingQuantTilingKey;
Expand Down Expand Up @@ -138,6 +139,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Init(GM_ADDR xGM,
topK = tilingData.dispatchFFNCombineInfo.topK;
expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank;
maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize;
listLen = tilingData.dispatchFFNCombineInfo.listLen;

m0 = tilingData.cocTiling.m0;
k0 = tilingData.cocTiling.k0;
Expand Down Expand Up @@ -254,7 +256,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
uint32_t epilogueGranularity = expertPerRank - 1;

typename MatmulKernel::Params params{
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
problemShape, static_cast<uint32_t>(EP), static_cast<uint32_t>(listLen), static_cast<uint32_t>(expertPerRank), static_cast<uint32_t>(maxOutputSize),
static_cast<uint32_t>(rank), static_cast<uint32_t>(rankSize),
static_cast<uint32_t>(topK), initRoutingQuantTilingKey,
epilogueCoreNum, epilogueGranularity,
Expand Down
Loading
Loading