Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add relevant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Jul 29, 2019
1 parent d962df3 commit 6e64ee9
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
from common import assertRaises, models
from mxnet.base import NotImplementedForSymbol
from mxnet.test_utils import discard_stderr
from mxnet.test_utils import discard_stderr, rand_shape_nd
import pickle as pkl

def test_symbol_basic():
Expand Down Expand Up @@ -188,6 +188,21 @@ def test_symbol_infer_shape_var():
assert arg_shapes[1] == overwrite_shape
assert out_shapes[0] == overwrite_shape


def test_symbol_magic_abs():
for dim in range(1, 7):
with mx.name.NameManager():
data = mx.symbol.Variable('data')
method = data.abs(name='abs0')
magic = abs(data)
regular = mx.symbol.abs(data, name='abs0')
ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)}
mx.test_utils.check_consistency(
[method, magic], ctx_list=[ctx, ctx])
mx.test_utils.check_consistency(
[regular, magic], ctx_list=[ctx, ctx])


def test_symbol_fluent():
has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
Expand Down

0 comments on commit 6e64ee9

Please sign in to comment.