Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add erfinv operator for calculating inverse error function (#13811)
Browse files Browse the repository at this point in the history
* add default behaviour for argmax

* prototype of erfvin

* add test

* gpu support

* Revert "add default behaviour for argmax"

This reverts commit 64e9f1a.

* move erfinv to contrib

* edit copyright

* remove atof

* use std and update license

* add license exclude file

* fix per eric's comments

* change license header
  • Loading branch information
Ziyi Mu authored and eric-haibin-lin committed Jan 22, 2019
1 parent eebdd5f commit b86ccf1
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/api/python/ndarray/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ The `ndarray` package provides several classes:
relu
sigmoid
erf
erfinv
```

### More
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ Composite multiple symbols into a new one by an operator.
relu
sigmoid
erf
erfinv
```

### More
Expand Down
105 changes: 105 additions & 0 deletions src/operator/contrib/erfinv-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2014 Indiana University
* All rights reserved.
* Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
* Indiana University, Bloomington, IN
* This software is licensed under the New BSD license:
* Redistribution and use in source and binary forms,
* with or without modification, are permitted provided
* that the following conditions are met:
* Redistributions of source code must retain the above
* copyright notice, this list of conditions and the
* following disclaimer.
* Redistributions in binary form must reproduce the
* above copyright notice, this list of conditions and
* the following disclaimer in the documentation and/or
* other materials provided with the distribution.
* Neither the name of Indiana University nor
* the names of its contributors may be used to endorse
* or promote products derived from this software without
* specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
* CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
* THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
* IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
/*
* The next function is taken from
* https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
* Output was modified to be inf or -inf when input is 1 or -1.
*/
#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_

#define _USE_MATH_DEFINES

#include <mxnet/base.h>
#include <limits>
#include "math.h"

namespace mxnet {
namespace op {
namespace mshadow_op {

/*! \brief inverse gauss error function */
struct erfinv : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType v) {
/* Function to calculate inverse error function. Rational approximation
is used to generate an initial approximation, which is then improved to
full accuracy by two steps of Newton's method. Code is a direct
translation of the erfinv m file in matlab version 2.0.
Author: Gary L. Pavlis, Indiana University
Date: February 1996
*/
const double central_range = 0.7;
double y = static_cast<double>(v);
double y_fab = std::fabs(y);
/*working variables */
double x = 0.0;
double z, num, dem;
/* coefficients in rational expansion */
double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331};
double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801};
double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311};
double d[2]={ 3.543889200, 1.637067800};
if (y_fab > 1.0) {
/* This needs IEEE constant*/
return DType(std::numeric_limits<double>::quiet_NaN());
} else if (y_fab == 1.0) {
return DType((std::copysign(1.0, y))*std::numeric_limits<double>::infinity());
} else if (y_fab <= central_range) {
z = y*y;
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0);
x = y*num/dem;
} else {
z = std::sqrt(-std::log((1.0-y_fab)/2.0));
num = ((c[3]*z + c[2])*z + c[1])*z + c[0];
dem = (d[1]*z + d[0])*z + 1.0;
x = (std::copysign(1.0, y))*num/dem;
}
/* Two steps of Newton-Raphson correction */
x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));
x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));

return DType(x);
}
};

} // namespace mshadow_op
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
3 changes: 3 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "math_functions-inl.h"
#include "special_functions-inl.h"
#include "./operator_tune.h"
#include "./contrib/erfinv-inl.h"

#ifdef __CUDACC__
#include <cuda_fp16.h>
Expand Down Expand Up @@ -169,6 +170,8 @@ struct softrelu : public mxnet_op::tunable {

MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));

MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(erfinv::Map(a))));

MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));

MXNET_SIMPLE_UNARY_MATH_OP(erf);
Expand Down
4 changes: 3 additions & 1 deletion src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erf); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erf_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erfinv); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erfinv_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad); // NOLINT()
Expand Down
16 changes: 16 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,22 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_erf)
.set_attr<FCompute>("FCompute<cpu>",
ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::erf_grad>>);

// erfinv
MXNET_OPERATOR_REGISTER_UNARY(erfinv)
.describe(R"code(Returns element-wise inverse gauss error function of the input.
Example::
erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf]
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::erfinv>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_erfinv"});

MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv)
.set_attr<FCompute>("FCompute<cpu>",
ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::erfinv_grad>>);

// rcbrt
MXNET_OPERATOR_REGISTER_UNARY(rcbrt)
.describe(R"code(Returns element-wise inverse cube-root value of the input.
Expand Down
8 changes: 8 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ NNVM_REGISTER_OP(_backward_erf)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::Compute<gpu, unary_bwd<mshadow_op::erf_grad>>);

