Skip to content

Commit

Permalink
add evaluates function in evaluates_map (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonzaycev96 authored Aug 26, 2020
1 parent 1e083c2 commit 1776a53
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
31 changes: 31 additions & 0 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "reference/hard_sigmoid.hpp"
#include "reference/elu.hpp"
#include "reference/selu.hpp"
#include "reference/ctc_loss.hpp"

using namespace ngraph;
using namespace std;
Expand Down Expand Up @@ -407,6 +408,36 @@ namespace {
return true;
}

template<element::Type_t ET>
bool evaluate(const shared_ptr<op::v4::CTCLoss> &op, const HostTensorVector &outputs,
const HostTensorVector &input) {
using T = typename element_type_traits<ET>::value_type;
#define REF_CALL(elType) \
runtime::reference::CTCLoss<T, typename element_type_traits<elType>::value_type>( \
input[0]->get_data_ptr<T>(), \
input[0]->get_shape(), \
input[1]->get_data_ptr<elType>(), \
input[2]->get_data_ptr<elType>(), \
input[3]->get_data_ptr<elType>(), \
input[4]->get_data_ptr<elType>(), \
op->get_preprocess_collapse_repeated(), \
op->get_ctc_merge_repeated(), \
op->get_unique(), \
outputs[0]->get_data_ptr<T>()); \
break;

switch (input[1]->get_element_type()) {
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
default:
return false;
}
#undef REF_CALL
return true;
}

template<typename T>
bool evaluate_node(std::shared_ptr<Node> node, const HostTensorVector &outputs, const HostTensorVector &inputs) {
switch (node->get_element_type()) {
Expand Down
3 changes: 3 additions & 0 deletions ngraph/test/runtime/interpreter/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ NGRAPH_OP(Elu, op::v0)
NGRAPH_OP(Selu, op::v0)
NGRAPH_OP(Ceiling, op::v0)
NGRAPH_OP(Gelu, op::v0)


NGRAPH_OP(CTCLoss, op::v4)

0 comments on commit 1776a53

Please sign in to comment.