Skip to content

Commit

Permalink
Updated Mul->Add conversion to support dynamic shapes (#512)
Browse files Browse the repository at this point in the history
* Updated Mul Add conversion to support dynamic shapes

* Keep changes

* Fix for cases when eltwise performs broadcasting via Constant

* Added comments;Fixed eltwise shape infer; Updated tests
  • Loading branch information
GlebKazantaev authored May 26, 2020
1 parent e835a4c commit d3764a7
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,13 @@ enum class CONVERSION_RESULT {
NONE
};

/*
* check_constant function checks how given constant performs elementwise operation with given input
* CONVERSION_RESULT has several types:
* SCALE_SHIFT - constant applies only per-channel
* POWER - constant applies as single value
* NONE - default return value
*/

INFERENCE_ENGINE_API_CPP(CONVERSION_RESULT)
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::Shape & shape);
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::PartialShape & shape);
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ ngraph::graph_rewrite_callback get_callback() {
"Unsupported template parameter. Only Add or Multiply allowed!");

auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
if (!lin_op) {
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
return false;
}

const auto output_shape = lin_op->output(0).get_partial_shape();
const auto output_shape_rank = output_shape.rank().get_length();

if (!lin_op->get_element_type().is_real()) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
Expand All @@ -93,39 +96,58 @@ ngraph::graph_rewrite_callback get_callback() {
}
}

// Check that eltwise is not useless otherwise we remove it
if ((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) {
bool has_result_output = false;
for (const auto & output : lin_op->output(0).get_target_inputs()) {
if (dynamic_cast<ngraph::op::Result*>(output.get_node())) {
has_result_output = true;
}
/* This lambda checks data and constant shapes for broadcasting
For example:
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
*/
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
if (data_pshape.rank().is_dynamic() || const_shape.size() > data_pshape.rank().get_length()) {
return true;
}

auto parent = data_node.get_node_shared_ptr();
size_t consumers_count = 0;
for (const auto &output : parent->outputs()) {
consumers_count += output.get_target_inputs().size();
std::vector<ngraph::Dimension> data_shape(data_pshape);

auto const_shape_it = const_shape.rbegin();
auto data_shape_it = data_shape.rbegin();

while (const_shape_it != const_shape.rend()) {
auto data_dim = *data_shape_it;
auto const_dim = *const_shape_it;

/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
DYN - 64 - TRUE
DYN - 1 - FALSE
64 - 1 - FALSE
1 - 64 - TRUE
64 - 64 - FALSE
*/
if ((data_dim.is_dynamic() && const_dim != 1) ||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
return true;
}

++const_shape_it;
++data_shape_it;
}

if (!has_result_output || consumers_count == 1) {
if (!std::dynamic_pointer_cast<ngraph::op::Parameter>(parent)) {
parent->set_friendly_name(lin_op->get_friendly_name());
}
// TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input
// using replace_source_output method
for (auto &input : lin_op->output(0).get_target_inputs()) {
input.replace_source_output(data_node);
}
return false;
};

// Check that eltwise is not useless and do not broadcast output otherwise we remove it
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
if (ret_status) {
return true;
}
}

auto res = check_constant(const_node, data_node.get_partial_shape());

auto res = check_constant(const_node, data_node.get_shape());

if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && lin_op->get_shape().size() < 4)) {
if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4)) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
Expand All @@ -140,12 +162,12 @@ ngraph::graph_rewrite_callback get_callback() {
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
if (std::is_same<T, ngraph::opset1::Add>()) {
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, lin_op->get_shape()),
ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()));
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, output_shape),
ngraph::op::util::normalize_constant(const_node, output_shape));
} else {
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()),
ngraph::op::util::normalize_constant(bias, lin_op->get_shape()));
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, output_shape),
ngraph::op::util::normalize_constant(bias, output_shape));
}

scaleshift->set_friendly_name(lin_op->get_friendly_name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
INFERENCE_ENGINE_API_CPP(bool) get_single_value(const std::shared_ptr<op::Constant> & const_node, float & value);

INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) normalize_constant(const std::shared_ptr<op::Constant> & constant,
const Shape & shape);
const PartialShape & shape);

INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) broadcastTo(const Output<Node>& input, const Shape& shape);

