diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 891511b86d8d..0e07c3782922 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5333,19 +5333,21 @@ def test_boolean_mask(): assert same(data.grad.asnumpy(), expected_grad) # test 0-size output - mx.set_np_shape(True) - data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) - index = mx.nd.array([0, 0, 0]) - data.attach_grad() - with mx.autograd.record(): - out = mx.nd.contrib.boolean_mask(data, index) - out.backward() - data.grad.wait_to_read() - expected = np.zeros((0, 3)) - expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) - assert same(out.asnumpy(), expected) - assert same(data.grad.asnumpy(), expected_grad) - mx.set_np_shape(False) + prev_np_shape = mx.set_np_shape(True) + try: + data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) + index = mx.nd.array([0, 0, 0]) + data.attach_grad() + with mx.autograd.record(): + out = mx.nd.contrib.boolean_mask(data, index) + out.backward() + data.grad.wait_to_read() + expected = np.zeros((0, 3)) + expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert same(out.asnumpy(), expected) + assert same(data.grad.asnumpy(), expected_grad) + finally: + mx.set_np_shape(prev_np_shape) # test gradient shape = (100, 30) @@ -9463,7 +9465,8 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads, def test_zero_sized_dim(): - mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown' + # Must be done to prevent zero-sized dimension conversion to 'unknown' + prev_np_shape = mx.util.set_np_shape(True) def seq_last(): """Test for issue: https://github.com/apache/incubator-mxnet/issues/18938""" @@ -9483,9 +9486,12 @@ def seq_reverse(): res = mx.nd.op.SequenceReverse(data) assert data.shape == res.shape - seq_last() - seq_reverse() - seq_mask() + try: + seq_last() + seq_reverse() + seq_mask() + finally: + mx.util.set_np_shape(prev_np_shape) @mx.util.use_np def test_take_grads(): diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 8e4370ea6466..9d1e529d4142 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -213,7 +213,7 @@ def g(): def test_np_global_shape(): - set_np_shape(2) + prev_np_shape = set_np_shape(2) data = [] def f(): @@ -229,4 +229,4 @@ def f(): assert_almost_equal(data[0].asnumpy(), np.ones(shape=())) assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2))) finally: - set_np_shape(0) + set_np_shape(prev_np_shape)