From 1776a53d77bc876224b12ce5ad9737c15dd98ae3 Mon Sep 17 00:00:00 2001 From: Anton Zaytsev Date: Wed, 26 Aug 2020 18:13:13 +0300 Subject: [PATCH] add evaluates function in evaluates_map (#4) --- .../runtime/interpreter/evaluates_map.cpp | 31 +++++++++++++++++++ .../runtime/interpreter/opset_int_tbl.hpp | 3 ++ 2 files changed, 34 insertions(+) diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index ab597a92767231..5f0289de2154c9 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -33,6 +33,7 @@ #include "reference/hard_sigmoid.hpp" #include "reference/elu.hpp" #include "reference/selu.hpp" +#include "reference/ctc_loss.hpp" using namespace ngraph; using namespace std; @@ -407,6 +408,36 @@ namespace { return true; } + template + bool evaluate(const shared_ptr &op, const HostTensorVector &outputs, + const HostTensorVector &input) { + using T = typename element_type_traits::value_type; +#define REF_CALL(elType) \ + runtime::reference::CTCLoss::value_type>( \ + input[0]->get_data_ptr(), \ + input[0]->get_shape(), \ + input[1]->get_data_ptr(), \ + input[2]->get_data_ptr(), \ + input[3]->get_data_ptr(), \ + input[4]->get_data_ptr(), \ + op->get_preprocess_collapse_repeated(), \ + op->get_ctc_merge_repeated(), \ + op->get_unique(), \ + outputs[0]->get_data_ptr()); \ + break; + + switch (input[1]->get_element_type()) { + case element::Type_t::i32: + REF_CALL(element::Type_t::i32); + case element::Type_t::i64: + REF_CALL(element::Type_t::i64); + default: + return false; + } +#undef REF_CALL + return true; + } + template bool evaluate_node(std::shared_ptr node, const HostTensorVector &outputs, const HostTensorVector &inputs) { switch (node->get_element_type()) { diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 8eb0d521f41d42..0dbc601cab0d97 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -49,3 +49,6 @@ NGRAPH_OP(Elu, op::v0) NGRAPH_OP(Selu, op::v0) NGRAPH_OP(Ceiling, op::v0) NGRAPH_OP(Gelu, op::v0) + + +NGRAPH_OP(CTCLoss, op::v4)