-
Notifications
You must be signed in to change notification settings - Fork 3.7k
First Draft EyeLike CPU OP9 #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cpu/tensor/eye_like.h" | ||
| #include "core/framework/tensorprotoutils.h" | ||
| #include "core/util/math_cpuonly.h" | ||
|
|
||
| using namespace ::onnxruntime::common; | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| ONNX_CPU_OPERATOR_KERNEL( | ||
| EyeLike, | ||
| 9, | ||
| KernelDefBuilder().TypeConstraint("T1", | ||
| std::vector<MLDataType>{ | ||
| DataTypeImpl::GetTensorType<float>(), | ||
| DataTypeImpl::GetTensorType<int64_t>(), | ||
| DataTypeImpl::GetTensorType<uint64_t>(), | ||
| }) | ||
| .TypeConstraint("T2", | ||
| std::vector<MLDataType>{ | ||
| DataTypeImpl::GetTensorType<float>(), | ||
| DataTypeImpl::GetTensorType<uint64_t>(), | ||
| DataTypeImpl::GetTensorType<int64_t>(), | ||
| }), | ||
| EyeLike); | ||
|
|
||
| Status EyeLike::Compute(OpKernelContext* context) const { | ||
| const Tensor* T1 = context->Input<Tensor>(0); | ||
| ONNXRUNTIME_ENFORCE(T1 != nullptr); | ||
|
|
||
| auto output_tensor_dtype = has_dtype_ ? static_cast<onnx::TensorProto::DataType>(dtype_) : utils::GetTensorProtoType(*T1); | ||
| switch (output_tensor_dtype) { | ||
| case onnx::TensorProto_DataType_FLOAT: | ||
| return ComputeImpl<float>(context); | ||
| case onnx::TensorProto_DataType_INT64: | ||
| return ComputeImpl<int64_t>(context); | ||
| case onnx::TensorProto_DataType_UINT64: | ||
| return ComputeImpl<uint64_t>(context); | ||
| default: | ||
| ONNXRUNTIME_THROW("Unsupported 'dtype' value: ", output_tensor_dtype); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status EyeLike::ComputeImpl(OpKernelContext* context) const { | ||
| const Tensor* T1 = context->Input<Tensor>(0); | ||
| const std::vector<int64_t>& input_dims = T1->Shape().GetDims(); | ||
| if (input_dims.size() != 2) { | ||
| return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2"); | ||
| } | ||
|
|
||
| // set output tensor shape same as input tensor and set all values to zero | ||
| auto* T2 = context->Output(0, input_dims); | ||
| auto output_mat = EigenMatrixMapRowMajor<T>( | ||
| T2->template MutableData<T>(), | ||
| input_dims[0], | ||
| input_dims[1]); | ||
| output_mat.setZero(); | ||
|
|
||
| if ((k_ >= 0 && k_ >= input_dims[1]) || (k_ < 0 && std::abs(k_) >= input_dims[0])) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this an error condition?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really an error. If K is greater than row(for lower diag) or col (in case of upper diag) then we terminate early and return a tensor with all zeroes. |
||
| return Status::OK(); | ||
| } | ||
| output_mat.diagonal(k_).array() = static_cast<T>(1); | ||
|
|
||
| return Status::OK(); | ||
| } | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/common/common.h" | ||
| #include "core/framework/op_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| class EyeLike final : public OpKernel { | ||
| public: | ||
| EyeLike(const OpKernelInfo& info) : OpKernel(info) { | ||
| if (!info.GetAttr("k", &k_).IsOK()) { | ||
| k_ = 0; | ||
| } | ||
|
|
||
| has_dtype_ = info.GetAttr("dtype", &dtype_).IsOK(); | ||
| } | ||
|
|
||
| Status Compute(OpKernelContext* context) const override; | ||
|
|
||
| private: | ||
| template <typename T> | ||
| Status ComputeImpl(OpKernelContext* context) const; | ||
|
|
||
| bool has_dtype_; | ||
| int64_t dtype_; | ||
| int64_t k_; | ||
| }; | ||
|
|
||
| } //namespace onnxruntime |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "gtest/gtest.h" | ||
| #include "test/providers/provider_test_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace test { | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLikeDefault) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddOutput<float>("T2", {3, 2}, {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_DifferentDtype) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddAttribute("dtype", int64_t(7)); | ||
| test.AddInput<float>("T1", {3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddOutput<int64_t>("T2", {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_1) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<int64_t>("T1", {3, 2}, {0, 0, 0, 0, 0, 0}); | ||
| test.AddAttribute("k", int64_t(3)); | ||
| test.AddAttribute("dtype", int64_t(7)); | ||
| test.AddOutput<int64_t>("T2", {3, 2}, {0, 0, 0, 0, 0, 0}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_K_EgdeCase_2) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<int64_t>("T1", {3, 2}, {0, 0, 0, 0, 0, 0}); | ||
| test.AddAttribute("k", int64_t(-3)); | ||
| test.AddOutput<int64_t>("T2", {3, 2}, {0, 0, 0, 0, 0, 0}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_UpperDiagonal) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<float>("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddAttribute("k", int64_t(2)); | ||
| test.AddOutput<float>("T2", {3, 4}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_UpperrDiagonal2) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddAttribute("k", int64_t(1)); | ||
| test.AddOutput<float>("T2", {3, 2}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_LowerDiagonal) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<float>("T1", {3, 2}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddAttribute("k", int64_t(-1)); | ||
| test.AddAttribute("dtype", int64_t(1)); | ||
| test.AddOutput<float>("T2", {3, 2}, {0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| TEST(EyeLikeOpTest, EyeLike_LowerDiagonal2) { | ||
| OpTester test("EyeLike", 9); | ||
| test.AddInput<float>("T1", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.AddAttribute("k", int64_t(-2)); | ||
| test.AddAttribute("dtype", int64_t(1)); | ||
| test.AddOutput<float>("T2", {3, 4}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}); | ||
| test.Run(); | ||
| } | ||
|
|
||
| } // namespace test | ||
| } // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the entire list of supported types as per ONNX spec... should I add all of them?
tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(bool)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could add the types step by step based on requirement. Is there any models require other types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if any model actually needs this operator right now. I am adding this as part of OP9 support. Do we maintain this info somewhere? How do I check?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, we start with tensor(float) and add other types if need be.