@@ -54,6 +54,7 @@ constexpr uint32_t EP_RECV_COUNTS_INDEX = 2;
5454constexpr uint32_t TOPK_WEIGHTS_INDEX = 3 ;
5555constexpr uint32_t TP_RECV_COUNTS_INDEX = 4 ;
5656constexpr uint32_t OUTPUT_X_INDEX = 0 ;
57+ constexpr uint32_t OUTPUT_SEND_COST_INDEX = 1 ;
5758
5859constexpr uint32_t ATTR_GROUP_EP_INDEX = 0 ;
5960constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1 ;
@@ -238,7 +239,7 @@ static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char
238239 return true ;
239240}
240241
241- static bool CheckOutputTensorDim (gert::TilingContext *context, const char *nodeName)
242+ static bool CheckOutputTensorDim (gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose )
242243{
243244 const gert::StorageShape *xStorageShape = context->GetOutputShape (OUTPUT_X_INDEX);
244245 OP_TILING_CHECK (xStorageShape == nullptr , OP_LOGE (nodeName, " x is null." ), return false );
@@ -249,25 +250,34 @@ static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeN
249250 OP_LOGD (nodeName, " x dim0 = %ld" , xStorageShape->GetStorageShape ().GetDim (0 ));
250251 OP_LOGD (nodeName, " x dim1 = %ld" , xStorageShape->GetStorageShape ().GetDim (1 ));
251252
253+ if (isEnableDiagnose) {
254+ const gert::StorageShape *sendCostStatsStorageShape = context->GetOutputShape (OUTPUT_SEND_COST_INDEX);
255+ OP_TILING_CHECK (sendCostStatsStorageShape == nullptr , OP_LOGE (nodeName, " combine sendCostStatsShape is null." ),
256+ return false );
257+ OP_TILING_CHECK (sendCostStatsStorageShape->GetStorageShape ().GetDimNum () != ONE_DIM,
258+ OP_LOGE (nodeName, " combine sendCostStatsShape must be 1-dimension, but got %lu dim" ,
259+ sendCostStatsStorageShape->GetStorageShape ().GetDimNum ()),
260+ return false );
261+ }
252262 return true ;
253263}
254264
255- static bool CheckTensorDim (gert::TilingContext *context, const char *nodeName)
265+ static bool CheckTensorDim (gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose )
256266{
257267 OP_TILING_CHECK (!CheckInputTensorDim (context, nodeName),
258268 OP_LOGE (nodeName, " param shape of input tensor is invalid" ), return false );
259269
260270 OP_TILING_CHECK (!CheckOptionalInputTensorDim (context, nodeName),
261271 OP_LOGE (nodeName, " param shape of optional input tensor is invalid" ), return false );
262272
263- OP_TILING_CHECK (!CheckOutputTensorDim (context, nodeName),
273+ OP_TILING_CHECK (!CheckOutputTensorDim (context, nodeName, isEnableDiagnose ),
264274 OP_LOGE (nodeName, " param shape of output tensor is invalid" ), return false );
265275
266276 return true ;
267277}
268278
269279// 校验数据类型
270- static bool CheckTensorDataType (gert::TilingContext *context, const char *nodeName)
280+ static bool CheckTensorDataType (gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose )
271281{
272282 auto recvXDesc = context->GetInputDesc (RECV_X_INDEX);
273283 OP_TILING_CHECK (recvXDesc == nullptr , OP_LOGE (nodeName, " recvXDesc is null." ), return false );
@@ -296,10 +306,20 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
296306 OP_TILING_CHECK ((xDesc->GetDataType () != recvXDesc->GetDataType ()),
297307 OP_LOGE (nodeName, " x dataType is invalid, dataType should be equal to recvX dataType , but is " ),
298308 return false );
309+
310+ if (isEnableDiagnose) {
311+ auto sendCostStatsDesc = context->GetOutputDesc (OUTPUT_SEND_COST_INDEX);
312+ OP_TILING_CHECK (sendCostStatsDesc == nullptr , OP_LOGE (nodeName, " combine sendCostStatsDesc is null." ),
313+ return false );
314+ OP_TILING_CHECK (
315+ sendCostStatsDesc->GetDataType () != ge::DT_INT32,
316+ OP_LOGE (nodeName, " combine sendCostStatsDesc dataType is invalid, dataType should be int32, but is ." ),
317+ return false );
318+ }
299319 return true ;
300320}
301321
302- static bool CheckTensorFormat (gert::TilingContext *context, const char *nodeName)
322+ static bool CheckTensorFormat (gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose )
303323{
304324 auto recvXDesc = context->GetInputDesc (RECV_X_INDEX);
305325 OP_TILING_CHECK (recvXDesc == nullptr , OP_LOGE (nodeName, " recvXDesc is null." ), return false );
@@ -330,6 +350,14 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName
330350 OP_TILING_CHECK (static_cast <ge::Format>(ge::GetPrimaryFormat (xDesc->GetStorageFormat ())) == ge::FORMAT_FRACTAL_NZ,
331351 OP_LOGE (nodeName, " xFormat is invalid" ), return false );
332352
353+ if (isEnableDiagnose) {
354+ auto sendCostStatsDesc = context->GetOutputDesc (OUTPUT_SEND_COST_INDEX);
355+ OP_TILING_CHECK (sendCostStatsDesc == nullptr , OP_LOGE (nodeName, " combine sendCostStatsDesc is null." ),
356+ return false );
357+ OP_TILING_CHECK (static_cast <ge::Format>(ge::GetPrimaryFormat (sendCostStatsDesc->GetStorageFormat ())) ==
358+ ge::FORMAT_FRACTAL_NZ,
359+ OP_LOGE (nodeName, " combine sendCostStatsDesc format is invalid" ), return false );
360+ }
333361 return true ;
334362}
335363
@@ -435,17 +463,18 @@ static bool CheckAttrs(gert::TilingContext *context, CamMoeCombineNormalTilingDa
435463 return true ;
436464}
437465
438- static ge::graphStatus TilingCheckCamMoeCombineNormal (gert::TilingContext *context, const char *nodeName)
466+ static ge::graphStatus TilingCheckCamMoeCombineNormal (gert::TilingContext *context, const char *nodeName,
467+ const bool isEnableDiagnose)
439468{
440469 // 检查参数shape信息
441- OP_TILING_CHECK (!CheckTensorDim (context, nodeName), OP_LOGE (nodeName, " param shape is invalid" ),
470+ OP_TILING_CHECK (!CheckTensorDim (context, nodeName, isEnableDiagnose ), OP_LOGE (nodeName, " param shape is invalid" ),
442471 return ge::GRAPH_FAILED);
443472 // 检查参数dataType信息
444- OP_TILING_CHECK (!CheckTensorDataType (context, nodeName), OP_LOGE (nodeName, " param dataType is invalid " ),
445- return ge::GRAPH_FAILED);
473+ OP_TILING_CHECK (!CheckTensorDataType (context, nodeName, isEnableDiagnose ),
474+ OP_LOGE (nodeName, " param dataType is invalid " ), return ge::GRAPH_FAILED);
446475 // 检查参数format信息
447- OP_TILING_CHECK (!CheckTensorFormat (context, nodeName), OP_LOGE (nodeName, " param Format is invalid " ),
448- return ge::GRAPH_FAILED);
476+ OP_TILING_CHECK (!CheckTensorFormat (context, nodeName, isEnableDiagnose ),
477+ OP_LOGE (nodeName, " param Format is invalid " ), return ge::GRAPH_FAILED);
449478 return ge::GRAPH_SUCCESS;
450479}
451480
@@ -493,8 +522,11 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
493522 OP_TILING_CHECK (GetAttrAndSetTilingData (context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED,
494523 OP_LOGE (nodeName, " Getting attr failed." ), return ge::GRAPH_FAILED);
495524
525+ auto sendCostStatsStorageShape = context->GetOutputShape (OUTPUT_SEND_COST_INDEX);
526+ bool isEnableDiagnose = (sendCostStatsStorageShape != nullptr );
527+ tilingData->camMoeCombineNormalInfo .isEnableDiagnose = isEnableDiagnose;
496528 // 检查输入输出的dim、format、dataType
497- OP_TILING_CHECK (TilingCheckCamMoeCombineNormal (context, nodeName) != ge::GRAPH_SUCCESS,
529+ OP_TILING_CHECK (TilingCheckCamMoeCombineNormal (context, nodeName, isEnableDiagnose ) != ge::GRAPH_SUCCESS,
498530 OP_LOGE (nodeName, " Tiling check params failed" ), return ge::GRAPH_FAILED);
499531
500532 // 检查属性的取值是否合法
0 commit comments