Skip to content

Commit

Permalink
Added checks for input element types for inputs #0, #1, #3, #4, #5.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgavrilo committed Oct 16, 2020
1 parent debdf36 commit 3d9d2e4
Showing 1 changed file with 123 additions and 85 deletions.
208 changes: 123 additions & 85 deletions ngraph/core/src/op/non_max_suppression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,91 +676,6 @@ shared_ptr<Node>
m_output_type);
}

void op::v5::NonMaxSuppression::validate()
{
const auto boxes_ps = get_input_partial_shape(0);
const auto scores_ps = get_input_partial_shape(1);

NODE_VALIDATION_CHECK(this,
m_output_type == element::i64 || m_output_type == element::i32,
"Output type must be i32 or i64");

if (boxes_ps.is_dynamic() || scores_ps.is_dynamic())
{
return;
}

NODE_VALIDATION_CHECK(this,
boxes_ps.rank().is_static() && boxes_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'boxes' input. Got: ",
boxes_ps);

NODE_VALIDATION_CHECK(this,
scores_ps.rank().is_static() && scores_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);

const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);

const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);

const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);

const auto soft_nms_sigma = get_input_partial_shape(5);
NODE_VALIDATION_CHECK(this,
soft_nms_sigma.is_dynamic() || is_scalar(soft_nms_sigma.to_shape()),
"Expected a scalar for the 'soft_nms_sigma' input. Got: ",
soft_nms_sigma);

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
NODE_VALIDATION_CHECK(this,
num_batches_boxes.same_scheme(num_batches_scores),
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
num_batches_boxes,
"; Scores: ",
num_batches_scores);

const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[2];
NODE_VALIDATION_CHECK(this,
num_boxes_boxes.same_scheme(num_boxes_scores),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively. Boxes: ",
num_boxes_boxes,
"; Scores: ",
num_boxes_scores);

NODE_VALIDATION_CHECK(this,
boxes_ps[2].is_static() && boxes_ps[2].get_length() == 4u,
"The last dimension of the 'boxes' input must be equal to 4. Got:",
boxes_ps[2]);
}

int64_t op::v5::NonMaxSuppression::max_boxes_output_from_input() const
{
int64_t max_output_boxes{0};

const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);

return max_output_boxes;
}

using V5BoxEncoding = op::v5::NonMaxSuppression::BoxEncodingType;

namespace
Expand Down Expand Up @@ -958,6 +873,129 @@ namespace
*valid_outputs_ptr = static_cast<int32_t>(valid_outputs);
}
}

inline bool is_float_type_admissible(const element::Type& t)
{
return t == element::f32 || t == element::f16 || t == element::bf16;
}
}

void op::v5::NonMaxSuppression::validate()
{
const auto boxes_ps = get_input_partial_shape(0);
const auto scores_ps = get_input_partial_shape(1);

NODE_VALIDATION_CHECK(this,
m_output_type == element::i64 || m_output_type == element::i32,
"Output type must be i32 or i64");

if (boxes_ps.is_dynamic() || scores_ps.is_dynamic())
{
return;
}

NODE_VALIDATION_CHECK(this,
is_float_type_admissible(get_input_element_type(0)),
"Expected bf16, fp16 or fp32 as element type for the 'boxes' input.");

NODE_VALIDATION_CHECK(this,
is_float_type_admissible(get_input_element_type(1)),
"Expected bf16, fp16 or fp32 as element type for the 'scores' input.");

NODE_VALIDATION_CHECK(this,
boxes_ps.rank().is_static() && boxes_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'boxes' input. Got: ",
boxes_ps);

NODE_VALIDATION_CHECK(this,
scores_ps.rank().is_static() && scores_ps.rank().get_length() == 3,
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);

if (inputs().size() >= 3)
{
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
}

if (inputs().size() >= 4)
{
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
is_float_type_admissible(get_input_element_type(3)),
"Expected bf16, fp16 or fp32 as element type for the "
"'iou_threshold' input.");
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() ||
is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

if (inputs().size() >= 5)
{
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
is_float_type_admissible(get_input_element_type(4)),
"Expected bf16, fp16 or fp32 as element type for the "
"'score_threshold_ps' input.");
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

if (inputs().size() >= 6)
{
const auto soft_nms_sigma = get_input_partial_shape(5);
NODE_VALIDATION_CHECK(this,
is_float_type_admissible(get_input_element_type(5)),
"Expected bf16, fp16 or fp32 as element type for the "
"'soft_nms_sigma' input.");
NODE_VALIDATION_CHECK(this,
soft_nms_sigma.is_dynamic() || is_scalar(soft_nms_sigma.to_shape()),
"Expected a scalar for the 'soft_nms_sigma' input. Got: ",
soft_nms_sigma);
}

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
NODE_VALIDATION_CHECK(this,
num_batches_boxes.same_scheme(num_batches_scores),
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
num_batches_boxes,
"; Scores: ",
num_batches_scores);

const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[2];
NODE_VALIDATION_CHECK(this,
num_boxes_boxes.same_scheme(num_boxes_scores),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively. Boxes: ",
num_boxes_boxes,
"; Scores: ",
num_boxes_scores);

NODE_VALIDATION_CHECK(this,
boxes_ps[2].is_static() && boxes_ps[2].get_length() == 4u,
"The last dimension of the 'boxes' input must be equal to 4. Got:",
boxes_ps[2]);
}

int64_t op::v5::NonMaxSuppression::max_boxes_output_from_input() const
{
int64_t max_output_boxes{0};

const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);

return max_output_boxes;
}

float op::v5::NonMaxSuppression::iou_threshold_from_input() const
Expand Down

0 comments on commit 3d9d2e4

Please sign in to comment.