Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sca
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loop);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike);

// Opset 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantLike);
Expand Down Expand Up @@ -373,6 +374,7 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loop)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike)>());

// Opset 9
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantLike)>());
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/eye_like.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/eye_like.h"
#include "core/framework/tensorprotoutils.h"
#include "core/util/math_cpuonly.h"

using namespace ::onnxruntime::common;

namespace onnxruntime {

ONNX_CPU_OPERATOR_KERNEL(
EyeLike,
9,
KernelDefBuilder().TypeConstraint("T1",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the entire list of supported types as per ONNX spec... should I add all of them?
tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(bool)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add the types step by step based on requirement. Is there any models require other types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if any model actually needs this operator right now. I am adding this as part of OP9 support. Do we maintain this info somewhere? How do I check?

Copy link
Contributor

@duli2012 duli2012 Dec 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, we start with tensor(float) and add other types if need be.

DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
})
.TypeConstraint("T2",
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
}),
EyeLike);

Status EyeLike::Compute(OpKernelContext* context) const {
const Tensor* T1 = context->Input<Tensor>(0);
ONNXRUNTIME_ENFORCE(T1 != nullptr);

auto output_tensor_dtype = has_dtype_ ? static_cast<onnx::TensorProto::DataType>(dtype_) : utils::GetTensorProtoType(*T1);
switch (output_tensor_dtype) {
case onnx::TensorProto_DataType_FLOAT:
return ComputeImpl<float>(context);
case onnx::TensorProto_DataType_INT64:
return ComputeImpl<int64_t>(context);
case onnx::TensorProto_DataType_UINT64:
return ComputeImpl<uint64_t>(context);
default:
ONNXRUNTIME_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
}
}

template <typename T>
Status EyeLike::ComputeImpl(OpKernelContext* context) const {
const Tensor* T1 = context->Input<Tensor>(0);
const std::vector<int64_t>& input_dims = T1->Shape().GetDims();
if (input_dims.size() != 2) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
}

// set output tensor shape same as input tensor and set all values to zero
auto* T2 = context->Output(0, input_dims);
auto output_mat = EigenMatrixMapRowMajor<T>(
T2->template MutableData<T>(),
input_dims[0],
input_dims[1]);
output_mat.setZero();

if ((k_ >= 0 && k_ >= input_dims[1]) || (k_ < 0 && std::abs(k_) >= input_dims[0])) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an error condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really an error. If K is greater than row(for lower diag) or col (in case of upper diag) then we terminate early and return a tensor with all zeroes.

return Status::OK();
}
output_mat.diagonal(k_).array() = static_cast<T>(1);

return Status::OK();
}
} // namespace onnxruntime
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/eye_like.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {

class EyeLike final : public OpKernel {
public:
EyeLike(const OpKernelInfo& info) : OpKernel(info) {
if (!info.GetAttr("k", &k_).IsOK()) {
k_ = 0;
}

has_dtype_ = info.GetAttr("dtype", &dtype_).IsOK();
}

Status Compute(OpKernelContext* context) const override;

private:
template <typename T>
Status ComputeImpl(OpKernelContext* context) const;

bool has_dtype_;
int64_t dtype_;
int64_t k_;
};

} //namespace onnxruntime
77 changes: 77 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/eyelike_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

TEST(EyeLikeOpTest, EyeLikeDefault) {
OpTester test("EyeLike", 9);
test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddOutput<float>("T2", {3, 2}, {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_DifferentDtype) {
OpTester test("EyeLike", 9);
test.AddAttribute("dtype", int64_t(7));
test.AddInput<float>("T1", {3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddOutput<int64_t>("T2", {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_1) {
OpTester test("EyeLike", 9);
test.AddInput<int64_t>("T1", {3, 2}, {0, 0, 0, 0, 0, 0});
test.AddAttribute("k", int64_t(3));
test.AddAttribute("dtype", int64_t(7));
test.AddOutput<int64_t>("T2", {3, 2}, {0, 0, 0, 0, 0, 0});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_2) {
OpTester test("EyeLike", 9);
test.AddInput<int64_t>("T1", {3, 2}, {0, 0, 0, 0, 0, 0});
test.AddAttribute("k", int64_t(-3));
test.AddOutput<int64_t>("T2", {3, 2}, {0, 0, 0, 0, 0, 0});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_UpperDiagonal) {
OpTester test("EyeLike", 9);
test.AddInput<float>("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddAttribute("k", int64_t(2));
test.AddOutput<float>("T2", {3, 4}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_UpperrDiagonal2) {
OpTester test("EyeLike", 9);
test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddAttribute("k", int64_t(1));
test.AddOutput<float>("T2", {3, 2}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_LowerDiagonal) {
OpTester test("EyeLike", 9);
test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddAttribute("k", int64_t(-1));
test.AddAttribute("dtype", int64_t(1));
test.AddOutput<float>("T2", {3, 2}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f});
test.Run();
}

TEST(EyeLikeOpTest, EyeLike_LowerDiagonal2) {
OpTester test("EyeLike", 9);
test.AddInput<float>("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddAttribute("k", int64_t(-2));
test.AddAttribute("dtype", int64_t(1));
test.AddOutput<float>("T2", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f});
test.Run();
}

} // namespace test
} // namespace onnxruntime