diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 3d8a7aa98c94..0b7dca4ebba4 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -205,6 +205,10 @@ def _to_shared_mem(self): self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id))) return shared_pid.value, shared_id.value, self.shape, self.dtype + def __abs__(self): + """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)""" + return self.abs() + def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ return add(self, other) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 1e2defab3713..68322297c8bb 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -93,6 +93,10 @@ def __iter__(self): """ return (self[i] for i in range(len(self))) + def __abs__(self): + """x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)""" + return self.abs() + def __add__(self, other): """x.__add__(y) <=> x+y diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 56db1ebd640d..0f154bd67a1a 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -172,6 +172,15 @@ def test_ndarray_negate(): assert_almost_equal(npy, arr.asnumpy()) +@with_seed() +def test_ndarray_magic_abs(): + for dim in range(1, 7): + shape = rand_shape_nd(dim) + npy = np.random.uniform(-10, 10, shape) + arr = mx.nd.array(npy) + assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy()) + + @with_seed() def test_ndarray_reshape(): tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 0c97c68b0880..963b32493b44 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -22,7 +22,7 @@ import numpy as np from common import assertRaises, models from mxnet.base import NotImplementedForSymbol -from mxnet.test_utils import discard_stderr +from mxnet.test_utils import discard_stderr, rand_shape_nd import pickle as pkl def test_symbol_basic(): @@ -188,6 +188,21 @@ def test_symbol_infer_shape_var(): assert arg_shapes[1] == overwrite_shape assert out_shapes[0] == overwrite_shape + +def test_symbol_magic_abs(): + for dim in range(1, 7): + with mx.name.NameManager(): + data = mx.symbol.Variable('data') + method = data.abs(name='abs0') + magic = abs(data) + regular = mx.symbol.abs(data, name='abs0') + ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)} + mx.test_utils.check_consistency( + [method, magic], ctx_list=[ctx, ctx]) + mx.test_utils.check_consistency( + [regular, magic], ctx_list=[ctx, ctx]) + + def test_symbol_fluent(): has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod', 'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',