Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add magic method abs to NDArray and Symbol. (#15680)
Browse files Browse the repository at this point in the history
* add magic method abs to ndarray

* add relevant tests

* add magic method abs to symbol

* add relevant tests

* retrigger CI

* retrigger CI
  • Loading branch information
kshitij12345 authored and wkcn committed Aug 2, 2019
1 parent b07211f commit cf28b46
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def _to_shared_mem(self):
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
return shared_pid.value, shared_id.value, self.shape, self.dtype

def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
return self.abs()

def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __iter__(self):
"""
return (self[i] for i in range(len(self)))

def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
return self.abs()

def __add__(self, other):
"""x.__add__(y) <=> x+y
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ def test_ndarray_negate():
assert_almost_equal(npy, arr.asnumpy())


@with_seed()
def test_ndarray_magic_abs():
for dim in range(1, 7):
shape = rand_shape_nd(dim)
npy = np.random.uniform(-10, 10, shape)
arr = mx.nd.array(npy)
assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())


@with_seed()
def test_ndarray_reshape():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
from common import assertRaises, models
from mxnet.base import NotImplementedForSymbol
from mxnet.test_utils import discard_stderr
from mxnet.test_utils import discard_stderr, rand_shape_nd
import pickle as pkl

def test_symbol_basic():
Expand Down Expand Up @@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
assert arg_shapes[1] == overwrite_shape
assert out_shapes[0] == overwrite_shape


def test_symbol_magic_abs():
for dim in range(1, 7):
with mx.name.NameManager():
data = mx.symbol.Variable('data')
method = data.abs(name='abs0')
magic = abs(data)
regular = mx.symbol.abs(data, name='abs0')
ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)}
mx.test_utils.check_consistency(
[method, magic], ctx_list=[ctx, ctx])
mx.test_utils.check_consistency(
[regular, magic], ctx_list=[ctx, ctx])


def test_symbol_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
Expand Down

0 comments on commit cf28b46

Please sign in to comment.