Skip to content

Commit ca5d42f

Browse files
committed
Add diagnostic modules to dispatch and combine
1 parent 32d5437 commit ca5d42f

19 files changed

+445
-69
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
9797
const std::optional<at::Tensor> &num_tokens_per_rank, const at::Tensor &is_token_in_rank,
9898
const std::optional<at::Tensor> &num_tokens_per_expert, int cached_num_recv_tokens,
9999
const std::optional<at::Tensor> &cached_rank_prefix_matrix,
100-
const std::optional<at::Tensor> &cached_channel_prefix_matrix, int expert_alignment,
100+
const std::optional<at::Tensor> &cached_channel_prefix_matrix,
101+
const std::optional<at::Tensor> &dispatch_wait_recv_cost_stats, int expert_alignment,
101102
int num_worst_tokens, const Config &config, std::optional<EventHandle> &previous_event,
102103
bool async, bool allocate_on_comm_stream, bool use_quant)
103104
{
@@ -172,6 +173,14 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
172173
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
173174
}
174175

176+
at::Tensor dispatch_wait_recv_cost_stats_out;
177+
if (dispatch_wait_recv_cost_stats.has_value()) {
178+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt32);
179+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
180+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
181+
dispatch_wait_recv_cost_stats_out = dispatch_wait_recv_cost_stats.value();
182+
}
183+
175184
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
176185

177186
auto send_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
@@ -270,7 +279,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
270279
num_ranks, // rankSize
271280
rank, // rankId
272281
hcom_ep_name, tp_size, tp_rank, num_experts, quant_mode, global_bs, expandx_out, dynamic_scales_out,
273-
expand_idx_out);
282+
expand_idx_out, dispatch_wait_recv_cost_stats_out);
274283

275284
auto recv_topk_idx = std::optional<at::Tensor>();
276285
auto recv_topk_weights = std::optional<at::Tensor>();
@@ -305,9 +314,10 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
305314
return;
306315
}
307316

308-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> Buffer::intranode_combine(
309-
const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional<torch::Tensor> &topk_weights,
310-
const torch::Tensor &src_idx, const torch::Tensor &send_head)
317+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
318+
Buffer::intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
319+
const std::optional<torch::Tensor> &topk_weights, const torch::Tensor &src_idx,
320+
const torch::Tensor &send_head, const std::optional<at::Tensor> &combine_send_cost_stats)
311321
{
312322
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous());
313323
at::Tensor recv_x = x;
@@ -345,6 +355,14 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
345355
expert_scales = at::ones({num_tokens, num_topk}, at::dtype(at::kFloat).device(device));
346356
}
347357

358+
at::Tensor combine_send_cost_stats_out;
359+
if (combine_send_cost_stats.has_value()) {
360+
EP_HOST_ASSERT(combine_send_cost_stats->scalar_type() == torch::kInt32);
361+
EP_HOST_ASSERT(combine_send_cost_stats->dim() == 1 and combine_send_cost_stats->is_contiguous());
362+
EP_HOST_ASSERT(combine_send_cost_stats->size(0) == num_ranks);
363+
combine_send_cost_stats_out = combine_send_cost_stats.value();
364+
}
365+
348366
int64_t hidden = static_cast<int>(recv_x.size(1));
349367
at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device));
350368
int64_t tp_world_size = 1;
@@ -367,7 +385,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
367385

368386
EXEC_NPU_CMD(aclnnCamMoeCombineNormal, recv_x, token_src_info, ep_send_counts, expert_scales, tp_send_counts,
369387
hcom_ep_name, num_ranks, rank, hcom_ep_name, tp_world_size, tp_rankId, moe_expert_number, global_bs,
370-
combined_x);
388+
combined_x, combine_send_cost_stats_out);
371389

