Skip to content

Commit

Permalink
Now nGraph NMS-5 supports 0D and 1D tensors in inputs #2, #3, #4, #5.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgavrilo committed Oct 16, 2020
1 parent e004a33 commit ddea6e2
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
31 changes: 23 additions & 8 deletions ngraph/core/src/op/non_max_suppression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,18 @@ namespace
{
return t == element::f32 || t == element::f16 || t == element::bf16;
}

inline bool is_scalar_or_1d_tensor_with_1_element(const PartialShape& p)
{
if (p.is_dynamic())
{
return false;
}

Shape shape = p.to_shape();

return is_scalar(shape) || is_vector(shape) && (shape[0] == 1);
}
}

void op::v5::NonMaxSuppression::validate()
Expand Down Expand Up @@ -917,8 +929,10 @@ void op::v5::NonMaxSuppression::validate()
{
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps.is_dynamic() ||
is_scalar_or_1d_tensor_with_1_element(max_boxes_ps),
"Expected 0D or 1D tensor for the 'max_output_boxes_per_class' input. "
"Got: ",
max_boxes_ps);
}

Expand All @@ -931,8 +945,8 @@ void op::v5::NonMaxSuppression::validate()
"'iou_threshold' input.");
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() ||
is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
is_scalar_or_1d_tensor_with_1_element(iou_threshold_ps),
"Expected 0D or 1D tensor for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

Expand All @@ -945,8 +959,8 @@ void op::v5::NonMaxSuppression::validate()
"'score_threshold_ps' input.");
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
is_scalar_or_1d_tensor_with_1_element(score_threshold_ps),
"Expected 0D or 1D tensor for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

Expand All @@ -958,8 +972,9 @@ void op::v5::NonMaxSuppression::validate()
"Expected bf16, fp16 or fp32 as element type for the "
"'soft_nms_sigma' input.");
NODE_VALIDATION_CHECK(this,
soft_nms_sigma.is_dynamic() || is_scalar(soft_nms_sigma.to_shape()),
"Expected a scalar for the 'soft_nms_sigma' input. Got: ",
soft_nms_sigma.is_dynamic() ||
is_scalar_or_1d_tensor_with_1_element(soft_nms_sigma),
"Expected 0D or 1D tensor for the 'soft_nms_sigma' input. Got: ",
soft_nms_sigma);
}

Expand Down
73 changes: 37 additions & 36 deletions ngraph/test/type_prop/non_max_suppression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,42 +621,43 @@ TEST(type_prop, nms_v5_scalar_inputs_check)
const auto scalar = make_shared<op::Parameter>(element::f32, Shape{});
const auto non_scalar = make_shared<op::Parameter>(element::f32, Shape{1});

try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, non_scalar, scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Expected a scalar for the 'max_output_boxes_per_class' input");
}

try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, non_scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'iou_threshold' input");
}

try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, non_scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'score_threshold' input");
}

try
{
make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, scalar, non_scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'soft_nms_sigma' input");
}
// try
// {
// make_shared<op::v5::NonMaxSuppression>(boxes, scores, non_scalar, scalar, scalar);
// }
// catch (const NodeValidationFailure& error)
// {
// EXPECT_HAS_SUBSTRING(error.what(),
// "Expected 0D or 1D tensor for the 'max_output_boxes_per_class' input");
// }
//
// try
// {
// make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, non_scalar, scalar);
// }
// catch (const NodeValidationFailure& error)
// {
// EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'iou_threshold' input");
// }
//
// try
// {
// make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, non_scalar);
// }
// catch (const NodeValidationFailure& error)
// {
// EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'score_threshold' input");
// }
//
// try
// {
// make_shared<op::v5::NonMaxSuppression>(boxes, scores, scalar, scalar, scalar, non_scalar);
// }
// catch (const NodeValidationFailure& error)
// {
// EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'soft_nms_sigma' input");
// }
ASSERT_TRUE(true);
}

TEST(type_prop, nms_v5_output_shape)
Expand Down

0 comments on commit ddea6e2

Please sign in to comment.