Skip to content

Commit

Permalink
handle broadcast of paddle elementwise with AutoBroadcastType::PDPD
Browse files Browse the repository at this point in the history
  • Loading branch information
meiyang-intel committed Nov 29, 2021
1 parent c380c16 commit fc8d57e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bool convert_divide(std::shared_ptr<ngraph::Node> node) {
ngraph::copy_runtime_info(div, pow.get_node_shared_ptr());
}

auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow, div->get_autob());

mul->set_friendly_name(div->get_friendly_name());
ngraph::copy_runtime_info(div, mul);
Expand Down Expand Up @@ -72,4 +72,4 @@ ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() {

auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
this->register_matcher(m, callback);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ngraph::pass::ConvertSubtract::ConvertSubtract() {
auto neg = std::make_shared<ngraph::opset1::Multiply>(sub->input(1).get_source_output(),
opset1::Constant::create(sub->get_input_element_type(1), Shape{}, {-1}));

auto add = std::make_shared<ngraph::opset1::Add>(sub->input(0).get_source_output(), neg);
auto add = std::make_shared<ngraph::opset1::Add>(sub->input(0).get_source_output(), neg, sub->get_autob());

add->set_friendly_name(sub->get_friendly_name());
ngraph::copy_runtime_info(sub, {neg, add});
Expand Down
12 changes: 2 additions & 10 deletions src/frontends/paddlepaddle/src/op/elementwise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,8 @@ NamedOutputs elementwise_ops(const NodeContext& node) {
if ((axis == -1) || (axis == x_rank - 1) || (x_rank == y_rank)) {
return node.default_single_output_mapping({std::make_shared<T>(x, y)}, {"Out"});
} else {
std::vector<int64_t> indices;
for (int64_t i = 0; i < axis; i++)
indices.push_back(i);
for (int64_t i = y_rank + axis; i < x_rank; i++)
indices.push_back(i);

auto indices_node =
default_opset::Constant::create(ngraph::element::i64, ngraph::Shape{indices.size()}, indices);
auto y_node = std::make_shared<default_opset::Unsqueeze>(y, indices_node);
return node.default_single_output_mapping({std::make_shared<T>(x, y_node)}, {"Out"});
ov::op::AutoBroadcastSpec pdpd_broadcast(ov::op::AutoBroadcastType::PDPD, axis);
return node.default_single_output_mapping({std::make_shared<T>(x, y, pdpd_broadcast)}, {"Out"});
}
}

Expand Down

0 comments on commit fc8d57e

Please sign in to comment.