372390
if (this->is_padding) {
373391
if (this->padding_cnt == PADDING_SIZE) {

csrc/deepep/deep_ep.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ struct Buffer {
5858
const std::optional<at::Tensor> &num_tokens_per_rank, const at::Tensor &is_token_in_rank,
5959
const std::optional<at::Tensor> &num_tokens_per_expert, int cached_num_recv_tokens,
6060
const std::optional<at::Tensor> &cached_rank_prefix_matrix,
61-
const std::optional<at::Tensor> &cached_channel_prefix_matrix, int expert_alignment,
61+
const std::optional<at::Tensor> &cached_channel_prefix_matrix,
62+
const std::optional<at::Tensor> &dispatch_wait_recv_cost_stats, int expert_alignment,
6263
int num_worst_tokens, const Config &config, std::optional<EventHandle> &previous_event,
6364
bool async, bool allocate_on_comm_stream, bool use_quant);
6465

6566
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
6667

67-
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>> intranode_combine(
68-
const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional<torch::Tensor> &topk_weights,
69-
const torch::Tensor &src_idx, const torch::Tensor &send_head);
68+
std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandle>>
69+
intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx,
70+
const std::optional<torch::Tensor> &topk_weights, const torch::Tensor &src_idx,
71+
const torch::Tensor &send_head, const std::optional<at::Tensor> &combine_send_cost_stats);
7072

7173
std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, std::optional<EventHandle>,
7274
std::optional<std::function<void()>>>

csrc/deepep/ops/op_host/cam_moe_combine_normal.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class CamMoeCombineNormal : public OpDef
4343
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
4444
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
4545

46+
this->Output("combine_send_cost_stats")
47+
.ParamType(OPTIONAL)
48+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
49+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
50+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
51+
4652
this->Attr("ep_group_name").AttrType(REQUIRED).String();
4753
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
4854
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();

csrc/deepep/ops/op_host/cam_moe_combine_normal_tiling.cc

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ constexpr uint32_t EP_RECV_COUNTS_INDEX = 2;
5454
constexpr uint32_t TOPK_WEIGHTS_INDEX = 3;
5555
constexpr uint32_t TP_RECV_COUNTS_INDEX = 4;
5656
constexpr uint32_t OUTPUT_X_INDEX = 0;
57+
constexpr uint32_t OUTPUT_SEND_COST_INDEX = 1;
5758

5859
constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
5960
constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1;
@@ -238,7 +239,7 @@ static bool CheckOptionalInputTensorDim(gert::TilingContext *context, const char
238239
return true;
239240
}
240241

241-
static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName)
242+
static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose)
242243
{
243244
const gert::StorageShape *xStorageShape = context->GetOutputShape(OUTPUT_X_INDEX);
244245
OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "x is null."), return false);
@@ -249,25 +250,34 @@ static bool CheckOutputTensorDim(gert::TilingContext *context, const char *nodeN
249250
OP_LOGD(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0));
250251
OP_LOGD(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1));
251252

253+
if (isEnableDiagnose) {
254+
const gert::StorageShape *sendCostStatsStorageShape = context->GetOutputShape(OUTPUT_SEND_COST_INDEX);
255+
OP_TILING_CHECK(sendCostStatsStorageShape == nullptr, OP_LOGE(nodeName, "combine sendCostStatsShape is null."),
256+
return false);
257+
OP_TILING_CHECK(sendCostStatsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM,
258+
OP_LOGE(nodeName, "combine sendCostStatsShape must be 1-dimension, but got %lu dim",
259+
sendCostStatsStorageShape->GetStorageShape().GetDimNum()),
260+
return false);
261+
}
252262
return true;
253263
}
254264

255-
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName)
265+
static bool CheckTensorDim(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose)
256266
{
257267
OP_TILING_CHECK(!CheckInputTensorDim(context, nodeName),
258268
OP_LOGE(nodeName, "param shape of input tensor is invalid"), return false);
259269

260270
OP_TILING_CHECK(!CheckOptionalInputTensorDim(context, nodeName),
261271
OP_LOGE(nodeName, "param shape of optional input tensor is invalid"), return false);
262272

263-
OP_TILING_CHECK(!CheckOutputTensorDim(context, nodeName),
273+
OP_TILING_CHECK(!CheckOutputTensorDim(context, nodeName, isEnableDiagnose),
264274
OP_LOGE(nodeName, "param shape of output tensor is invalid"), return false);
265275

266276
return true;
267277
}
268278

269279
// 校验数据类型
270-
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName)
280+
static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose)
271281
{
272282
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
273283
OP_TILING_CHECK(recvXDesc == nullptr, OP_LOGE(nodeName, "recvXDesc is null."), return false);
@@ -296,10 +306,20 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
296306
OP_TILING_CHECK((xDesc->GetDataType() != recvXDesc->GetDataType()),
297307
OP_LOGE(nodeName, "x dataType is invalid, dataType should be equal to recvX dataType , but is "),
298308
return false);
309+
310+
if (isEnableDiagnose) {
311+
auto sendCostStatsDesc = context->GetOutputDesc(OUTPUT_SEND_COST_INDEX);
312+
OP_TILING_CHECK(sendCostStatsDesc == nullptr, OP_LOGE(nodeName, "combine sendCostStatsDesc is null."),
313+
return false);
314+
OP_TILING_CHECK(
315+
sendCostStatsDesc->GetDataType() != ge::DT_INT32,
316+
OP_LOGE(nodeName, "combine sendCostStatsDesc dataType is invalid, dataType should be int32, but is ."),
317+
return false);
318+
}
299319
return true;
300320
}
301321

