diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6f48d4912f9a..f4d2ef32cc2e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6606,15 +6606,14 @@ def test_slice_forward_backward(a, index): for index in index_list: test_slice_forward_backward(arr, index) - def test_begin_equals_end(): - shape = (2, 4) - begin = (None, 2) - end = (None, 2) - step = (1, -1) + def test_begin_equals_end(shape, begin, end, step): in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape) out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step) - assertRaises(MXNetError, test_begin_equals_end) + assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,)) + assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1)) + assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2)) + assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1)) # check numeric gradient in_data = np.arange(36).reshape(2, 2, 3, 3)