Skip to content

Commit

Permalink
pdpd frontend: enable deformable_conv, enable multiclass_nms and matr…
Browse files Browse the repository at this point in the history
…ix_nms. (#6833)
  • Loading branch information
ceciliapeng2011 authored Aug 3, 2021
1 parent 344e063 commit 950d7b8
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 1 deletion.
74 changes: 74 additions & 0 deletions ngraph/frontend/paddlepaddle/src/op/deformable_conv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset8.hpp>
#include <node_context.hpp>
#include "conv2d_utils.hpp"

namespace ngraph
{
namespace frontend
{
namespace pdpd
{
namespace op
{
NamedOutputs deformable_conv(const NodeContext& node)
{
auto input = node.get_ng_input("Input");
auto filter = node.get_ng_input("Filter");
auto offset = node.get_ng_input("Offset");

auto strides = node.get_attribute<std::vector<int>>("strides");
auto dilations = node.get_attribute<std::vector<int>>("dilations");

auto groups = node.get_attribute<int>("groups");
auto deformable_groups = node.get_attribute<int>("deformable_groups");

const auto paddings = get_pads(node);
const auto pads_begin = paddings.first;
const auto pads_end = paddings.second;

const ngraph::op::PadType auto_pad{ngraph::op::PadType::EXPLICIT};

std::shared_ptr<Node> output_node;
if (node.has_ng_input("Mask"))
{
auto mask = node.get_ng_input("Mask");
output_node = std::make_shared<ngraph::opset8::DeformableConvolution>(
input,
offset,
filter,
mask,
ngraph::Strides(strides.begin(), strides.end()),
pads_begin,
pads_end,
ngraph::Strides(dilations.begin(), dilations.end()),
auto_pad,
groups,
deformable_groups,
true);
}
else
{
output_node = std::make_shared<ngraph::opset8::DeformableConvolution>(
input,
offset,
filter,
ngraph::Strides(strides.begin(), strides.end()),
pads_begin,
pads_end,
ngraph::Strides(dilations.begin(), dilations.end()),
auto_pad,
groups,
deformable_groups,
true);
}

return node.default_single_output_mapping({output_node}, {"Output"});
}

} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph
99 changes: 99 additions & 0 deletions ngraph/frontend/paddlepaddle/src/op/matrix_nms.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset8.hpp>
#include <node_context.hpp>
#include <paddlepaddle_frontend/utility.hpp>

namespace ngraph
{
namespace frontend
{
namespace pdpd
{
namespace op
{
NamedOutputs matrix_nms(const NodeContext& node)
{
using namespace ngraph;
using namespace opset8;
using namespace element;

auto bboxes = node.get_ng_input("BBoxes");
auto scores = node.get_ng_input("Scores");

auto score_threshold = node.get_attribute<float>("score_threshold");
auto post_threshold = node.get_attribute<float>("post_threshold");
auto nms_top_k = node.get_attribute<int>("nms_top_k");
auto keep_top_k = node.get_attribute<int>("keep_top_k");
auto background_class = node.get_attribute<int>("background_label");

auto gaussian_sigma = node.get_attribute<float>("gaussian_sigma");
auto use_gaussian = node.get_attribute<bool>("use_gaussian");
auto decay_function = MatrixNms::DecayFunction::LINEAR;
if (use_gaussian)
{
decay_function = MatrixNms::DecayFunction::GAUSSIAN;
}

auto out_names = node.get_output_names();
PDPD_ASSERT(out_names.size() == 3 || out_names.size() == 2,
"Unexpected number of outputs of MatrixNMS: " + out_names.size());

element::Type type_num = i32;
bool return_rois_num = true;
auto it = std::find(out_names.begin(), out_names.end(), "RoisNum");
if (it != out_names.end())
{
type_num = node.get_out_port_type("RoisNum");
}
else
{
return_rois_num = false;
}

auto type_index = node.get_out_port_type("Index");
PDPD_ASSERT((type_index == i32 || type_index == i64) &&
(type_num == i32 || type_num == i64),
"Unexpected data type of outputs of MatrixNMS");

auto normalized = node.get_attribute<bool>("normalized");

NamedOutputs named_outputs;
std::vector<Output<Node>> nms_outputs;
MatrixNms::Attributes attrs;
attrs.nms_top_k = nms_top_k;
attrs.post_threshold = post_threshold;
attrs.score_threshold = score_threshold;
attrs.sort_result_type = MatrixNms::SortResultType::SCORE;
attrs.keep_top_k = keep_top_k;
attrs.background_class = background_class;
attrs.normalized = normalized;
attrs.output_type = type_index;
attrs.sort_result_across_batch = false;
attrs.decay_function = decay_function;
attrs.gaussian_sigma = gaussian_sigma;

nms_outputs = std::make_shared<MatrixNms>(bboxes, scores, attrs)->outputs();

named_outputs["Out"] = {nms_outputs[0]};
named_outputs["Index"] = {nms_outputs[1]};
if (return_rois_num)
{
named_outputs["RoisNum"] = {nms_outputs[2]};

if (type_num != type_index)
{
// adapter
auto node_convert = std::make_shared<Convert>(nms_outputs[2], type_num);
named_outputs["RoisNum"] = {node_convert};
}
}

return named_outputs;
}

} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph
78 changes: 78 additions & 0 deletions ngraph/frontend/paddlepaddle/src/op/multiclass_nms.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset8.hpp>
#include <node_context.hpp>
#include <paddlepaddle_frontend/utility.hpp>

namespace ngraph
{
namespace frontend
{
namespace pdpd
{
namespace op
{
NamedOutputs multiclass_nms(const NodeContext& node)
{
using namespace ngraph;
using namespace opset8;
using namespace element;

auto bboxes = node.get_ng_input("BBoxes");
auto scores = node.get_ng_input("Scores");

auto score_threshold = node.get_attribute<float>("score_threshold");
auto iou_threshold = node.get_attribute<float>("nms_threshold");
auto nms_top_k = node.get_attribute<int>("nms_top_k");
auto keep_top_k = node.get_attribute<int>("keep_top_k");
auto background_class = node.get_attribute<int>("background_label");
auto nms_eta = node.get_attribute<float>("nms_eta");

auto out_names = node.get_output_names();
PDPD_ASSERT(out_names.size() == 3,
"Unexpected number of outputs of MulticlassNMS");

auto type_index = node.get_out_port_type("Index");
auto type_num = node.get_out_port_type("NmsRoisNum");
PDPD_ASSERT((type_index == i32 || type_index == i64) &&
(type_num == i32 || type_num == i64),
"Unexpected data type of outputs of MulticlassNMS: " +
out_names.size());

auto normalized = node.get_attribute<bool>("normalized");

NamedOutputs named_outputs;
std::vector<Output<Node>> nms_outputs;
MulticlassNms::Attributes attrs;
attrs.nms_top_k = nms_top_k;
attrs.iou_threshold = iou_threshold;
attrs.score_threshold = score_threshold;
attrs.sort_result_type = MulticlassNms::SortResultType::CLASSID;
attrs.keep_top_k = keep_top_k;
attrs.background_class = background_class;
attrs.nms_eta = nms_eta;
attrs.normalized = normalized;
attrs.output_type = type_index;
attrs.sort_result_across_batch = false;

nms_outputs = std::make_shared<MulticlassNms>(bboxes, scores, attrs)->outputs();

named_outputs["Out"] = {nms_outputs[0]};
named_outputs["Index"] = {nms_outputs[1]};
named_outputs["NmsRoisNum"] = {nms_outputs[2]};

if (type_num != type_index)
{
// adapter
auto node_convert = std::make_shared<Convert>(nms_outputs[2], type_num);
named_outputs["NmsRoisNum"] = {node_convert};
}

return named_outputs;
}

} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph
8 changes: 7 additions & 1 deletion ngraph/frontend/paddlepaddle/src/op_table.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "op_table.hpp"

namespace ngraph
Expand All @@ -23,6 +22,7 @@ namespace ngraph
OP_CONVERTER(concat);
OP_CONVERTER(conv2d);
OP_CONVERTER(conv2d_transpose);
OP_CONVERTER(deformable_conv);
OP_CONVERTER(dropout);
OP_CONVERTER(elementwise_add);
OP_CONVERTER(elementwise_div);
Expand All @@ -43,6 +43,8 @@ namespace ngraph
OP_CONVERTER(log);
OP_CONVERTER(logical_not);
OP_CONVERTER(matmul);
OP_CONVERTER(matrix_nms);
OP_CONVERTER(multiclass_nms);
OP_CONVERTER(nearest_interp_v2);
OP_CONVERTER(rnn);
OP_CONVERTER(relu);
Expand Down Expand Up @@ -74,6 +76,8 @@ namespace ngraph
{"concat", op::concat},
{"conv2d", op::conv2d},
{"conv2d_transpose", op::conv2d_transpose},
{"deformable_conv", op::deformable_conv},
{"deformable_conv_v1", op::deformable_conv},
{"depthwise_conv2d", op::conv2d},
{"depthwise_conv2d_transpose", op::conv2d_transpose},
{"dropout", op::dropout},
Expand All @@ -96,6 +100,8 @@ namespace ngraph
{"log", op::log},
{"logical_not", op::logical_not},
{"matmul", op::matmul},
{"matrix_nms", op::matrix_nms},
{"multiclass_nms3", op::multiclass_nms},
{"nearest_interp_v2", op::nearest_interp_v2},
{"nearest_interp", op::nearest_interp_v2},
{"rnn", op::rnn},
Expand Down

0 comments on commit 950d7b8

Please sign in to comment.