Expand Down
18 changes: 13 additions & 5 deletions inference-engine/src/transformations/src/ngraph_ops/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,24 @@ void op::Eltwise::validate_and_infer_types() {
NODE_VALIDATION_CHECK(this, element::Type::merge(et_result, data1_et, data2_et),
"Element types for first and second do not match :", data1_et, " and ", data2_et);

auto shape1 = get_input_partial_shape(0).to_shape();
auto shape2 = get_input_partial_shape(1).to_shape();
if (get_input_partial_shape(0).rank().is_dynamic() ||
get_input_partial_shape(1).rank().is_dynamic()) {
set_output_type(0, et_result, PartialShape::dynamic());
return;
}

std::vector<Dimension> shape1(get_input_partial_shape(0));
std::vector<Dimension> shape2(get_input_partial_shape(1));

ngraph::Shape output_shape(std::max(shape1.size(), shape2.size()));
std::vector<Dimension> output_shape(PartialShape::dynamic(std::max(shape1.size(), shape2.size())));
auto output_shape_it = output_shape.rbegin();

auto shape1_it = shape1.rbegin(), shape2_it = shape2.rbegin();
while (shape1_it != shape1.rend() || shape2_it != shape2.rend()) {
if (shape1_it != shape1.rend() && shape2_it != shape2.rend()) {
*output_shape_it = std::max(*shape1_it, *shape2_it);
if (shape1_it->is_static() && shape2_it->is_static()) {
*output_shape_it = (shape1_it->get_length() > shape2_it->get_length() ? *shape1_it : *shape2_it);
}
} else if (shape1_it != shape1.rend()) {
*output_shape_it = *shape1_it;
} else if (shape2_it != shape2.rend()) {
Expand All @@ -61,5 +69,5 @@ void op::Eltwise::validate_and_infer_types() {
}
}

set_output_type(0, data1_et, PartialShape(output_shape));
set_output_type(0, et_result, output_shape);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include "ngraph_ops/scaleshift.hpp"

CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>& constant,
const ngraph::Shape& shape) {
if (!constant) return CONVERSION_RESULT::NONE;
const ngraph::PartialShape& shape) {
if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE;

auto const_shape = constant->get_shape();
auto input_shape = shape;
std::vector<ngraph::Dimension> input_shape(shape);

// In case of scalar we will convert it to Power
if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) {
Expand All @@ -47,7 +47,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>

if (idx == feature_index && *in_it == 1) {
is_power = true;
} else if (idx == feature_index && *in_it != *out_it) {
} else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) {
return CONVERSION_RESULT::NONE;
}
}
Expand Down Expand Up @@ -95,6 +95,11 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_0);
}

if (add_node->get_output_partial_shape(0).rank().is_dynamic() ||
mul_node->get_output_partial_shape(0).rank().is_dynamic()) {
return false;
}

// Check that eltwise is not useless otherwise we remove it
if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) &&
ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) {
Expand Down Expand Up @@ -124,20 +129,23 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
}
}

auto res1 = check_constant(const_weights_node, data_node.get_shape());
auto res2 = check_constant(const_bias_node, mul_node->get_output_shape(0));
auto res1 = check_constant(const_weights_node, data_node.get_partial_shape());
auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0));

const auto output_shape = add_node->get_output_partial_shape(0);
const auto output_shape_rank = output_shape.rank().get_length();

if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && add_node->get_shape().size() < 4)) {
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && output_shape_rank < 4)) {
return false;
}

// TODO: in case if scale and shift constants has equal values the best way is to convert them to Power
if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) {
NodeVector new_ops;

auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, add_node->get_shape());
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, add_node->get_shape());
auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape);
new_ops.push_back(weights_in);
new_ops.push_back(biases_in);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ bool get_single_value(const std::shared_ptr<op::Constant>& const_node, float& va
}

std::shared_ptr<Node> normalize_constant(const std::shared_ptr<op::Constant>& constant,
const Shape& shape) {
const PartialShape& shape) {
auto const_shape = constant->get_shape();
if (const_shape.size() == shape.size()) {
if (const_shape.size() == shape.rank().get_length()) {
return constant;
}
int cnt = shape.size() - const_shape.size();
int64_t cnt = shape.rank().get_length() - const_shape.size();
for (int i = 0; i < cnt; ++i) {
const_shape.insert(const_shape.begin(), 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ TEST_F(NGraphReaderTests, RemoveAdd2) {
</output>
</layer>
<layer id="3" name="add" precision="FP32" type="ReLU">
<data originalLayersNames="relu"/>
<data originalLayersNames="add,relu"/>
<input>
<port id="0">
<dim>1</dim>
Expand Down
Loading

0 comments on commit d3764a7

Please sign in to comment.