Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add gammaln, erf, erfinv (#16811)
Browse files Browse the repository at this point in the history
  • Loading branch information
xidulu authored and haojin2 committed Nov 14, 2019
1 parent 017f6fa commit 5139e2c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ The storage type of ``fix`` output depends upon the input storage type:

// erf
MXNET_OPERATOR_REGISTER_UNARY(erf)
.add_alias("_npx_erf")
.describe(R"code(Returns element-wise gauss error function of the input.
Example::
Expand All @@ -898,6 +899,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_erf)

// erfinv
MXNET_OPERATOR_REGISTER_UNARY(erfinv)
.add_alias("_npx_erfinv")
.describe(R"code(Returns element-wise inverse gauss error function of the input.
Example::
Expand Down Expand Up @@ -929,6 +931,7 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_gamma,

// gammaln
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(gammaln, cpu, mshadow_op::gammaln)
.add_alias("_npx_gammaln")
MXNET_ADD_SPARSE_OP_ALIAS(gammaln)
.describe(R"code(Returns element-wise log of the absolute value of the gamma function \
of the input.
Expand Down
54 changes: 54 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import platform
import mxnet as mx
import scipy.stats as ss
import scipy.special as scipy_special
from nose.tools import assert_raises
from mxnet import np, npx
from mxnet.gluon import HybridBlock
Expand Down Expand Up @@ -2550,6 +2551,59 @@ def _test_bernoulli_exception(prob, logit):
assertRaises(MXNetError, _test_bernoulli_exception, scaled_prob, None)


@with_seed()
@use_np
def test_npx_special_unary_func():
def check_unary_func(func, ref_grad, shape, low, high):
class TestUnary(HybridBlock):
def __init__(self, func):
super(TestUnary, self).__init__()
self._func = func

def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.npx, self._func)(a)

np_func = getattr(scipy_special, func)
mx_func = TestUnary(func)
np_test_data = _np.random.uniform(low, high, shape).astype(_np.float32)
mx_test_data = mx.numpy.array(np_test_data)
for hybridize in [True, False]:
if hybridize:
mx_func.hybridize()
if ref_grad:
mx_test_data.attach_grad()
np_out = np_func(np_test_data)
with mx.autograd.record():
y = mx_func(mx_test_data)
assert y.shape == np_out.shape
assert_almost_equal(y.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
if np_out.dtype == np.bool_:
assert y.dtype == np.bool_

if ref_grad:
y.backward()
assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-1, atol=1e-2, equal_nan=True)

np_out = getattr(scipy_special, func)(np_test_data)
mx_out = getattr(mx.npx, func)(mx_test_data)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

import math
funcs = {
'erf' : (lambda x: 2.0 / math.sqrt(math.pi) * _np.exp(-(x ** 2)), 0.5, 0.5),
'erfinv' : (lambda x: 0.5 * math.sqrt(math.pi) * _np.exp(scipy_special.erfinv(x) ** 2), 0.5, 0.5),
'gamma' : (lambda x: scipy_special.gamma(x) * scipy_special.psi(x), 0.5, 0.5),
'gammaln' : (lambda x: scipy_special.psi(x), 0.5, 0.5)
}
ndim = random.choice([2, 3, 4])
shape = random.choice([rand_shape_nd(ndim, dim=3), (1, 0, 2)])
for shape in [rand_shape_nd(ndim, dim=3), (1, 0, 2)]:
for func, func_data in funcs.items():
ref_grad, low, high = func_data
check_unary_func(func, ref_grad, shape, low, high)


@with_seed()
@use_np
def test_np_random():
Expand Down

0 comments on commit 5139e2c

Please sign in to comment.