diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 1c84af0c668e..910b6ca15499 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -479,3 +479,13 @@ def test_infershape_happens_for_all_ops_in_graph(): assert False +def test_symbol_copy(): + a = mx.sym.Variable('a') + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b' + + a = mx.sym.Variable('a').as_np_ndarray() + b = copy.copy(a) + b._set_attr(name='b') + assert a.name == 'a' and b.name == 'b'