Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add magic method abs to NDArray and Symbol. #15680

Merged
merged 7 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def _to_shared_mem(self):
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
return shared_pid.value, shared_id.value, self.shape, self.dtype

def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
return self.abs()

def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __iter__(self):
"""
return (self[i] for i in range(len(self)))

def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
return self.abs()

def __add__(self, other):
"""x.__add__(y) <=> x+y

Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ def test_ndarray_negate():
assert_almost_equal(npy, arr.asnumpy())


@with_seed()
def test_ndarray_magic_abs():
for dim in range(1, 7):
shape = rand_shape_nd(dim)
npy = np.random.uniform(-10, 10, shape)
arr = mx.nd.array(npy)
assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())


@with_seed()
def test_ndarray_reshape():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
from common import assertRaises, models
from mxnet.base import NotImplementedForSymbol
from mxnet.test_utils import discard_stderr
from mxnet.test_utils import discard_stderr, rand_shape_nd
import pickle as pkl

def test_symbol_basic():
Expand Down Expand Up @@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
assert arg_shapes[1] == overwrite_shape
assert out_shapes[0] == overwrite_shape


def test_symbol_magic_abs():
for dim in range(1, 7):
with mx.name.NameManager():
data = mx.symbol.Variable('data')
method = data.abs(name='abs0')
magic = abs(data)
regular = mx.symbol.abs(data, name='abs0')
ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)}
mx.test_utils.check_consistency(
[method, magic], ctx_list=[ctx, ctx])
mx.test_utils.check_consistency(
[regular, magic], ctx_list=[ctx, ctx])


def test_symbol_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
Expand Down