Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes: 1. normalized support. 2. sort by score before keep_top_k insi… #8

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ namespace ngraph
float y2 = 0.0f;
};

static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ)
static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ, const bool normalized)
{
float areaI = (boxI.y2 - boxI.y1) * (boxI.x2 - boxI.x1);
float areaJ = (boxJ.y2 - boxJ.y1) * (boxJ.x2 - boxJ.x1);
const float norm = static_cast<float>(normalized == false);

float areaI = (boxI.y2 - boxI.y1 + norm) * (boxI.x2 - boxI.x1 + norm);
float areaJ = (boxJ.y2 - boxJ.y1 + norm) * (boxJ.x2 - boxJ.x1 + norm);

if (areaI <= 0.0f || areaJ <= 0.0f)
{
Expand All @@ -56,8 +58,8 @@ namespace ngraph
float intersection_xmax = std::min(boxI.x2, boxJ.x2);

float intersection_area =
std::max(intersection_ymax - intersection_ymin, 0.0f) *
std::max(intersection_xmax - intersection_xmin, 0.0f);
std::max(intersection_ymax - intersection_ymin + norm, 0.0f) *
std::max(intersection_xmax - intersection_xmin + norm, 0.0f);

return intersection_area / (areaI + areaJ - intersection_area);
}
Expand Down Expand Up @@ -241,7 +243,7 @@ namespace ngraph
continue;
}

// sort by score
// sort by score in current class
std::partial_sort(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size,
candidate_boxes.end(),
Expand Down Expand Up @@ -270,7 +272,7 @@ namespace ngraph
--j)
{
float iou = multiclass_nms_v8::intersectionOverUnion(
next_candidate.box, selected[j].box);
next_candidate.box, selected[j].box, normalized);
next_candidate.score *= func(iou, adaptive_threshold);

if (iou >= adaptive_threshold)
Expand Down Expand Up @@ -312,22 +314,18 @@ namespace ngraph
num_dets += selected.size();
} // for each class

/* sort inside batch element */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "NONE" and "CLASSID", pass through
// sort inside batch element before go through keep_top_k
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});

// threshold keep_top_k for each batch element
if (keep_top_k > -1 && keep_top_k < num_dets)
Expand All @@ -336,6 +334,27 @@ namespace ngraph
selected_boxes.resize(num_dets);
}

// sort
if (!sort_result_across_batch)
{
if (sort_result_type == op::v8::MulticlassNms::SortResultType::CLASSID)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.class_index < r.class_index) ||
((l.class_index == r.class_index) &&
l.score > r.score) ||
((std::fabs(l.score - r.score) <= 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "SCORE", pass through, as,
// it has already gurranteed.
}

*valid_outputs++ = num_dets;
for (auto& v : selected_boxes)
{
Expand Down