From 64e9f1a9e3c9cabf312b8d80b3520b22da31c0b6 Mon Sep 17 00:00:00 2001 From: rondogency Date: Mon, 24 Dec 2018 16:44:56 -0500 Subject: [PATCH 01/12] add default behaviour for argmax --- src/operator/tensor/broadcast_reduce_op.h | 6 +++--- tests/python/unittest/test_ndarray.py | 26 ++++++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 1edcb5a74a77..f22f263b6ee9 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -88,11 +88,11 @@ struct ReduceAxisParam : public dmlc::Parameter { dmlc::optional axis; bool keepdims; DMLC_DECLARE_PARAMETER(ReduceAxisParam) { - DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("The axis along which to perform the reduction. " "Negative values means indexing from right to left. " - "``Requires axis to be set as int, because global reduction " - "is not supported yet.``"); + "``The axis need to be set as an int. If the axis is " + "not set, the rightmost axis will be reduced.``"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axis is left " "in the result as dimension with size one."); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 0aa48553901b..897079dc39c8 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -477,7 +477,6 @@ def test_dot(): C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True) assert_almost_equal(c, C.asnumpy(), atol=atol) - @with_seed() def test_reduce(): sample_num = 200 @@ -524,6 +523,31 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin), mx.nd.argmin, False) +@with_seed() +def test_argmax_argmin(): + # test optional parameters + # test name : input data, argmax result, argmin result + tests = { + 'axis_0' : [[[1, 2, 3], [4, 5, 6]], [1, 1, 1], [0, 0, 0]], + 'keep_dims' : [[[1, 2, 3], [4, 5, 6]], [[2], [2]], [[0], [0]]], + 'axis_none' : [[1, 2, 3, 4], 3, 0] + } + + arg_max = mx.nd.array(tests['axis_0'][0]).argmax(axis=0) + arg_min = mx.nd.array(tests['axis_0'][0]).argmin(axis=0) + assert_almost_equal(arg_max.asnumpy(), tests['axis_0'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['axis_0'][2]) + + arg_max = mx.nd.array(tests['keep_dims'][0]).argmax(axis=1, keepdims=True) + arg_min = mx.nd.array(tests['keep_dims'][0]).argmin(axis=1, keepdims=True) + assert_almost_equal(arg_max.asnumpy(), tests['keep_dims'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['keep_dims'][2]) + + arg_max = mx.nd.array(tests['axis_none'][0]).argmax() + arg_min = mx.nd.array(tests['axis_none'][0]).argmin() + assert_almost_equal(arg_max.asnumpy(), tests['axis_none'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['axis_none'][2]) + @with_seed() def test_broadcast(): sample_num = 1000 From d610bfa515b7bf727c90ee06dc91ab0e8a6c1239 Mon Sep 17 00:00:00 2001 From: rondogency Date: Mon, 7 Jan 2019 23:40:59 -0500 Subject: [PATCH 02/12] prototype of erfvin --- docs/api/python/ndarray/ndarray.md | 1 + docs/api/python/symbol/symbol.md | 1 + src/operator/mshadow_op.h | 87 +++++++++++++++++++ src/operator/operator_tune.cc | 4 +- .../tensor/elemwise_unary_op_basic.cc | 16 ++++ 5 files changed, 108 insertions(+), 1 deletion(-) diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md index 6419c4ed4067..2df18c286ba7 100644 --- a/docs/api/python/ndarray/ndarray.md +++ b/docs/api/python/ndarray/ndarray.md @@ -659,6 +659,7 @@ The `ndarray` package provides several classes: relu sigmoid erf + erfinv ``` ### More diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index 9eba2618065b..0fc2aa7c6cf2 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -659,6 +659,7 @@ Composite multiple symbols into a new one by an operator. relu sigmoid erf + erfinv ``` ### More diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 0b20a02634c3..9199c430beed 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -169,6 +169,93 @@ struct softrelu : public mxnet_op::tunable { MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a)); + +/* The next function is taken from +https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. +Below is the copyright. Output was modified to be inf or -inf when input is 1 or -1. + + Copyright (c) 2014 Indiana University + All rights reserved. + Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + Indiana University, Bloomington, IN + This software is licensed under the New BSD license: + Redistribution and use in source and binary forms, + with or without modification, are permitted provided + that the following conditions are met: + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + Redistributions in binary form must reproduce the + above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or + other materials provided with the distribution. + Neither the name of Indiana University nor + the names of its contributors may be used to endorse + or promote products derived from this software without + specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. +*/ + +#define CENTRAL_RANGE 0.7 +/*! \brief inverse gauss error function */ +struct erfinv : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType v) { + /* Function to calculate inverse error function. Rational approximation + is used to generate an initial approximation, which is then improved to + full accuracy by two steps of Newton's method. Code is a direct + translation of the erfinv m file in matlab version 2.0. + Author: Gary L. Pavlis, Indiana University + Date: February 1996 + */ + double y = static_cast(v); + double x,z,num,dem; /*working variables */ + /* coefficients in rational expansion */ + double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; + double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + double d[2]={ 3.543889200, 1.637067800}; + if(fabs(y) > 1.0) return DType(atof("NaN")); /* This needs IEEE constant*/ + if(fabs(y) == 1.0) return DType((copysign(1.0,y))*atof("INFINITY")); + if(fabs(y) <= CENTRAL_RANGE ) + { + z = y*y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); + x = y*num/dem; + } + else if( (fabs(y) > CENTRAL_RANGE) && (fabs(y) < 1.0) ) + { + z = sqrt(-log((1.0-fabs(y))/2.0)); + num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; + dem = (d[1]*z + d[0])*z + 1.0; + x = (copysign(1.0,y))*num/dem; + } + /* Two steps of Newton-Raphson correction */ + x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); + x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); + + return DType(x); + } +}; +#undef CENTRAL_RANGE + +MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(mshadow_op::erfinv::Map(a)))); + MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); MXNET_SIMPLE_UNARY_MATH_OP(erf); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 2018e80cb48b..56d35b23b369 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -234,9 +234,11 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad); // NOLINT() -IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erf); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erf_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erfinv); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erfinv_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad); // NOLINT() diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 7f69395d1c87..d752c26dc50e 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -916,6 +916,22 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_erf) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); +// erfinv +MXNET_OPERATOR_REGISTER_UNARY(erfinv) +.describe(R"code(Returns element-wise inverse gauss error function of the input. + +Example:: + + erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf] + +)code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_erfinv"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv) +.set_attr("FCompute", + ElemwiseBinaryOp::Compute>); + // rcbrt MXNET_OPERATOR_REGISTER_UNARY(rcbrt) .describe(R"code(Returns element-wise inverse cube-root value of the input. From 57d2bbaf7ca952d871624959daaff987f2ad65b3 Mon Sep 17 00:00:00 2001 From: rondogency Date: Tue, 8 Jan 2019 23:38:13 -0500 Subject: [PATCH 03/12] add test --- tests/python/unittest/test_operator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d8e80d7d6938..f15d08622966 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3500,7 +3500,11 @@ def test_special_functions_using_scipy(): # erf mathematical_core("erf", lambda x: mx.sym.erf(x), lambda x: scipy_special.erf(x), - lambda x: 2.0 / math.sqrt(math.pi) * math.exp(-(x ** 2)), 0.5, 0.5) + lambda x: 2.0 / math.sqrt(math.pi) * np.exp(-(x ** 2)), 0.5, 0.5) + + # erfinv + mathematical_core("erfinv", lambda x: mx.sym.erfinv(x), lambda x: scipy_special.erfinv(x), + lambda x: 0.5 * math.sqrt(math.pi) * np.exp(scipy_special.erfinv(x) ** 2), 0.5, 0.5) def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., grad_init=2.): From 08aec73db460206bd2e78fcd225d24bbe7e17fcb Mon Sep 17 00:00:00 2001 From: rondogency Date: Tue, 8 Jan 2019 23:50:11 -0500 Subject: [PATCH 04/12] gpu support --- src/operator/tensor/elemwise_unary_op_basic.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 14f2be02ab1a..642cb0e6e48b 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -62,6 +62,14 @@ NNVM_REGISTER_OP(_backward_erf) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); +// erfinv +NNVM_REGISTER_OP(erfinv) +.set_attr("FCompute", UnaryOp::Compute); + +NNVM_REGISTER_OP(_backward_erfinv) +.set_attr("FCompute", + ElemwiseBinaryOp::Compute>); + // copy NNVM_REGISTER_OP(_copy) .set_attr("FCompute", UnaryOp::IdentityCompute) From 006b2c827f87c824292f221b14571ac34f846b19 Mon Sep 17 00:00:00 2001 From: rondogency Date: Wed, 9 Jan 2019 00:13:12 -0500 Subject: [PATCH 05/12] Revert "add default behaviour for argmax" This reverts commit 64e9f1a9e3c9cabf312b8d80b3520b22da31c0b6. --- src/operator/tensor/broadcast_reduce_op.h | 6 +++--- tests/python/unittest/test_ndarray.py | 26 +---------------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index f22f263b6ee9..1edcb5a74a77 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -88,11 +88,11 @@ struct ReduceAxisParam : public dmlc::Parameter { dmlc::optional axis; bool keepdims; DMLC_DECLARE_PARAMETER(ReduceAxisParam) { - DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) .describe("The axis along which to perform the reduction. " "Negative values means indexing from right to left. " - "``The axis need to be set as an int. If the axis is " - "not set, the rightmost axis will be reduced.``"); + "``Requires axis to be set as int, because global reduction " + "is not supported yet.``"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axis is left " "in the result as dimension with size one."); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 897079dc39c8..0aa48553901b 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -477,6 +477,7 @@ def test_dot(): C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True) assert_almost_equal(c, C.asnumpy(), atol=atol) + @with_seed() def test_reduce(): sample_num = 200 @@ -523,31 +524,6 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin), mx.nd.argmin, False) -@with_seed() -def test_argmax_argmin(): - # test optional parameters - # test name : input data, argmax result, argmin result - tests = { - 'axis_0' : [[[1, 2, 3], [4, 5, 6]], [1, 1, 1], [0, 0, 0]], - 'keep_dims' : [[[1, 2, 3], [4, 5, 6]], [[2], [2]], [[0], [0]]], - 'axis_none' : [[1, 2, 3, 4], 3, 0] - } - - arg_max = mx.nd.array(tests['axis_0'][0]).argmax(axis=0) - arg_min = mx.nd.array(tests['axis_0'][0]).argmin(axis=0) - assert_almost_equal(arg_max.asnumpy(), tests['axis_0'][1]) - assert_almost_equal(arg_min.asnumpy(), tests['axis_0'][2]) - - arg_max = mx.nd.array(tests['keep_dims'][0]).argmax(axis=1, keepdims=True) - arg_min = mx.nd.array(tests['keep_dims'][0]).argmin(axis=1, keepdims=True) - assert_almost_equal(arg_max.asnumpy(), tests['keep_dims'][1]) - assert_almost_equal(arg_min.asnumpy(), tests['keep_dims'][2]) - - arg_max = mx.nd.array(tests['axis_none'][0]).argmax() - arg_min = mx.nd.array(tests['axis_none'][0]).argmin() - assert_almost_equal(arg_max.asnumpy(), tests['axis_none'][1]) - assert_almost_equal(arg_min.asnumpy(), tests['axis_none'][2]) - @with_seed() def test_broadcast(): sample_num = 1000 From 5545e67e29e1a0834ad7530f0b2810c683ac8ba3 Mon Sep 17 00:00:00 2001 From: rondogency Date: Sun, 13 Jan 2019 00:26:02 -0500 Subject: [PATCH 06/12] move erfinv to contrib --- src/operator/contrib/erfinv-inl.h | 128 ++++++++++++++++++++++++++++++ src/operator/mshadow_op.h | 88 +------------------- 2 files changed, 130 insertions(+), 86 deletions(-) create mode 100644 src/operator/contrib/erfinv-inl.h diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h new file mode 100644 index 000000000000..178fbfc074a1 --- /dev/null +++ b/src/operator/contrib/erfinv-inl.h @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file erfinv-inl.h + * \brief + * \author Ziyi Mu +*/ + +/* The next function is taken from +https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. +Below is the copyright. Output was modified to be inf or -inf when input is 1 or -1. + + Copyright (c) 2014 Indiana University + All rights reserved. + Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + Indiana University, Bloomington, IN + This software is licensed under the New BSD license: + Redistribution and use in source and binary forms, + with or without modification, are permitted provided + that the following conditions are met: + Redistributions of source code must retain the above + copyright notice, this list of conditions and the + following disclaimer. + Redistributions in binary form must reproduce the + above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or + other materials provided with the distribution. + Neither the name of Indiana University nor + the names of its contributors may be used to endorse + or promote products derived from this software without + specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ + +#include +#include "math.h" + +namespace mxnet { +namespace op { +namespace mshadow_op { + +#define CENTRAL_RANGE 0.7 + +/*! \brief inverse gauss error function */ +struct erfinv : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType v) { + /* Function to calculate inverse error function. Rational approximation + is used to generate an initial approximation, which is then improved to + full accuracy by two steps of Newton's method. Code is a direct + translation of the erfinv m file in matlab version 2.0. + Author: Gary L. Pavlis, Indiana University + Date: February 1996 + */ + double y = static_cast(v); + /*working variables */ + double x = 0.0; + double z, num, dem; + /* coefficients in rational expansion */ + double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; + double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + double d[2]={ 3.543889200, 1.637067800}; + if (fabs(y) > 1.0) { + /* This needs IEEE constant*/ + return DType(atof("NaN")); + } else if (fabs(y) == 1.0) { + return DType((copysign(1.0, y))*atof("INFINITY")); + } else if (fabs(y) <= CENTRAL_RANGE) { + z = y*y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); + x = y*num/dem; + } else { + z = sqrt(-log((1.0-fabs(y))/2.0)); + num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; + dem = (d[1]*z + d[0])*z + 1.0; + x = (copysign(1.0, y))*num/dem; + } + /* Two steps of Newton-Raphson correction */ + x = x - (erf(x) - y)/((2.0/sqrt(M_PI))*exp(-x*x)); + x = x - (erf(x) - y)/((2.0/sqrt(M_PI))*exp(-x*x)); + + return DType(x); + } +}; +#undef CENTRAL_RANGE + +} // namespace mshadow_op +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 9199c430beed..f56436b8fa0c 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -31,6 +31,7 @@ #include "math_functions-inl.h" #include "special_functions-inl.h" #include "./operator_tune.h" +#include "./contrib/erfinv-inl.h" #ifdef __CUDACC__ #include @@ -169,92 +170,7 @@ struct softrelu : public mxnet_op::tunable { MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a)); - -/* The next function is taken from -https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. -Below is the copyright. Output was modified to be inf or -inf when input is 1 or -1. - - Copyright (c) 2014 Indiana University - All rights reserved. - Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., - Indiana University, Bloomington, IN - This software is licensed under the New BSD license: - Redistribution and use in source and binary forms, - with or without modification, are permitted provided - that the following conditions are met: - Redistributions of source code must retain the above - copyright notice, this list of conditions and the - following disclaimer. - Redistributions in binary form must reproduce the - above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or - other materials provided with the distribution. - Neither the name of Indiana University nor - the names of its contributors may be used to endorse - or promote products derived from this software without - specific prior written permission. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND - CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED - WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A - PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL - THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY - DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF - USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER - IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. -*/ - -#define CENTRAL_RANGE 0.7 -/*! \brief inverse gauss error function */ -struct erfinv : public mxnet_op::tunable { - template - MSHADOW_XINLINE static DType Map(DType v) { - /* Function to calculate inverse error function. Rational approximation - is used to generate an initial approximation, which is then improved to - full accuracy by two steps of Newton's method. Code is a direct - translation of the erfinv m file in matlab version 2.0. - Author: Gary L. Pavlis, Indiana University - Date: February 1996 - */ - double y = static_cast(v); - double x,z,num,dem; /*working variables */ - /* coefficients in rational expansion */ - double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; - double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; - double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; - double d[2]={ 3.543889200, 1.637067800}; - if(fabs(y) > 1.0) return DType(atof("NaN")); /* This needs IEEE constant*/ - if(fabs(y) == 1.0) return DType((copysign(1.0,y))*atof("INFINITY")); - if(fabs(y) <= CENTRAL_RANGE ) - { - z = y*y; - num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); - dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); - x = y*num/dem; - } - else if( (fabs(y) > CENTRAL_RANGE) && (fabs(y) < 1.0) ) - { - z = sqrt(-log((1.0-fabs(y))/2.0)); - num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; - dem = (d[1]*z + d[0])*z + 1.0; - x = (copysign(1.0,y))*num/dem; - } - /* Two steps of Newton-Raphson correction */ - x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); - x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); - - return DType(x); - } -}; -#undef CENTRAL_RANGE - -MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(mshadow_op::erfinv::Map(a)))); +MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(erfinv::Map(a)))); MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); From 722a5733e99a60298997f2a522853a7a2034b08f Mon Sep 17 00:00:00 2001 From: rondogency Date: Sun, 13 Jan 2019 18:22:58 -0500 Subject: [PATCH 07/12] edit copyright --- src/operator/contrib/erfinv-inl.h | 32 +++++-------------------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h index 178fbfc074a1..dafbd244471f 100644 --- a/src/operator/contrib/erfinv-inl.h +++ b/src/operator/contrib/erfinv-inl.h @@ -1,29 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file erfinv-inl.h - * \brief - * \author Ziyi Mu -*/ - /* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. Below is the copyright. Output was modified to be inf or -inf when input is 1 or -1. @@ -67,6 +41,9 @@ Below is the copyright. Output was modified to be inf or -inf when input is 1 or #ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ #define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ +#define _USE_MATH_DEFINES +#define CENTRAL_RANGE 0.7 + #include #include "math.h" @@ -74,7 +51,8 @@ namespace mxnet { namespace op { namespace mshadow_op { -#define CENTRAL_RANGE 0.7 +using std::copysign; +using std::atof; /*! \brief inverse gauss error function */ struct erfinv : public mxnet_op::tunable { From d878e0708f70552cddb178f3617b8065b9bad378 Mon Sep 17 00:00:00 2001 From: rondogency Date: Sun, 13 Jan 2019 22:40:59 -0500 Subject: [PATCH 08/12] remove atof --- src/operator/contrib/erfinv-inl.h | 85 ++++++++++++++++--------------- 1 file changed, 43 insertions(+), 42 deletions(-) diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h index dafbd244471f..40c36929e95a 100644 --- a/src/operator/contrib/erfinv-inl.h +++ b/src/operator/contrib/erfinv-inl.h @@ -1,43 +1,44 @@ -/* The next function is taken from -https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. -Below is the copyright. Output was modified to be inf or -inf when input is 1 or -1. - - Copyright (c) 2014 Indiana University - All rights reserved. - Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., - Indiana University, Bloomington, IN - This software is licensed under the New BSD license: - Redistribution and use in source and binary forms, - with or without modification, are permitted provided - that the following conditions are met: - Redistributions of source code must retain the above - copyright notice, this list of conditions and the - following disclaimer. - Redistributions in binary form must reproduce the - above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or - other materials provided with the distribution. - Neither the name of Indiana University nor - the names of its contributors may be used to endorse - or promote products derived from this software without - specific prior written permission. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND - CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED - WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A - PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL - THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY - DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF - USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER - IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. -*/ - +/* + * Copyright (c) 2014 Indiana University + * All rights reserved. + * Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + * Indiana University, Bloomington, IN + * This software is licensed under the New BSD license: + * Redistribution and use in source and binary forms, + * with or without modification, are permitted provided + * that the following conditions are met: + * Redistributions of source code must retain the above + * copyright notice, this list of conditions and the + * following disclaimer. + * Redistributions in binary form must reproduce the + * above copyright notice, this list of conditions and + * the following disclaimer in the documentation and/or + * other materials provided with the distribution. + * Neither the name of Indiana University nor + * the names of its contributors may be used to endorse + * or promote products derived from this software without + * specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + * THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +/* + * The next function is taken from + * https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. + * Output was modified to be inf or -inf when input is 1 or -1. + */ #ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ #define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ @@ -76,9 +77,9 @@ struct erfinv : public mxnet_op::tunable { double d[2]={ 3.543889200, 1.637067800}; if (fabs(y) > 1.0) { /* This needs IEEE constant*/ - return DType(atof("NaN")); + return DType(10000); } else if (fabs(y) == 1.0) { - return DType((copysign(1.0, y))*atof("INFINITY")); + return DType((copysign(1.0, y))*10000); } else if (fabs(y) <= CENTRAL_RANGE) { z = y*y; num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); From cb3a2170ad4d010bd8ce18e3800d185332bcfa11 Mon Sep 17 00:00:00 2001 From: rondogency Date: Wed, 16 Jan 2019 00:38:28 -0500 Subject: [PATCH 09/12] use std and update license --- src/operator/contrib/erfinv-inl.h | 16 +++++++--------- tools/license_header.py | 1 + 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h index 40c36929e95a..006a0f008cba 100644 --- a/src/operator/contrib/erfinv-inl.h +++ b/src/operator/contrib/erfinv-inl.h @@ -46,15 +46,13 @@ #define CENTRAL_RANGE 0.7 #include +#include #include "math.h" namespace mxnet { namespace op { namespace mshadow_op { -using std::copysign; -using std::atof; - /*! \brief inverse gauss error function */ struct erfinv : public mxnet_op::tunable { template @@ -77,23 +75,23 @@ struct erfinv : public mxnet_op::tunable { double d[2]={ 3.543889200, 1.637067800}; if (fabs(y) > 1.0) { /* This needs IEEE constant*/ - return DType(10000); + return DType(std::numeric_limits::quiet_NaN()); } else if (fabs(y) == 1.0) { - return DType((copysign(1.0, y))*10000); + return DType((std::copysign(1.0, y))*std::numeric_limits::infinity()); } else if (fabs(y) <= CENTRAL_RANGE) { z = y*y; num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); x = y*num/dem; } else { - z = sqrt(-log((1.0-fabs(y))/2.0)); + z = std::sqrt(-std::log((1.0-fabs(y))/2.0)); num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; dem = (d[1]*z + d[0])*z + 1.0; - x = (copysign(1.0, y))*num/dem; + x = (std::copysign(1.0, y))*num/dem; } /* Two steps of Newton-Raphson correction */ - x = x - (erf(x) - y)/((2.0/sqrt(M_PI))*exp(-x*x)); - x = x - (erf(x) - y)/((2.0/sqrt(M_PI))*exp(-x*x)); + x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x)); + x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x)); return DType(x); } diff --git a/tools/license_header.py b/tools/license_header.py index 10ba8b909e70..e5e24aa18fc8 100755 --- a/tools/license_header.py +++ b/tools/license_header.py @@ -66,6 +66,7 @@ 'src/operator/mkl/', 'src/operator/special_functions-inl.h', 'src/operator/nn/pool.h', + 'src/operator/contrib/erfinv-inl.h', 'src/operator/contrib/psroi_pooling-inl.h', 'src/operator/contrib/nn/deformable_im2col.h', 'src/operator/contrib/nn/deformable_im2col.cuh', From 9cc6c467af0f99b560c44f2f444ad02ace249a94 Mon Sep 17 00:00:00 2001 From: rondogency Date: Fri, 18 Jan 2019 13:25:00 -0500 Subject: [PATCH 10/12] add license exclude file --- tests/nightly/apache_rat_license_check/rat-excludes | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/nightly/apache_rat_license_check/rat-excludes b/tests/nightly/apache_rat_license_check/rat-excludes index a488eb84d069..efc32196b0d2 100755 --- a/tests/nightly/apache_rat_license_check/rat-excludes +++ b/tests/nightly/apache_rat_license_check/rat-excludes @@ -60,4 +60,5 @@ deformable_im2col.h REQUIRE include/* */test/test-symbol.json.ref -*/profiler/test/profile-matmul-20iter.json.ref \ No newline at end of file +*/profiler/test/profile-matmul-20iter.json.ref +erfinv-inl.h \ No newline at end of file From b012d8447fa64304b91ff09979e247da382770a5 Mon Sep 17 00:00:00 2001 From: rondogency Date: Mon, 21 Jan 2019 00:13:34 -0500 Subject: [PATCH 11/12] fix per eric's comments --- src/operator/contrib/erfinv-inl.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h index 006a0f008cba..8d718ade6562 100644 --- a/src/operator/contrib/erfinv-inl.h +++ b/src/operator/contrib/erfinv-inl.h @@ -43,7 +43,6 @@ #define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ #define _USE_MATH_DEFINES -#define CENTRAL_RANGE 0.7 #include #include @@ -64,7 +63,9 @@ struct erfinv : public mxnet_op::tunable { Author: Gary L. Pavlis, Indiana University Date: February 1996 */ + const double central_range = 0.7; double y = static_cast(v); + double y_fab = std::fabs(y); /*working variables */ double x = 0.0; double z, num, dem; @@ -73,18 +74,18 @@ struct erfinv : public mxnet_op::tunable { double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; double d[2]={ 3.543889200, 1.637067800}; - if (fabs(y) > 1.0) { + if (y_fab > 1.0) { /* This needs IEEE constant*/ return DType(std::numeric_limits::quiet_NaN()); - } else if (fabs(y) == 1.0) { + } else if (y_fab == 1.0) { return DType((std::copysign(1.0, y))*std::numeric_limits::infinity()); - } else if (fabs(y) <= CENTRAL_RANGE) { + } else if (y_fab <= central_range) { z = y*y; num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); x = y*num/dem; } else { - z = std::sqrt(-std::log((1.0-fabs(y))/2.0)); + z = std::sqrt(-std::log((1.0-y_fab)/2.0)); num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; dem = (d[1]*z + d[0])*z + 1.0; x = (std::copysign(1.0, y))*num/dem; @@ -96,7 +97,6 @@ struct erfinv : public mxnet_op::tunable { return DType(x); } }; -#undef CENTRAL_RANGE } // namespace mshadow_op } // namespace op From 647747b2c2e4ef95c00457686d79892948247713 Mon Sep 17 00:00:00 2001 From: rondogency Date: Mon, 21 Jan 2019 15:45:12 -0500 Subject: [PATCH 12/12] change license header --- tools/license_header.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/license_header.py b/tools/license_header.py index 6d55c4f6de19..11cc92839993 100755 --- a/tools/license_header.py +++ b/tools/license_header.py @@ -77,7 +77,6 @@ # Licensed under Caffe header 'src/operator/nn/pool.h', - 'src/operator/contrib/erfinv-inl.h', 'src/operator/contrib/psroi_pooling-inl.h', 'src/operator/contrib/nn/deformable_im2col.h', 'src/operator/contrib/nn/deformable_im2col.cuh', @@ -85,6 +84,7 @@ 'src/operator/nn/im2col.cuh', # Licenses in headers + 'src/operator/contrib/erfinv-inl.h', 'docs/_static/searchtools_custom.js', 'docs/_static/js/clipboard.js', 'docs/_static/js/clipboard.min.js',