diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index bda7f1d0..56d573be 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -97,7 +97,8 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional const std::optional &num_tokens_per_rank, const at::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, const std::optional &cached_rank_prefix_matrix, - const std::optional &cached_channel_prefix_matrix, int expert_alignment, + const std::optional &cached_channel_prefix_matrix, + const std::optional &dispatch_wait_recv_cost_stats, int expert_alignment, int num_worst_tokens, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream, bool use_quant) { @@ -172,6 +173,14 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional scale_hidden_stride = static_cast(x_scales->stride(1)); } + at::Tensor dispatch_wait_recv_cost_stats_out; + if (dispatch_wait_recv_cost_stats.has_value()) { + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous()); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); + dispatch_wait_recv_cost_stats_out = dispatch_wait_recv_cost_stats.value(); + } + int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens) auto send_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device())); @@ -270,7 +279,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional num_ranks, // rankSize rank, // rankId hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, global_bs, expandx_out, dynamic_scales_out, - expand_idx_out); + expand_idx_out, dispatch_wait_recv_cost_stats_out); auto recv_topk_idx = std::optional(); auto recv_topk_weights = std::optional(); @@ -305,9 +314,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int return; } -std::tuple, std::optional> Buffer::intranode_combine( - const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional &topk_weights, - const torch::Tensor &src_idx, const torch::Tensor &send_head) +std::tuple, std::optional> +Buffer::intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, + const std::optional &topk_weights, const torch::Tensor &src_idx, + const torch::Tensor &send_head, const std::optional &combine_send_cost_stats) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); at::Tensor recv_x = x; @@ -345,6 +355,14 @@ std::tuple, std::optionalscalar_type() == torch::kInt32); + EP_HOST_ASSERT(combine_send_cost_stats->dim() == 1 and combine_send_cost_stats->is_contiguous()); + EP_HOST_ASSERT(combine_send_cost_stats->size(0) == num_ranks); + combine_send_cost_stats_out = combine_send_cost_stats.value(); + } + int64_t hidden = static_cast(recv_x.size(1)); at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device)); int64_t tp_world_size = 1; @@ -367,7 +385,7 @@ std::tuple, std::optionalis_padding) { if (this->padding_cnt == PADDING_SIZE) { diff --git a/csrc/deepep/deep_ep.hpp b/csrc/deepep/deep_ep.hpp index ffb7ed9d..f536c3fe 100644 --- a/csrc/deepep/deep_ep.hpp +++ b/csrc/deepep/deep_ep.hpp @@ -58,15 +58,17 @@ struct Buffer { const std::optional &num_tokens_per_rank, const at::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, int cached_num_recv_tokens, const std::optional &cached_rank_prefix_matrix, - const std::optional &cached_channel_prefix_matrix, int expert_alignment, + const std::optional &cached_channel_prefix_matrix, + const std::optional &dispatch_wait_recv_cost_stats, int expert_alignment, int num_worst_tokens, const Config &config, std::optional &previous_event, bool async, bool allocate_on_comm_stream, bool use_quant); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, std::optional> intranode_combine( - const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional &topk_weights, - const torch::Tensor &src_idx, const torch::Tensor &send_head); + std::tuple, std::optional> + intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, + const std::optional &topk_weights, const torch::Tensor &src_idx, + const torch::Tensor &send_head, const std::optional &combine_send_cost_stats); std::tuple, at::Tensor, at::Tensor, at::Tensor, std::optional, std::optional>> diff --git a/csrc/deepep/ops/op_host/cam_moe_combine_normal.cpp b/csrc/deepep/ops/op_host/cam_moe_combine_normal.cpp index a97feed4..3ea4e9c9 100644 --- a/csrc/deepep/ops/op_host/cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops/op_host/cam_moe_combine_normal.cpp @@ -43,6 +43,12 @@ class CamMoeCombineNormal : public OpDef .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("combine_send_cost_stats") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("ep_group_name").AttrType(REQUIRED).String(); this->Attr("ep_world_size").AttrType(REQUIRED).Int(); this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); diff --git a/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc b/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc index 8cc57834..ddc53fa8 100644 --- a/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc +++ b/csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc @@ -54,6 +54,7 @@ constexpr uint32_t EP_RECV_COUNTS_INDEX = 2; constexpr uint32_t TOPK_WEIGHTS_INDEX = 3; constexpr uint32_t TP_RECV_COUNTS_INDEX = 4; constexpr uint32_t OUTPUT_X_INDEX = 0; +constexpr uint32_t OUTPUT_SEND_COST_INDEX = 1; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; @@ -238,7 +239,7 @@ static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char return true; } -static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName) +static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose) { const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX); OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "x is null."), return false); @@ -249,10 +250,19 @@ static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeN OP_LOGD(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0)); OP_LOGD(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1)); + if (isEnableDiagnose) { + const gert::StorageShape *sendCostStatsStorageShape = context->GetOutputShape(OUTPUT_SEND_COST_INDEX); + OP_TILING_CHECK(sendCostStatsStorageShape == nullptr, OP_LOGE(nodeName, "combine sendCostStatsShape is null."), + return false); + OP_TILING_CHECK(sendCostStatsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + OP_LOGE(nodeName, "combine sendCostStatsShape must be 1-dimension, but got %lu dim", + sendCostStatsStorageShape->GetStorageShape().GetDimNum()), + return false); + } return true; } -static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName) +static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose) { OP_TILING_CHECK(!CheckInputTensorDim(context, nodeName), OP_LOGE(nodeName, "param shape of input tensor is invalid"), return false); @@ -260,14 +270,14 @@ static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName) OP_TILING_CHECK(!CheckOptionalInputTensorDim(context, nodeName), OP_LOGE(nodeName, "param shape of optional input tensor is invalid"), return false); - OP_TILING_CHECK(!CheckOutputTensorDim(context, nodeName), + OP_TILING_CHECK(!CheckOutputTensorDim(context, nodeName, isEnableDiagnose), OP_LOGE(nodeName, "param shape of output tensor is invalid"), return false); return true; } // 校验数据类型 -static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName) +static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose) { auto recvXDesc = context->GetInputDesc(RECV_X_INDEX); OP_TILING_CHECK(recvXDesc == nullptr, OP_LOGE(nodeName, "recvXDesc is null."), return false); @@ -296,10 +306,20 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_TILING_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()), OP_LOGE(nodeName, "x dataType is invalid, dataType should be equal to recvX dataType , but is "), return false); + + if (isEnableDiagnose) { + auto sendCostStatsDesc = context->GetOutputDesc(OUTPUT_SEND_COST_INDEX); + OP_TILING_CHECK(sendCostStatsDesc == nullptr, OP_LOGE(nodeName, "combine sendCostStatsDesc is null."), + return false); + OP_TILING_CHECK( + sendCostStatsDesc->GetDataType() != ge::DT_INT32, + OP_LOGE(nodeName, "combine sendCostStatsDesc dataType is invalid, dataType should be int32, but is ."), + return false); + } return true; } -static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName) +static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose) { auto recvXDesc = context->GetInputDesc(RECV_X_INDEX); OP_TILING_CHECK(recvXDesc == nullptr, OP_LOGE(nodeName, "recvXDesc is null."), return false); @@ -330,6 +350,14 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName OP_TILING_CHECK(static_cast(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, OP_LOGE(nodeName, "xFormat is invalid"), return false); + if (isEnableDiagnose) { + auto sendCostStatsDesc = context->GetOutputDesc(OUTPUT_SEND_COST_INDEX); + OP_TILING_CHECK(sendCostStatsDesc == nullptr, OP_LOGE(nodeName, "combine sendCostStatsDesc is null."), + return false); + OP_TILING_CHECK(static_cast(ge::GetPrimaryFormat(sendCostStatsDesc->GetStorageFormat())) == + ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "combine sendCostStatsDesc format is invalid"), return false); + } return true; } @@ -435,17 +463,18 @@ static bool CheckAttrs(gert::TilingContext *context, CamMoeCombineNormalTilingDa return true; } -static ge::graphStatus TilingCheckCamMoeCombineNormal(gert::TilingContext *context, const char *nodeName) +static ge::graphStatus TilingCheckCamMoeCombineNormal(gert::TilingContext *context, const char *nodeName, + const bool isEnableDiagnose) { // 检查参数shape信息 - OP_TILING_CHECK(!CheckTensorDim(context, nodeName), OP_LOGE(nodeName, "param shape is invalid"), + OP_TILING_CHECK(!CheckTensorDim(context, nodeName, isEnableDiagnose), OP_LOGE(nodeName, "param shape is invalid"), return ge::GRAPH_FAILED); // 检查参数dataType信息 - OP_TILING_CHECK(!CheckTensorDataType(context, nodeName), OP_LOGE(nodeName, "param dataType is invalid"), - return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorDataType(context, nodeName, isEnableDiagnose), + OP_LOGE(nodeName, "param dataType is invalid"), return ge::GRAPH_FAILED); // 检查参数format信息 - OP_TILING_CHECK(!CheckTensorFormat(context, nodeName), OP_LOGE(nodeName, "param Format is invalid"), - return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorFormat(context, nodeName, isEnableDiagnose), + OP_LOGE(nodeName, "param Format is invalid"), return ge::GRAPH_FAILED); return ge::GRAPH_SUCCESS; } @@ -493,8 +522,11 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext * OP_TILING_CHECK(GetAttrAndSetTilingData(context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED, OP_LOGE(nodeName, "Getting attr failed."), return ge::GRAPH_FAILED); + auto sendCostStatsStorageShape = context->GetOutputShape(OUTPUT_SEND_COST_INDEX); + bool isEnableDiagnose = (sendCostStatsStorageShape != nullptr); + tilingData->camMoeCombineNormalInfo.isEnableDiagnose = isEnableDiagnose; // 检查输入输出的dim、format、dataType - OP_TILING_CHECK(TilingCheckCamMoeCombineNormal(context, nodeName) != ge::GRAPH_SUCCESS, + OP_TILING_CHECK(TilingCheckCamMoeCombineNormal(context, nodeName, isEnableDiagnose) != ge::GRAPH_SUCCESS, OP_LOGE(nodeName, "Tiling check params failed"), return ge::GRAPH_FAILED); // 检查属性的取值是否合法 diff --git a/csrc/deepep/ops/op_host/cam_moe_dispatch_normal.cpp b/csrc/deepep/ops/op_host/cam_moe_dispatch_normal.cpp index fab6bd78..fca41662 100644 --- a/csrc/deepep/ops/op_host/cam_moe_dispatch_normal.cpp +++ b/csrc/deepep/ops/op_host/cam_moe_dispatch_normal.cpp @@ -62,6 +62,12 @@ class CamMoeDispatchNormal : public OpDef .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("dispatch_wait_recv_cost_stats") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("group_ep").AttrType(REQUIRED).String(); this->Attr("ep_world_size").AttrType(REQUIRED).Int(); this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); diff --git a/csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc b/csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc index 95146211..11d0e1fd 100644 --- a/csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc +++ b/csrc/deepep/ops/op_host/cam_moe_dispatch_normal_tiling.cc @@ -59,6 +59,7 @@ constexpr uint32_t RECV_COUNT_INDEX = 5U; constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0U; constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1U; constexpr uint32_t OUTPUT_ASSIST_INFO_INDEX = 2U; +constexpr uint32_t OUTPUT_WAIT_RECV_COST_INDEX = 3U; constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; @@ -122,7 +123,8 @@ static void PrintTilingDataInfo(const char *nodeName, CamMoeDispatchNormalTiling OP_LOGD(nodeName, "totalWinSize is %lu.", tilingData.camMoeDispatchNormalInfo.totalWinSize); } -static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) +static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode, + const bool isEnableDiagnose) { const gert::StorageShape *xStorageShape = context->GetInputShape(X_INDEX); OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "xShape is null."), return false); @@ -172,10 +174,21 @@ static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, c return false); OP_LOGD(nodeName, "assistInfoForCombine dim0 = %ld", assistInfoStorageShape->GetStorageShape().GetDim(0)); + if (isEnableDiagnose) { + const gert::StorageShape *waitRecvcostStatsStorageShape = context->GetOutputShape(OUTPUT_WAIT_RECV_COST_INDEX); + OP_TILING_CHECK(waitRecvcostStatsStorageShape == nullptr, + OP_LOGE(nodeName, "dispatch waitRecvCostStatsShape is null."), return false); + OP_TILING_CHECK(waitRecvcostStatsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + OP_LOGE(nodeName, "dispatch waitRecvCostStatsShape dim must be 1, but current dim num is %lu.", + waitRecvcostStatsStorageShape->GetStorageShape().GetDimNum()), + return false); + } + return true; } -static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) +static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode, + const bool isEnableDiagnose) { auto xDesc = context->GetInputDesc(X_INDEX); OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); @@ -216,10 +229,21 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa OP_LOGE(nodeName, "assistInfoForCombine dataType is invalid, dataType should be int32, but is ."), return false); + if (isEnableDiagnose) { + auto waitRecvCostStatsDesc = context->GetOutputDesc(OUTPUT_WAIT_RECV_COST_INDEX); + OP_TILING_CHECK(waitRecvCostStatsDesc == nullptr, OP_LOGE(nodeName, "dispatch waitRecvCostStatsDesc is null."), + return false); + OP_TILING_CHECK( + waitRecvCostStatsDesc->GetDataType() != ge::DT_INT32, + OP_LOGE(nodeName, "dispatch waitRecvCostStatsDesc dataType is invalid, dataType should be int32, but is ."), + return false); + } + return true; } -static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode) +static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const uint32_t quantMode, + const bool isEnableDiagnose) { auto xDesc = context->GetInputDesc(X_INDEX); OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); @@ -252,6 +276,15 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName static_cast(ge::GetPrimaryFormat(assistInfoDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, OP_LOGE(nodeName, "assistInfoForCombine format is invalid."), return false); + if (isEnableDiagnose) { + auto waitRecvCostStatsDesc = context->GetOutputDesc(OUTPUT_WAIT_RECV_COST_INDEX); + OP_TILING_CHECK(waitRecvCostStatsDesc == nullptr, OP_LOGE(nodeName, "dispatch waitRecvCostStatsDesc is null."), + return false); + OP_TILING_CHECK(static_cast(ge::GetPrimaryFormat(waitRecvCostStatsDesc->GetStorageFormat())) == + ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "dispatch waitRecvCostStatsDesc format is invalid"), return false); + } + return true; } @@ -448,14 +481,14 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char } static ge::graphStatus TilingCheckCamMoeDispatchNormal(gert::TilingContext *context, const char *nodeName, - const uint32_t quantMode) + const uint32_t quantMode, const bool isEnableDiagnose) { - OP_TILING_CHECK(!CheckTensorDim(context, nodeName, quantMode), OP_LOGE(nodeName, "params shape is invalid."), - return ge::GRAPH_FAILED); - OP_TILING_CHECK(!CheckTensorDataType(context, nodeName, quantMode), + OP_TILING_CHECK(!CheckTensorDim(context, nodeName, quantMode, isEnableDiagnose), + OP_LOGE(nodeName, "params shape is invalid."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorDataType(context, nodeName, quantMode, isEnableDiagnose), OP_LOGE(nodeName, "params dataType is invalid."), return ge::GRAPH_FAILED); - OP_TILING_CHECK(!CheckTensorFormat(context, nodeName, quantMode), OP_LOGE(nodeName, "params format is invalid."), - return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorFormat(context, nodeName, quantMode, isEnableDiagnose), + OP_LOGE(nodeName, "params format is invalid."), return ge::GRAPH_FAILED); return ge::GRAPH_SUCCESS; } @@ -517,9 +550,14 @@ static ge::graphStatus CamMoeDispatchNormalA3TilingFuncImpl(gert::TilingContext quantMode = tilingData->camMoeDispatchNormalInfo.quantMode; + auto waitRecvcostStatsStorageShape = context->GetOutputShape(OUTPUT_WAIT_RECV_COST_INDEX); + bool isEnableDiagnose = (waitRecvcostStatsStorageShape != nullptr); + tilingData->camMoeDispatchNormalInfo.isEnableDiagnose = isEnableDiagnose; + // 检查输入输出的dim、format、dataType - OP_TILING_CHECK(TilingCheckCamMoeDispatchNormal(context, nodeName, quantMode) != ge::GRAPH_SUCCESS, - OP_LOGE(nodeName, "Tiling check param failed."), return ge::GRAPH_FAILED); + OP_TILING_CHECK( + TilingCheckCamMoeDispatchNormal(context, nodeName, quantMode, isEnableDiagnose) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling check param failed."), return ge::GRAPH_FAILED); // 检查属性的取值是否合法 OP_TILING_CHECK(CheckAttrs(context, nodeName, *tilingData, localMoeExpertNum) != ge::GRAPH_SUCCESS, diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.cpp b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.cpp index 67b88072..a25fea35 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.cpp @@ -19,12 +19,14 @@ aclnnStatus aclnnCamMoeCombineNormalGetWorkspaceSize(const aclTensor *recvX, con const aclTensor *tpRecvCountsOptional, char *epGroupName, int64_t epWorldSize, int64_t epRankId, char *tpGroupNameOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, - int64_t globalBs, const aclTensor *out, uint64_t *workspaceSize, + int64_t globalBs, const aclTensor *out, + const aclTensor *sendCostStats, uint64_t *workspaceSize, aclOpExecutor **executor) { - return aclnnInnerCamMoeCombineNormalGetWorkspaceSize( - recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights, tpRecvCountsOptional, epGroupName, epWorldSize, epRankId, - tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum, globalBs, out, workspaceSize, executor); + return aclnnInnerCamMoeCombineNormalGetWorkspaceSize(recvX, tokenSrcInfo, epRecvCounts, recvTopkWeights, + tpRecvCountsOptional, epGroupName, epWorldSize, epRankId, + tpGroupNameOptional, tpWorldSize, tpRankId, moeExpertNum, + globalBs, out, sendCostStats, workspaceSize, executor); } aclnnStatus aclnnCamMoeCombineNormal(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.h b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.h index 4a2ceecb..7e3d91e5 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.h +++ b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_combine_normal.h @@ -29,7 +29,8 @@ __attribute__((visibility("default"))) aclnnStatus aclnnCamMoeCombineNormalGetWo const aclTensor *recvX, const aclTensor *tokenSrcInfo, const aclTensor *epRecvCounts, const aclTensor *recvTopkWeights, const aclTensor *tpRecvCountsOptional, char *epGroupName, int64_t epWorldSize, int64_t epRankId, char *tpGroupNameOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, - int64_t globalBs, const aclTensor *out, uint64_t *workspaceSize, aclOpExecutor **executor); + int64_t globalBs, const aclTensor *out, const aclTensor *sendCostStats, uint64_t *workspaceSize, + aclOpExecutor **executor); /* function: aclnnMoeCombine * workspace : workspace memory addr(input). diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.cpp b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.cpp index 1c8684eb..0e730bbd 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.cpp +++ b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.cpp @@ -19,12 +19,12 @@ aclnnStatus aclnnCamMoeDispatchNormalGetWorkspaceSize( const aclTensor *recvOffset, const aclTensor *recvCount, char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, const aclTensor *recvX, const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, - uint64_t *workspaceSize, aclOpExecutor **executor) + const aclTensor *waitRecvCostStats, uint64_t *workspaceSize, aclOpExecutor **executor) { - return aclnnInnerCamMoeDispatchNormalGetWorkspaceSize(x, topkIdx, sendOffset, sendTokenIdx, recvOffset, recvCount, - groupEp, epWorldSize, epRankId, groupTpOptional, tpWorldSize, - tpRankId, moeExpertNum, quantMode, globalBs, recvX, - recvXScales, assistInfoForCombine, workspaceSize, executor); + return aclnnInnerCamMoeDispatchNormalGetWorkspaceSize( + x, topkIdx, sendOffset, sendTokenIdx, recvOffset, recvCount, groupEp, epWorldSize, epRankId, groupTpOptional, + tpWorldSize, tpRankId, moeExpertNum, quantMode, globalBs, recvX, recvXScales, assistInfoForCombine, + waitRecvCostStats, workspaceSize, executor); } aclnnStatus aclnnCamMoeDispatchNormal(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, diff --git a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.h b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.h index a717c2f3..6ee50757 100644 --- a/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.h +++ b/csrc/deepep/ops/op_host/op_api/aclnn_cam_moe_dispatch_normal.h @@ -12,7 +12,7 @@ __attribute__((visibility("default"))) aclnnStatus aclnnCamMoeDispatchNormalGetW const aclTensor *recvOffset, const aclTensor *recvCount, char *groupEp, int64_t epWorldSize, int64_t epRankId, char *groupTpOptional, int64_t tpWorldSize, int64_t tpRankId, int64_t moeExpertNum, int64_t quantMode, int64_t globalBs, const aclTensor *recvX, const aclTensor *recvXScales, const aclTensor *assistInfoForCombine, - uint64_t *workspaceSize, aclOpExecutor **executor); + const aclTensor *waitRecvCostStats, uint64_t *workspaceSize, aclOpExecutor **executor); __attribute__((visibility("default"))) aclnnStatus aclnnCamMoeDispatchNormal(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, diff --git a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.cpp b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.cpp index eb159ef6..e24248c2 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.cpp +++ b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.cpp @@ -7,7 +7,8 @@ using namespace CamMoeCombineNormalImpl; extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut, - GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, + GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(CamMoeCombineNormalTilingData); @@ -16,7 +17,8 @@ extern "C" __global__ __aicore__ void cam_moe_combine_normal(GM_ADDR recvX, GM_A #if (ORIG_DTYPE_RECV_X == DT_BF16 || ORIG_DTYPE_RECV_X == DT_FLOAT16) GET_TILING_DATA_WITH_STRUCT(CamMoeCombineNormalTilingData, tilingData, tilingGM); CamMoeCombineNormal op; - op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, workspaceGM, &pipe, &tilingData); + op.Init(recvX, tokenSrcInfo, epRecvCount, topkWeights, tpRecvCount, XOut, sendCostStatsOut, workspaceGM, &pipe, + &tilingData); op.Process(); #endif } diff --git a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.h b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.h index dbd1d3d3..e51bef99 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.h +++ b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal.h @@ -19,6 +19,7 @@ constexpr uint32_t MUL_256_ALIGN = 256U; constexpr uint64_t WIN_512_ALIGN = 512UL; constexpr uint32_t FLOAT_NUM_PER_ALIGN = 8U; constexpr uint8_t DOUBLE_BUFFER = 2; +constexpr int64_t CYCLE_TO_TIME = 50; // cycle num is converted into a fixed base unit of time, set at 50 template __aicore__ inline void SyncFunc() @@ -38,14 +39,14 @@ class CamMoeCombineNormal public: __aicore__ inline CamMoeCombineNormal(){}; __aicore__ inline void Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, - const CamMoeCombineNormalTilingData *tilingData); + GM_ADDR tpRecvCount, GM_ADDR XOut, GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, + TPipe *pipe, const CamMoeCombineNormalTilingData *tilingData); __aicore__ inline void Process(); private: __aicore__ inline void InitMagic(); __aicore__ inline void InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR XOut); + GM_ADDR topkWeights, GM_ADDR XOut, GM_ADDR sendCostStatsOut); __aicore__ inline void InitTilingData(const CamMoeCombineNormalTilingData *tilingData); __aicore__ inline void InitBuffLen(); __aicore__ inline void CopyBufferToShareAndSetStatus(); @@ -106,9 +107,13 @@ class CamMoeCombineNormal uint32_t h256AlignFloatLen_{0}; uint32_t h32AlignRecvXLen_{0}; uint32_t h512AlignRecvXLen_{0}; + uint32_t sendCostStatsBufSize_{0}; + + bool isEnableDiagnose_{false}; TPipe *tpipe_{nullptr}; TQue weightedSumQueue_; + TQue sendCostStatsOutQueue_; TQueBind localCopyQueue_; TBuf<> stateBuf_; TBuf<> topkWeightsBuf_; @@ -124,6 +129,7 @@ class CamMoeCombineNormal GlobalTensor epRecvCountGM_; GlobalTensor topkWeightsGM_; GlobalTensor xOutGlobal_; + GlobalTensor sendCostStatsGT_; GM_ADDR localRankGM_; GM_ADDR workspaceGM_; }; @@ -146,13 +152,17 @@ __aicore__ inline void CamMoeCombineNormal::InitMagic() template __aicore__ inline void CamMoeCombineNormal::InitGlobalBuffer(GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, - GM_ADDR topkWeights, GM_ADDR XOut) + GM_ADDR topkWeights, GM_ADDR XOut, + GM_ADDR sendCostStatsOut) { recvXGM_.SetGlobalBuffer((__gm__ RecvXType *)recvX); tokenSrcInfoGM_.SetGlobalBuffer((__gm__ SrcInfoType *)tokenSrcInfo); epRecvCountGM_.SetGlobalBuffer((__gm__ int32_t *)epRecvCount); topkWeightsGM_.SetGlobalBuffer((__gm__ float *)topkWeights); xOutGlobal_.SetGlobalBuffer((__gm__ XType *)XOut); + if (isEnableDiagnose_) { + sendCostStatsGT_.SetGlobalBuffer((__gm__ int32_t *)sendCostStatsOut); + } } template @@ -167,6 +177,7 @@ CamMoeCombineNormal::InitTilingData(const CamMoeCombineNorm moeExpertPerRankNum_ = tilingData->camMoeCombineNormalInfo.moeExpertPerRankNum; epWorldSize_ = tilingData->camMoeCombineNormalInfo.epWorldSize; epRankId_ = tilingData->camMoeCombineNormalInfo.epRankId; + isEnableDiagnose_ = tilingData->camMoeCombineNormalInfo.isEnableDiagnose; } template @@ -178,22 +189,23 @@ __aicore__ inline void CamMoeCombineNormal::InitBuffLen() hRecvXTypeLen_ = axisH_ * sizeof(RecvXType); h32AlignRecvXLen_ = Ceil(hRecvXTypeLen_, UB_32_ALIGN) * UB_32_ALIGN; h512AlignRecvXLen_ = Ceil(hRecvXTypeLen_, WIN_512_ALIGN) * WIN_512_ALIGN; + if (isEnableDiagnose_) { + sendCostStatsBufSize_ = Ceil(epWorldSize_ * sizeof(int32_t), UB_32_ALIGN) * UB_32_ALIGN; + } } template -__aicore__ inline void CamMoeCombineNormal::Init(GM_ADDR recvX, GM_ADDR tokenSrcInfo, - GM_ADDR epRecvCount, GM_ADDR topkWeights, - GM_ADDR tpRecvCount, GM_ADDR XOut, - GM_ADDR workspaceGM, TPipe *pipe, - const CamMoeCombineNormalTilingData *tilingData) +__aicore__ inline void CamMoeCombineNormal::Init( + GM_ADDR recvX, GM_ADDR tokenSrcInfo, GM_ADDR epRecvCount, GM_ADDR topkWeights, GM_ADDR tpRecvCount, GM_ADDR XOut, + GM_ADDR sendCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, const CamMoeCombineNormalTilingData *tilingData) { workspaceGM_ = workspaceGM; tpipe_ = pipe; coreIdx_ = GetBlockIdx(); InitMagic(); - InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, XOut); InitTilingData(tilingData); + InitGlobalBuffer(recvX, tokenSrcInfo, epRecvCount, topkWeights, XOut, sendCostStatsOut); InitBuffLen(); PipeBarrier(); @@ -226,17 +238,43 @@ __aicore__ inline void CamMoeCombineNormal::CopyBufferToSha const DataCopyExtParams dataCopyParams{1U, blockLen, 0U, 0U, 0U}; const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; DataCopyPad(srcInfoLocal, tokenSrcInfoGM_[startTokenId * TOKEN_SRC_INFO_LEN], dataCopyParams, padParams); - SyncFunc(); + + LocalTensor sendCostStatsTensor; + if (isEnableDiagnose_) { + tpipe_->InitBuffer(sendCostStatsOutQueue_, DOUBLE_BUFFER, sendCostStatsBufSize_); + sendCostStatsTensor = sendCostStatsOutQueue_.AllocTensor(); + Duplicate(sendCostStatsTensor, 0, sendCostStatsBufSize_ / sizeof(int32_t)); + } + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; tokenIndex++) { uint32_t index = (tokenIndex - startTokenId) * TOKEN_SRC_INFO_LEN; uint32_t srcRankId = static_cast(srcInfoLocal(index + RANK_ID_OFFSET_IN_SRC_INFO)); uint32_t srcTokenId = static_cast(srcInfoLocal(index + TOKEN_IDX_OFFSET_IN_SRC_INFO)); uint32_t srcTopkId = static_cast(srcInfoLocal(index + TOPK_IDX_OFFSET_IN_SRC_INFO)); + int64_t sendStartCycle = GetSystemCycle(); + CopyBufferToShare(srcRankId, srcTokenId, srcTopkId, tokenIndex); PipeBarrier(); SetStatusBySrcInfo(srcRankId, srcTokenId, srcTopkId); + + if (isEnableDiagnose_) { + SyncFunc(); + int32_t durationTime = static_cast((GetSystemCycle() - sendStartCycle) / CYCLE_TO_TIME); // us + int32_t preTime = sendCostStatsTensor.GetValue(srcRankId); + sendCostStatsTensor.SetValue(srcRankId, preTime + durationTime); + } } + + if (isEnableDiagnose_) { + SyncFunc(); + AscendC::SetAtomicAdd(); + DataCopyExtParams statsCopyOutParams = {1U, static_cast(epWorldSize_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPad(sendCostStatsGT_, sendCostStatsTensor, statsCopyOutParams); + AscendC::SetAtomicNone(); + sendCostStatsOutQueue_.FreeTensor(sendCostStatsTensor); + } + SyncFunc(); } diff --git a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal_tiling.h b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal_tiling.h index 56e55eae..5f3d844d 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_combine_normal_tiling.h +++ b/csrc/deepep/ops/op_kernel/cam_moe_combine_normal_tiling.h @@ -22,6 +22,7 @@ struct CamMoeCombineNormalInfo { uint64_t totalWinSize; float armAvgFactor; float epsilon; + bool isEnableDiagnose; }; struct CamMoeCombineNormalTilingData { Mc2InitTiling mc2InitTiling; diff --git a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.cpp b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.cpp index 8e6fb74a..cf75d1ce 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.cpp +++ b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.cpp @@ -12,7 +12,8 @@ extern "C" __global__ __aicore__ void cam_moe_dispatch_normal(GM_ADDR x, GM_ADDR GM_ADDR send_token_idx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR assist_info_for_combine, - GM_ADDR workspaceGM, GM_ADDR tilingGM) + GM_ADDR waitRecvCostStatsOut, GM_ADDR workspaceGM, + GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(CamMoeDispatchNormalTilingData); TPipe pipe; @@ -21,7 +22,7 @@ extern "C" __global__ __aicore__ void cam_moe_dispatch_normal(GM_ADDR x, GM_ADDR GET_TILING_DATA_WITH_STRUCT(CamMoeDispatchNormalTilingData, tilingData, tilingGM); CamMoeDispatchNormal op; op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expandXOut, dynamicScalesOut, - assist_info_for_combine, workspaceGM, &pipe, &tilingData); + assist_info_for_combine, waitRecvCostStatsOut, workspaceGM, &pipe, &tilingData); op.Process(); return; } @@ -30,7 +31,7 @@ extern "C" __global__ __aicore__ void cam_moe_dispatch_normal(GM_ADDR x, GM_ADDR GET_TILING_DATA_WITH_STRUCT(CamMoeDispatchNormalTilingData, tilingData, tilingGM); CamMoeDispatchNormal op; op.Init(x, expertIds, send_offset, send_token_idx, recv_offset, recv_count, expandXOut, dynamicScalesOut, - assist_info_for_combine, workspaceGM, &pipe, &tilingData); + assist_info_for_combine, waitRecvCostStatsOut, workspaceGM, &pipe, &tilingData); op.Process(); return; } diff --git a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.h b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.h index 850de4a9..93207a76 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.h +++ b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal.h @@ -20,6 +20,7 @@ constexpr uint64_t STATE_WIN_OFFSET = 950UL * 1024UL; constexpr uint64_t WIN_ADDR_ALIGN = 512UL; constexpr uint32_t EXPAND_IDX_INFO = 3U; constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3UL * 1024UL * 1024UL; +constexpr int64_t CYCLE_TO_TIME = 50; // cycle num is converted into a fixed base unit of time, set at 50 template __aicore__ inline void SyncFunc() @@ -42,7 +43,7 @@ class CamMoeDispatchNormal __aicore__ inline CamMoeDispatchNormal(){}; __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, - GM_ADDR expandIdxOut, GM_ADDR workspaceGM, TPipe *pipe, + GM_ADDR expandIdxOut, GM_ADDR waitRecvCostStatsOut, GM_ADDR workspaceGM, TPipe *pipe, const CamMoeDispatchNormalTilingData *tilingData); __aicore__ inline void Process(); @@ -89,6 +90,7 @@ class CamMoeDispatchNormal GlobalTensor expandIdxOutGT; GlobalTensor dstGT; GlobalTensor dstStatusGT; + GlobalTensor waitRecvCostStatsGT; LocalTensor xInTensor; LocalTensor xOutTensor; @@ -99,6 +101,9 @@ class CamMoeDispatchNormal LocalTensor recvOffsetTensor; LocalTensor recvCountTensor; LocalTensor statusTensor; + LocalTensor waitRecvCostStatsTensor; + LocalTensor recvStatusTensor1; + LocalTensor recvStatusTensor2; TBuf<> expertIdsBuf; TBuf<> sendOffsetBuf; @@ -111,6 +116,7 @@ class CamMoeDispatchNormal TBuf<> scalarBuf; TBuf<> tokenCastFloatBuf; TBuf<> tokenAbsFloatBuf; + TBuf<> recvStatusBuf; GM_ADDR expandXOutGM; GM_ADDR shareGM; @@ -127,6 +133,7 @@ class CamMoeDispatchNormal uint32_t tpRankId{0}; uint32_t moeExpertNum{0}; uint32_t moeExpertNumPerRank{0}; + bool isEnableDiagnose{false}; uint32_t hUBAlignSize{0}; uint32_t hOutGMAlignSize{0}; @@ -137,6 +144,8 @@ class CamMoeDispatchNormal uint32_t stateOffset{0}; uint32_t dataState{0}; uint32_t winDataSizeOffset{0}; + uint32_t waitRecvCostStatsBufSize{0}; + uint32_t srcRankOffset{0}; uint32_t startStatusId; uint32_t endStatusId; @@ -146,6 +155,7 @@ class CamMoeDispatchNormal TQueBind xQueue; TQue xInQueue; TQue xOutQueue; + TQue waitRecvCostStatsOutQueue; __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; @@ -153,12 +163,10 @@ class CamMoeDispatchNormal }; template -__aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, - GM_ADDR send_tokenIdx, GM_ADDR recv_offset, - GM_ADDR recv_count, GM_ADDR expandXOut, - GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, - GM_ADDR workspaceGM, TPipe *pipe, - const CamMoeDispatchNormalTilingData *tilingData) +__aicore__ inline void CamMoeDispatchNormal::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR send_offset, GM_ADDR send_tokenIdx, GM_ADDR recv_offset, GM_ADDR recv_count, + GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR waitRecvCostStatsOut, + GM_ADDR workspaceGM, TPipe *pipe, const CamMoeDispatchNormalTilingData *tilingData) { tpipe_ = pipe; blockIdx = GetBlockIdx(); @@ -180,6 +188,7 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD epRankId = tilingData->camMoeDispatchNormalInfo.epRankId; moeExpertNum = tilingData->camMoeDispatchNormalInfo.moeExpertNum; moeExpertNumPerRank = moeExpertNum / epRankSize; + isEnableDiagnose = tilingData->camMoeDispatchNormalInfo.isEnableDiagnose; xGT.SetGlobalBuffer((__gm__ XType *)x); expertIdsGT.SetGlobalBuffer((__gm__ int32_t *)expertIds); @@ -189,6 +198,9 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD recvCountGT.SetGlobalBuffer((__gm__ int32_t *)(recv_count)); dynamicScalesOutGT.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); expandIdxOutGT.SetGlobalBuffer((__gm__ int32_t *)(expandIdxOut)); + if (isEnableDiagnose) { + waitRecvCostStatsGT.SetGlobalBuffer((__gm__ int32_t *)waitRecvCostStatsOut); + } expandXOutGM = expandXOut; @@ -212,6 +224,7 @@ __aicore__ inline void CamMoeDispatchNormal::Init(GM_ADDR x, GM_ADD } endStatusId = startStatusId + statusNumPerCore; stateOffset = STATE_OFFSET; + srcRankOffset = startStatusId / moeExpertNumPerRank; DataCacheCleanAndInvalid(selfDataStatusTensor); dataState = selfDataStatusTensor(0); if (dataState == 0) { @@ -434,6 +447,20 @@ __aicore__ inline void CamMoeDispatchNormal::WaitStatus() tpipe_->InitBuffer(recvOffsetBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B tpipe_->InitBuffer(recvCountBuf, moeExpertNum * sizeof(int32_t)); // moeNum * 4B + if (isEnableDiagnose) { + waitRecvCostStatsBufSize = Ceil(statusNumPerCore * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; + tpipe_->InitBuffer(waitRecvCostStatsOutQueue, BUFFER_NUM, waitRecvCostStatsBufSize); + tpipe_->InitBuffer(recvStatusBuf, waitRecvCostStatsBufSize * 2); + + waitRecvCostStatsTensor = waitRecvCostStatsOutQueue.AllocTensor(); + recvStatusTensor1 = recvStatusBuf.GetWithOffset(waitRecvCostStatsBufSize, 0); + recvStatusTensor2 = recvStatusBuf.GetWithOffset(waitRecvCostStatsBufSize, waitRecvCostStatsBufSize); + + Duplicate(waitRecvCostStatsTensor, 0, waitRecvCostStatsBufSize / sizeof(int32_t)); + Duplicate(recvStatusTensor1, 0, waitRecvCostStatsBufSize / sizeof(float)); + Duplicate(recvStatusTensor2, 0, waitRecvCostStatsBufSize / sizeof(float)); + } + recvOffsetTensor = recvOffsetBuf.Get(); recvCountTensor = recvCountBuf.Get(); DataCopyExtParams recvOffsetParams = {1U, static_cast(moeExpertNum * sizeof(uint32_t)), 0U, 0U, 0U}; @@ -456,6 +483,12 @@ __aicore__ inline void CamMoeDispatchNormal::WaitStatus() float compareTarget = static_cast(1.0) * statusNumPerCore; float sumOfFlag = static_cast(-1.0); DataCopyParams intriParams{static_cast(statusNumPerCore), 1, 0, 0}; + + int64_t systemCycleStart = 0; + if (isEnableDiagnose) { + systemCycleStart = GetSystemCycle(); + } + SyncFunc(); while (sumOfFlag != compareTarget) { DataCopy(statusFp32Tensor, windowInstatusFp32Tensor[startStatusId * stateOffset / sizeof(float)], intriParams); @@ -463,6 +496,34 @@ __aicore__ inline void CamMoeDispatchNormal::WaitStatus() ReduceSum(statusSumOutTensor, statusFp32Tensor, gatherMaskOutTensor, mask, statusNumPerCore, 1); SyncFunc(); sumOfFlag = statusSumOutTensor.GetValue(0); + + if (isEnableDiagnose) { + int32_t durationTime = static_cast((GetSystemCycle() - systemCycleStart) / CYCLE_TO_TIME); // us + SyncFunc(); + int32_t repeatTimes = Ceil(statusNumPerCore, 8); // 8 is the num of blocks within one iteration + int mask2 = (statusNumPerCore > 8 ? 8 : statusNumPerCore) * 8; // num of elements within one iteration + AscendC::BlockReduceSum(recvStatusTensor1, statusFp32Tensor, repeatTimes, mask2, 1, 1, 8); + SyncFunc(); + for (uint32_t i = 0; i < statusNumPerCore; ++i) { + if (recvStatusTensor1.GetValue(i) != recvStatusTensor2.GetValue(i)) { + int32_t srcRank = (i + startStatusId) / moeExpertNumPerRank - srcRankOffset; + int32_t preTime = waitRecvCostStatsTensor.GetValue(srcRank); + waitRecvCostStatsTensor.SetValue(srcRank, preTime + durationTime); + float preStatus = recvStatusTensor1.GetValue(i); + recvStatusTensor2.SetValue(i, preStatus); + } + } + } + } + + if (isEnableDiagnose) { + // copy waitRecvCostStats from UB to GM + SyncFunc(); + AscendC::SetAtomicAdd(); + DataCopyExtParams statsCopyOutParams = {1U, waitRecvCostStatsBufSize, 0U, 0U, 0U}; + DataCopyPad(waitRecvCostStatsGT[srcRankOffset], waitRecvCostStatsTensor, statsCopyOutParams); + AscendC::SetAtomicNone(); + waitRecvCostStatsOutQueue.FreeTensor(waitRecvCostStatsTensor); } // 清状态 diff --git a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal_tiling.h b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal_tiling.h index ab73e31d..7c91cc5d 100644 --- a/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal_tiling.h +++ b/csrc/deepep/ops/op_kernel/cam_moe_dispatch_normal_tiling.h @@ -14,6 +14,7 @@ struct CamMoeDispatchNormalInfo { uint32_t h; // h uint32_t aivNum; // aivNum bool isQuant; // whether quant or not + bool isEnableDiagnose; // whether enable diagnose or not bool reserved2; // reserved bool reserved3; // reserved uint64_t totalUbSize; // epWorldSize diff --git a/python/deep_ep/deep_ep/buffer.py b/python/deep_ep/deep_ep/buffer.py index 5846cac8..53767920 100644 --- a/python/deep_ep/deep_ep/buffer.py +++ b/python/deep_ep/deep_ep/buffer.py @@ -234,6 +234,7 @@ def dispatch( previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, + dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, ) -> Tuple[ Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], Optional[torch.Tensor], @@ -269,6 +270,8 @@ def dispatch( previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + dispatch_wait_recv_cost_stats: `[num_ranks]` with `torch.int`, record the time it takes for the dispatch phase + to receive all tokens from each slave rank in the current rank. Returns: recv_x: received tokens, the first element is a `torch.Tensor` shaped as `[received_token_count, hidden]` with @@ -324,6 +327,7 @@ def dispatch( 0, None, None, + dispatch_wait_recv_cost_stats, expert_alignment, num_worst_tokens, config, @@ -363,6 +367,7 @@ def combine( previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False, + combine_send_cost_stats: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: """ Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode @@ -379,6 +384,8 @@ def combine( previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + combine_send_cost_stats: `[num_ranks]`: record the time when the current rank sends all tokens to other ranks + in the combine phase. Returns: recv_x: the reduced token from its dispatched ranks. @@ -399,7 +406,7 @@ def combine( # Launch the kernel recv_x, recv_topk_weights, event = self.runtime.intranode_combine( - x, topk_idx, topk_weights_ori, src_idx, send_head + x, topk_idx, topk_weights_ori, src_idx, send_head, combine_send_cost_stats ) return recv_x, recv_topk_weights, EventOverlap(event) diff --git a/tests/python/deepep/test_intranode.py b/tests/python/deepep/test_intranode.py index a0c7f1da..1295f59b 100644 --- a/tests/python/deepep/test_intranode.py +++ b/tests/python/deepep/test_intranode.py @@ -1,11 +1,21 @@ import argparse import time +from typing import Optional # noinspection PyUnresolvedReferences import deep_ep +import numpy as np import torch import torch.distributed as dist -from utils import bench, calc_diff, init_dist, inplace_unique, per_token_cast_back +import torch_npu +from utils import ( + bench, + calc_diff, + diagnose_matrix, + init_dist, + inplace_unique, + per_token_cast_back, +) # noinspection PyShadowingNames @@ -20,6 +30,7 @@ def test_main( # Settings num_tokens, hidden = args.num_tokens, args.hidden num_topk, num_experts = args.num_topk, args.num_experts + enable_diagnose = args.enable_diagnose assert num_experts % num_ranks == 0 if local_rank == 0: @@ -165,6 +176,72 @@ def check_data(check_x, rank_prefix_matrix): assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 check_start = check_end + # Test diagnose function + # noinspection PyShadowingNames + def test_diagnose( + dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, + combine_send_cost_stats: Optional[torch.Tensor] = None, + ): + for current_x in filter(lambda elem: elem is not None, (x_pure_rand,)): + dispatch_args = { + "x": current_x, + "num_tokens_per_rank": num_tokens_per_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": config, + "topk_idx": topk_idx, + "topk_weights": topk_weights_pure_rand, + "dispatch_wait_recv_cost_stats": dispatch_wait_recv_cost_stats, + } + if dispatch_wait_recv_cost_stats is not None: + bench(lambda: buffer.dispatch(**dispatch_args), num_warmups=0) + if combine_send_cost_stats is not None: + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = buffer.dispatch(**dispatch_args) + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + combine_args = { + "x": recv_x, + "handle": handle, + "topk_weights": handle[7], + "config": config, + "async_finish": False, + "combine_send_cost_stats": combine_send_cost_stats, + } + bench(lambda: buffer.combine(**combine_args), num_warmups=0) + for stats, title in ( + (dispatch_wait_recv_cost_stats, "Dispatch wait recv cost"), + (combine_send_cost_stats, "Combine send cost"), + ): + if stats is None: + continue + gather_list = ( + [torch.zeros_like(stats) for _ in range(group.size())] + if local_rank == 0 + else None + ) + dist.gather(stats, gather_list=gather_list, group=group, dst=0) + if local_rank == 0: + stats_mat = torch.stack(gather_list, dim=0) + print(f"{title} stats:") + print(stats_mat) + res = diagnose_matrix( + stats_mat, thres_col=1.0, thres_row=2.0, thres_point=5.0 + ) + print( + f"[Diagnose {title}] abnormal_rows {res['abnormal_rows']}, " + f"abnormal_cols {res['abnormal_cols']}, abnormal_points {res['abnormal_points']}" + ) + for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x)): if local_rank == 0: print( @@ -286,6 +363,19 @@ def check_data(check_x, rank_prefix_matrix): ) print("", flush=True) + # Diagnose test + if enable_diagnose: + dispatch_wait_recv_cost_stats = torch.zeros( + (num_ranks,), dtype=torch.int32, device="npu" + ) + combine_send_cost_stats = torch.zeros( + (num_ranks,), dtype=torch.int32, device="npu" + ) + test_diagnose( + dispatch_wait_recv_cost_stats=dispatch_wait_recv_cost_stats, + combine_send_cost_stats=combine_send_cost_stats, + ) + # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): @@ -333,6 +423,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): help="Comma-separated list of ranks that will receive tokens. " 'Example: "0,1,3". If empty, all ranks may receive tokens.', ) + parser.add_argument( + "--enable-diagnose", + action="store_true", + help="Whether to enable diagnose for testing", + ) args = parser.parse_args() num_processes = args.num_processes diff --git a/tests/python/deepep/utils.py b/tests/python/deepep/utils.py index 37ab0d6e..edecbd38 100644 --- a/tests/python/deepep/utils.py +++ b/tests/python/deepep/utils.py @@ -108,3 +108,68 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def hash_tensor(t: torch.Tensor): return t.view(torch.int8).sum().item() + + +def diagnose_matrix( + mat, + thres_col=3.0, + thres_row=3.0, + thres_point=5.0, + suppress_points_in_strong_rowscols=True, +): + """ + Detect abnormal columns, rows, and individual points in a 2D wait-time matrix. + Arguments: + mat (np.ndarray): 2D array where mat[i, j] is the waiting time of source i for destination j to + receive(dispatch)/send(combine) the token + thres_col/thres_row/thres_point(float): The ratio of the average waiting time for abnormal rank + to the average waiting time for all ranks + suppress_points_in_strong_rowscols (bool): If True, exclude points already in detected abnormal + rows/columns. + Returns: + dict: { + "abnormal_cols": List[List[int, float, float]], # abnormal column indices + "abnormal_rows": List[List[int, float, float]], # abnormal row indices + "abnormal_points": List[List[int, int, float, float]] # abnormal points + } + """ + mat = mat.cpu().numpy() + # 1. Check for abnormal columns + col_means = mat.mean(axis=0) + z_col = col_means / (col_means.mean() + 1e-8) + abnormal_cols = [ + [j, col_means[j], z_col[j]] for j in np.where(z_col > thres_col)[0] + ] + + # 2. Check for abnormal rows + row_means = mat.mean(axis=1) + z_row = row_means / (row_means.mean() + 1e-8) + abnormal_rows = [ + [i, row_means[i], z_row[i]] for i in np.where(z_row > thres_row)[0] + ] + + # 3. Check for abnormal single points + z_all = mat / (mat.mean() + 1e-8) + # Get all positions with z-score > threshold + abnormal_points = [ + [i, j, mat[i, j], z_all[i, j]] + for i in range(mat.shape[0]) + for j in range(mat.shape[1]) + if z_all[i, j] > thres_point + ] + # Optionally remove points that are in already detected abnormal rows + # or columns + if suppress_points_in_strong_rowscols: + strong_rows = [row[0] for row in abnormal_rows] + strong_cols = [col[0] for col in abnormal_cols] + abnormal_points = [ + [i, j, v, z] + for [i, j, v, z] in abnormal_points + if i not in strong_rows and j not in strong_cols + ] + # 4. Return for automatic processing + return { + "abnormal_cols": abnormal_cols, + "abnormal_rows": abnormal_rows, + "abnormal_points": abnormal_points, + }