// erfinv
NNVM_REGISTER_OP(erfinv)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::erfinv>);

NNVM_REGISTER_OP(_backward_erfinv)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::Compute<gpu, unary_bwd<mshadow_op::erfinv_grad>>);

// copy
NNVM_REGISTER_OP(_copy)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>)
Expand Down
3 changes: 2 additions & 1 deletion tests/nightly/apache_rat_license_check/rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ _mask.pyx
coco.py
base.pyi
special_functions-inl.h
erfinv-inl.h
im2col.cuh
im2col.h
pool.h
Expand All @@ -49,4 +50,4 @@ deformable_im2col.h
REQUIRE
include/*
.*.iml
.*.json.ref
.*.json.ref
6 changes: 5 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,7 +3500,11 @@ def test_special_functions_using_scipy():

# erf
mathematical_core("erf", lambda x: mx.sym.erf(x), lambda x: scipy_special.erf(x),
lambda x: 2.0 / math.sqrt(math.pi) * math.exp(-(x ** 2)), 0.5, 0.5)
lambda x: 2.0 / math.sqrt(math.pi) * np.exp(-(x ** 2)), 0.5, 0.5)

# erfinv
mathematical_core("erfinv", lambda x: mx.sym.erfinv(x), lambda x: scipy_special.erfinv(x),
lambda x: 0.5 * math.sqrt(math.pi) * np.exp(scipy_special.erfinv(x) ** 2), 0.5, 0.5)


def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., grad_init=2.):
Expand Down
1 change: 1 addition & 0 deletions tools/license_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
'src/operator/nn/im2col.cuh',

# Licenses in headers
'src/operator/contrib/erfinv-inl.h',
'docs/_static/searchtools_custom.js',
'docs/_static/js/clipboard.js',
'docs/_static/js/clipboard.min.js',
Expand Down

2 comments on commit b86ccf1

@szha
Copy link
Member

@szha szha commented on b86ccf1 Jan 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit seems to have added a number of warnings:

src/operator/tensor/./.././contrib/erfinv-inl.h(79): warning: calling a constexpr __host__ function("quiet_NaN") from a __host__ __device__ function("Map") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.
          detected during:
            instantiation of "void mxnet::op::mxnet_op::op_with_req<OP, req>::Map(mxnet::index_t, DType *, const DType *) [with OP=mxnet::op::mshadow_op::erfinv, req=1, DType=float]"
src/operator/tensor/./../mxnet_op.h(659): here
            instantiation of "void mxnet::op::mxnet_op::mxnet_generic_kernel<OP,Args...>(int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<float *, float *>]"
src/operator/tensor/./../mxnet_op.h(679): here
            instantiation of "void mxnet::op::mxnet_op::Kernel<OP, mxnet::gpu>::Launch(mshadow::Stream<mshadow::gpu> *, int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<float *, float *>]"
src/operator/tensor/./elemwise_unary_op.h(243): here
            instantiation of "void mxnet::op::UnaryOp::Compute<xpu,OP>(const nnvm::NodeAttrs &, const mxnet::OpContext &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &, const std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType>> &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &) [with xpu=mxnet::gpu, OP=mxnet::op::mshadow_op::erfinv]"
src/operator/tensor/elemwise_unary_op_basic.cu(67): here

src/operator/tensor/./.././contrib/erfinv-inl.h(81): warning: calling a constexpr __host__ function("infinity") from a __host__ __device__ function("Map") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.
          detected during:
            instantiation of "void mxnet::op::mxnet_op::op_with_req<OP, req>::Map(mxnet::index_t, DType *, const DType *) [with OP=mxnet::op::mshadow_op::erfinv, req=1, DType=float]"
src/operator/tensor/./../mxnet_op.h(659): here
            instantiation of "void mxnet::op::mxnet_op::mxnet_generic_kernel<OP,Args...>(int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<float *, float *>]"
src/operator/tensor/./../mxnet_op.h(679): here
            instantiation of "void mxnet::op::mxnet_op::Kernel<OP, mxnet::gpu>::Launch(mshadow::Stream<mshadow::gpu> *, int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<float *, float *>]"
src/operator/tensor/./elemwise_unary_op.h(243): here
            instantiation of "void mxnet::op::UnaryOp::Compute<xpu,OP>(const nnvm::NodeAttrs &, const mxnet::OpContext &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &, const std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType>> &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &) [with xpu=mxnet::gpu, OP=mxnet::op::mshadow_op::erfinv]"
src/operator/tensor/elemwise_unary_op_basic.cu(67): here

src/operator/tensor/./.././contrib/erfinv-inl.h(79): warning: floating-point value does not fit in required integral type
          detected during:
            instantiation of "DType mxnet::op::mshadow_op::erfinv::Map(DType) [with DType=int8_t]"
src/operator/tensor/./../mxnet_op.h(437): here
            instantiation of "void mxnet::op::mxnet_op::op_with_req<OP, req>::Map(mxnet::index_t, DType *, const DType *) [with OP=mxnet::op::mshadow_op::erfinv, req=1, DType=int8_t]"
src/operator/tensor/./../mxnet_op.h(659): here
            instantiation of "void mxnet::op::mxnet_op::mxnet_generic_kernel<OP,Args...>(int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int8_t *, int8_t *>]"
src/operator/tensor/./../mxnet_op.h(679): here
            instantiation of "void mxnet::op::mxnet_op::Kernel<OP, mxnet::gpu>::Launch(mshadow::Stream<mshadow::gpu> *, int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int8_t *, int8_t *>]"
src/operator/tensor/./elemwise_unary_op.h(243): here
            instantiation of "void mxnet::op::UnaryOp::Compute<xpu,OP>(const nnvm::NodeAttrs &, const mxnet::OpContext &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &, const std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType>> &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &) [with xpu=mxnet::gpu, OP=mxnet::op::mshadow_op::erfinv]"
src/operator/tensor/elemwise_unary_op_basic.cu(67): here

src/operator/tensor/./.././contrib/erfinv-inl.h(79): warning: floating-point value does not fit in required integral type
          detected during:
            instantiation of "DType mxnet::op::mshadow_op::erfinv::Map(DType) [with DType=int32_t]"
src/operator/tensor/./../mxnet_op.h(437): here
            instantiation of "void mxnet::op::mxnet_op::op_with_req<OP, req>::Map(mxnet::index_t, DType *, const DType *) [with OP=mxnet::op::mshadow_op::erfinv, req=1, DType=int32_t]"
src/operator/tensor/./../mxnet_op.h(659): here
            instantiation of "void mxnet::op::mxnet_op::mxnet_generic_kernel<OP,Args...>(int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int32_t *, int32_t *>]"
src/operator/tensor/./../mxnet_op.h(679): here
            instantiation of "void mxnet::op::mxnet_op::Kernel<OP, mxnet::gpu>::Launch(mshadow::Stream<mshadow::gpu> *, int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int32_t *, int32_t *>]"
src/operator/tensor/./elemwise_unary_op.h(243): here
            instantiation of "void mxnet::op::UnaryOp::Compute<xpu,OP>(const nnvm::NodeAttrs &, const mxnet::OpContext &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &, const std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType>> &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &) [with xpu=mxnet::gpu, OP=mxnet::op::mshadow_op::erfinv]"
src/operator/tensor/elemwise_unary_op_basic.cu(67): here

src/operator/tensor/./.././contrib/erfinv-inl.h(79): warning: floating-point value does not fit in required integral type
          detected during:
            instantiation of "DType mxnet::op::mshadow_op::erfinv::Map(DType) [with DType=int64_t]"
src/operator/tensor/./../mxnet_op.h(437): here
            instantiation of "void mxnet::op::mxnet_op::op_with_req<OP, req>::Map(mxnet::index_t, DType *, const DType *) [with OP=mxnet::op::mshadow_op::erfinv, req=1, DType=int64_t]"
src/operator/tensor/./../mxnet_op.h(659): here
            instantiation of "void mxnet::op::mxnet_op::mxnet_generic_kernel<OP,Args...>(int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int64_t *, int64_t *>]"
src/operator/tensor/./../mxnet_op.h(679): here
            instantiation of "void mxnet::op::mxnet_op::Kernel<OP, mxnet::gpu>::Launch(mshadow::Stream<mshadow::gpu> *, int, Args...) [with OP=mxnet::op::mxnet_op::op_with_req<mxnet::op::mshadow_op::erfinv, 1>, Args=<int64_t *, int64_t *>]"
src/operator/tensor/./elemwise_unary_op.h(243): here
            instantiation of "void mxnet::op::UnaryOp::Compute<xpu,OP>(const nnvm::NodeAttrs &, const mxnet::OpContext &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &, const std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType>> &, const std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob>> &) [with xpu=mxnet::gpu, OP=mxnet::op::mshadow_op::erfinv]"
src/operator/tensor/elemwise_unary_op_basic.cu(67): here

@eric-haibin-lin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rondogency could you take a look?

Please sign in to comment.