302-
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName)
322+
static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName, const bool isEnableDiagnose)
303323
{
304324
auto recvXDesc = context->GetInputDesc(RECV_X_INDEX);
305325
OP_TILING_CHECK(recvXDesc == nullptr, OP_LOGE(nodeName, "recvXDesc is null."), return false);
@@ -330,6 +350,14 @@ static bool CheckTensorFormat(gert::TilingContext *context, const char *nodeName
330350
OP_TILING_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ,
331351
OP_LOGE(nodeName, "xFormat is invalid"), return false);
332352

353+
if (isEnableDiagnose) {
354+
auto sendCostStatsDesc = context->GetOutputDesc(OUTPUT_SEND_COST_INDEX);
355+
OP_TILING_CHECK(sendCostStatsDesc == nullptr, OP_LOGE(nodeName, "combine sendCostStatsDesc is null."),
356+
return false);
357+
OP_TILING_CHECK(static_cast<ge::Format>(ge::GetPrimaryFormat(sendCostStatsDesc->GetStorageFormat())) ==
358+
ge::FORMAT_FRACTAL_NZ,
359+
OP_LOGE(nodeName, "combine sendCostStatsDesc format is invalid"), return false);
360+
}
333361
return true;
334362
}
335363

@@ -435,17 +463,18 @@ static bool CheckAttrs(gert::TilingContext *context, CamMoeCombineNormalTilingDa
435463
return true;
436464
}
437465

438-
static ge::graphStatus TilingCheckCamMoeCombineNormal(gert::TilingContext *context, const char *nodeName)
466+
static ge::graphStatus TilingCheckCamMoeCombineNormal(gert::TilingContext *context, const char *nodeName,
467+
const bool isEnableDiagnose)
439468
{
440469
// 检查参数shape信息
441-
OP_TILING_CHECK(!CheckTensorDim(context, nodeName), OP_LOGE(nodeName, "param shape is invalid"),
470+
OP_TILING_CHECK(!CheckTensorDim(context, nodeName, isEnableDiagnose), OP_LOGE(nodeName, "param shape is invalid"),
442471
return ge::GRAPH_FAILED);
443472
// 检查参数dataType信息
444-
OP_TILING_CHECK(!CheckTensorDataType(context, nodeName), OP_LOGE(nodeName, "param dataType is invalid"),
445-
return ge::GRAPH_FAILED);
473+
OP_TILING_CHECK(!CheckTensorDataType(context, nodeName, isEnableDiagnose),
474+
OP_LOGE(nodeName, "param dataType is invalid"), return ge::GRAPH_FAILED);
446475
// 检查参数format信息
447-
OP_TILING_CHECK(!CheckTensorFormat(context, nodeName), OP_LOGE(nodeName, "param Format is invalid"),
448-
return ge::GRAPH_FAILED);
476+
OP_TILING_CHECK(!CheckTensorFormat(context, nodeName, isEnableDiagnose),
477+
OP_LOGE(nodeName, "param Format is invalid"), return ge::GRAPH_FAILED);
449478
return ge::GRAPH_SUCCESS;
450479
}
451480

@@ -493,8 +522,11 @@ static ge::graphStatus CamMoeCombineNormalA3TilingFuncImpl(gert::TilingContext *
493522
OP_TILING_CHECK(GetAttrAndSetTilingData(context, *tilingData, nodeName, groupEp, groupTp) == ge::GRAPH_FAILED,
494523
OP_LOGE(nodeName, "Getting attr failed."), return ge::GRAPH_FAILED);
495524

525+
auto sendCostStatsStorageShape = context->GetOutputShape(OUTPUT_SEND_COST_INDEX);
526+
bool isEnableDiagnose = (sendCostStatsStorageShape != nullptr);
527+
tilingData->camMoeCombineNormalInfo.isEnableDiagnose = isEnableDiagnose;
496528
// 检查输入输出的dim、format、dataType
497-
OP_TILING_CHECK(TilingCheckCamMoeCombineNormal(context, nodeName) != ge::GRAPH_SUCCESS,
529+
OP_TILING_CHECK(TilingCheckCamMoeCombineNormal(context, nodeName, isEnableDiagnose) != ge::GRAPH_SUCCESS,
498530
OP_LOGE(nodeName, "Tiling check params failed"), return ge::GRAPH_FAILED);
499531

500532
// 检查属性的取值是否合法

csrc/deepep/ops/op_host/cam_moe_dispatch_normal.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ class CamMoeDispatchNormal : public OpDef
6262
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
6363
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
6464

65+
this->Output("dispatch_wait_recv_cost_stats")
66+
.ParamType(OPTIONAL)
67+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
68+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
69+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
70+
6571
this->Attr("group_ep").AttrType(REQUIRED).String();
6672
this->Attr("ep_world_size").AttrType(REQUIRED).Int();
6773
this->Attr("ep_rank_id").AttrType(REQUIRED).Int();

0 commit comments

Comments
 (0)