From fb096deae714b3daab0341a1660419af4fe08d13 Mon Sep 17 00:00:00 2001 From: jialipen Date: Thu, 24 Jun 2021 21:04:35 +0800 Subject: [PATCH] fixes: 1. normalized support. 2. sort by score before keep_top_k inside a batch. --- .../src/runtime/reference/multiclass_nms.cpp | 65 ++++++++++++------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp b/ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp index 887167b299c516..ea092a77bed918 100644 --- a/ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp +++ b/ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp @@ -40,10 +40,12 @@ namespace ngraph float y2 = 0.0f; }; - static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ) + static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ, const bool normalized) { - float areaI = (boxI.y2 - boxI.y1) * (boxI.x2 - boxI.x1); - float areaJ = (boxJ.y2 - boxJ.y1) * (boxJ.x2 - boxJ.x1); + const float norm = static_cast(normalized == false); + + float areaI = (boxI.y2 - boxI.y1 + norm) * (boxI.x2 - boxI.x1 + norm); + float areaJ = (boxJ.y2 - boxJ.y1 + norm) * (boxJ.x2 - boxJ.x1 + norm); if (areaI <= 0.0f || areaJ <= 0.0f) { @@ -56,8 +58,8 @@ namespace ngraph float intersection_xmax = std::min(boxI.x2, boxJ.x2); float intersection_area = - std::max(intersection_ymax - intersection_ymin, 0.0f) * - std::max(intersection_xmax - intersection_xmin, 0.0f); + std::max(intersection_ymax - intersection_ymin + norm, 0.0f) * + std::max(intersection_xmax - intersection_xmin + norm, 0.0f); return intersection_area / (areaI + areaJ - intersection_area); } @@ -241,7 +243,7 @@ namespace ngraph continue; } - // sort by score + // sort by score in current class std::partial_sort(candidate_boxes.begin(), candidate_boxes.begin() + candiate_size, candidate_boxes.end(), @@ -270,7 +272,7 @@ namespace ngraph --j) { float iou = multiclass_nms_v8::intersectionOverUnion( - next_candidate.box, selected[j].box); + next_candidate.box, selected[j].box, normalized); next_candidate.score *= func(iou, adaptive_threshold); if (iou >= adaptive_threshold) @@ -312,22 +314,18 @@ namespace ngraph num_dets += selected.size(); } // for each class - /* sort inside batch element */ - if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE) - { - std::sort(selected_boxes.begin(), - selected_boxes.end(), - [](const BoxInfo& l, const BoxInfo& r) { - return ( - (l.batch_index == r.batch_index) && - ((l.score > r.score) || - ((std::fabs(l.score - r.score) < 1e-6) && - l.class_index < r.class_index) || - ((std::fabs(l.score - r.score) < 1e-6) && - l.class_index == r.class_index && l.index < r.index))); - }); - } - // in case of "NONE" and "CLASSID", pass through + // sort inside batch element before go through keep_top_k + std::sort(selected_boxes.begin(), + selected_boxes.end(), + [](const BoxInfo& l, const BoxInfo& r) { + return ( + (l.batch_index == r.batch_index) && + ((l.score > r.score) || + ((std::fabs(l.score - r.score) < 1e-6) && + l.class_index < r.class_index) || + ((std::fabs(l.score - r.score) < 1e-6) && + l.class_index == r.class_index && l.index < r.index))); + }); // threshold keep_top_k for each batch element if (keep_top_k > -1 && keep_top_k < num_dets) @@ -336,6 +334,27 @@ namespace ngraph selected_boxes.resize(num_dets); } + // sort + if (!sort_result_across_batch) + { + if (sort_result_type == op::v8::MulticlassNms::SortResultType::CLASSID) + { + std::sort(selected_boxes.begin(), + selected_boxes.end(), + [](const BoxInfo& l, const BoxInfo& r) { + return ( + (l.batch_index == r.batch_index) && + ((l.class_index < r.class_index) || + ((l.class_index == r.class_index) && + l.score > r.score) || + ((std::fabs(l.score - r.score) <= 1e-6) && + l.class_index == r.class_index && l.index < r.index))); + }); + } + // in case of "SCORE", pass through, as, + // it has already gurranteed. + } + *valid_outputs++ = num_dets; for (auto& v : selected_boxes) {