From 1316651f46b173f22c4fe0ca02998067f02a6975 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 29 Jul 2019 20:41:19 +0530 Subject: [PATCH] add relevant tests --- tests/python/unittest/test_symbol.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 2dfe3e44eedb..ef470e54ce5e 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -22,7 +22,7 @@ import numpy as np from common import assertRaises, models from mxnet.base import NotImplementedForSymbol -from mxnet.test_utils import discard_stderr +from mxnet.test_utils import discard_stderr, rand_shape_nd import pickle as pkl def test_symbol_basic(): @@ -188,6 +188,18 @@ def test_symbol_infer_shape_var(): assert arg_shapes[1] == overwrite_shape assert out_shapes[0] == overwrite_shape + +def test_symbol_magic_abs(): + for dim in range(1, 7): + with mx.name.NameManager(): + data = mx.symbol.Variable('data') + regular = data.abs(name='abs0') + magic = abs(data) + ctx = {'ctx': mx.context.current_context(), 'data': rand_shape_nd(dim)} + mx.test_utils.check_consistency( + [regular, magic], ctx_list=[ctx, ctx]) + + def test_symbol_fluent(): has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod', 'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',