diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7b5b9ebf3be4..ca407da244a8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6484,14 +6484,14 @@ def check_slice_axis_partial_infer(data, axis, begin, end, expected_out_shape): out = mx.sym.slice_axis(data, axis=axis, begin=begin, end=end) assert (out.infer_shape_partial()[1][0] == expected_out_shape), out.infer_shape_partial()[1] - var1 = mx.sym.var(name="data", shape=(0, 20)) - check_slice_partial_infer(var1, (None, None), (None, 10), [], (0, 10)) - check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (0, 5)) - check_slice_partial_infer(var1, (None, 3), (None, 10), [], (0, 7)) - check_slice_partial_infer(var1, (None, 3), (5, 10), [], (0, 7)) - check_slice_partial_infer(var1, (2, 3), (None, 10), [], (0, 7)) - check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (0, 7)) - check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (0, 3)) + var1 = mx.sym.var(name="data", shape=(10, 20)) + check_slice_partial_infer(var1, (None, None), (None, 10), [], (10, 10)) + check_slice_partial_infer(var1, (None, None), (None, 10), (None, 2), (10, 5)) + check_slice_partial_infer(var1, (None, 3), (None, 10), [], (10, 7)) + check_slice_partial_infer(var1, (None, 3), (5, 10), [], (5, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), [], (8, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (None, 1), (8, 7)) + check_slice_partial_infer(var1, (2, 3), (None, 10), (3, 3), (8, 3)) var1 = mx.sym.var(name="data", shape=(10, 0)) check_slice_axis_partial_infer(var1, 0, 0, 5, (5, 0))