diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5f4c104009f53..a35671c461fba 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -285,6 +285,7 @@ set_target_properties(onnx_proto PROPERTIES FOLDER "External/ONNX") # fix a warning in onnx code we can't do anything about if (MSVC) target_compile_options(onnx_proto PRIVATE /wd4146) # unary minus operator applied to unsigned type + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DEIGEN_HAS_C99_MATH") # required to be set explicitly to enable Eigen-Unsupported SpecialFunctions endif() set(onnxruntime_EXTERNAL_DEPENDENCIES gsl onnx_proto) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f5303df02c82f..cdcb09e523654 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -195,6 +195,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Con class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MLFloat16, IsNaN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Erf); void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); @@ -382,6 +383,7 @@ void RegisterOnnxOperatorKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); } // Forward declarations of ml op kernels diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 8f48195ea0145..ab05ed82be80c 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/math/element_wise_ops.h" +#include namespace onnxruntime { @@ -311,6 +312,12 @@ ONNX_CPU_OPERATOR_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Scale); +ONNX_CPU_OPERATOR_KERNEL( + Erf, + 9, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Erf); + template Status Add::Compute(OpKernelContext* context) const { return BroadcastTwo( @@ -874,4 +881,16 @@ Status Scale::Compute(OpKernelContext* ctx) const { return Status::OK(); } +template <> +Status Erf::Compute(OpKernelContext* context) const { + auto X_ptr = context->Input(0); + ONNXRUNTIME_ENFORCE(X_ptr != nullptr); + auto& X = *X_ptr; + auto& Y = *context->Output(0, X.Shape()); + + EigenMap(Y) = EigenMap(X).array().erf(); + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 912d352f39e24..feaa5c0a82b13 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -317,6 +317,15 @@ class Scale final : public OpKernel { float scale_; }; +template +class Erf final : public OpKernel { + public: + Erf(const OpKernelInfo& info) : OpKernel(info) { + } + + Status Compute(OpKernelContext* context) const override; +}; + template auto MakeEigenArrayMap(Tensor& t) { return EigenVectorArrayMap(t.template MutableData(), t.Shape().Size()); } template diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 4871e1102ceed..77f717987e026 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -326,7 +326,6 @@ int real_main(int argc, char* argv[]) { {"acosh_example", "opset 9 not supported yet"}, {"atanh_example", "opset 9 not supported yet"}, {"sign_model", "opset 9 not supported yet"}, - {"erf", "opset 9 not supported yet"}, {"sign", "opset 9 not supported yet"}, {"scatter_with_axis", "opset 9 not supported yet"}, {"scatter_without_axis", "opset 9 not supported yet"}, diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index e6eaf6c99e782..e8845542390fe 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -869,6 +869,13 @@ TEST(MathOpTest, Scale_Default) { test.Run(); } +TEST(MathOpTest, Erf) { + OpTester test("Erf", 9); + std::vector dims{2, 2}; + test.AddInput("A", dims, {0.5f, 1.0f, 0.7f, 2.0f}); + test.AddOutput("B", dims, {0.5204999f, 0.8427008f, 0.6778012f, 0.9953223f}); + test.Run(); +} } // namespace test } // namespace onnxruntime