diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d3521cad1274..9b193f850a93 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -283,7 +283,7 @@ def __neg__(self): return negative(self) def __deepcopy__(self, _): - return super(_Symbol, self).as_np_ndarray() + return super().__deepcopy__(_).as_np_ndarray() def __eq__(self, other): """x.__eq__(y) <=> x == y""" diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 1c84af0c668e..910b6ca15499 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -479,3 +479,13 @@ def test_infershape_happens_for_all_ops_in_graph(): assert False +def test_symbol_copy(): + a = mx.sym.Variable('a') + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b' + + a = mx.sym.Variable('a').as_np_ndarray() + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b'