diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d377dca633b05..5677053f8b089 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -189,6 +189,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Sca class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Loop); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); // Opset 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantLike); @@ -373,6 +374,7 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); // Opset 9 fn(BuildKernel()); diff --git a/onnxruntime/core/providers/cpu/tensor/eye_like.cc b/onnxruntime/core/providers/cpu/tensor/eye_like.cc new file mode 100644 index 0000000000000..0ce03906fb4d0 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/eye_like.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/eye_like.h" +#include "core/framework/tensorprotoutils.h" +#include "core/util/math_cpuonly.h" + +using namespace ::onnxruntime::common; + +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL( + EyeLike, + 9, + KernelDefBuilder().TypeConstraint("T1", + std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + }) + .TypeConstraint("T2", + std::vector{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + }), + EyeLike); + +Status EyeLike::Compute(OpKernelContext* context) const { + const Tensor* T1 = context->Input(0); + ONNXRUNTIME_ENFORCE(T1 != nullptr); + + auto output_tensor_dtype = has_dtype_ ? static_cast(dtype_) : utils::GetTensorProtoType(*T1); + switch (output_tensor_dtype) { + case onnx::TensorProto_DataType_FLOAT: + return ComputeImpl(context); + case onnx::TensorProto_DataType_INT64: + return ComputeImpl(context); + case onnx::TensorProto_DataType_UINT64: + return ComputeImpl(context); + default: + ONNXRUNTIME_THROW("Unsupported 'dtype' value: ", output_tensor_dtype); + } +} + +template +Status EyeLike::ComputeImpl(OpKernelContext* context) const { + const Tensor* T1 = context->Input(0); + const std::vector& input_dims = T1->Shape().GetDims(); + if (input_dims.size() != 2) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2"); + } + + // set output tensor shape same as input tensor and set all values to zero + auto* T2 = context->Output(0, input_dims); + auto output_mat = EigenMatrixMapRowMajor( + T2->template MutableData(), + input_dims[0], + input_dims[1]); + output_mat.setZero(); + + if ((k_ >= 0 && k_ >= input_dims[1]) || (k_ < 0 && std::abs(k_) >= input_dims[0])) { + return Status::OK(); + } + output_mat.diagonal(k_).array() = static_cast(1); + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/eye_like.h b/onnxruntime/core/providers/cpu/tensor/eye_like.h new file mode 100644 index 0000000000000..bbc81eb1e7a8d --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/eye_like.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class EyeLike final : public OpKernel { + public: + EyeLike(const OpKernelInfo& info) : OpKernel(info) { + if (!info.GetAttr("k", &k_).IsOK()) { + k_ = 0; + } + + has_dtype_ = info.GetAttr("dtype", &dtype_).IsOK(); + } + + Status Compute(OpKernelContext* context) const override; + + private: + template + Status ComputeImpl(OpKernelContext* context) const; + + bool has_dtype_; + int64_t dtype_; + int64_t k_; +}; + +} //namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/eyelike_op_test.cc b/onnxruntime/test/providers/cpu/tensor/eyelike_op_test.cc new file mode 100644 index 0000000000000..6e5c5ffcd7b4d --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/eyelike_op_test.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(EyeLikeOpTest, EyeLikeDefault) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddOutput("T2", {3, 2}, {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_DifferentDtype) { + OpTester test("EyeLike", 9); + test.AddAttribute("dtype", int64_t(7)); + test.AddInput("T1", {3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddOutput("T2", {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_1) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 2}, {0, 0, 0, 0, 0, 0}); + test.AddAttribute("k", int64_t(3)); + test.AddAttribute("dtype", int64_t(7)); + test.AddOutput("T2", {3, 2}, {0, 0, 0, 0, 0, 0}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_2) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 2}, {0, 0, 0, 0, 0, 0}); + test.AddAttribute("k", int64_t(-3)); + test.AddOutput("T2", {3, 2}, {0, 0, 0, 0, 0, 0}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_UpperDiagonal) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddAttribute("k", int64_t(2)); + test.AddOutput("T2", {3, 4}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_UpperrDiagonal2) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddAttribute("k", int64_t(1)); + test.AddOutput("T2", {3, 2}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_LowerDiagonal) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddAttribute("k", int64_t(-1)); + test.AddAttribute("dtype", int64_t(1)); + test.AddOutput("T2", {3, 2}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f}); + test.Run(); +} + +TEST(EyeLikeOpTest, EyeLike_LowerDiagonal2) { + OpTester test("EyeLike", 9); + test.AddInput("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddAttribute("k", int64_t(-2)); + test.AddAttribute("dtype", int64_t(1)); + test.AddOutput("T2", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime