From a3f995f25406afb794ed58a15191c0b4f708afa6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 9 Feb 2019 20:52:33 +0000 Subject: [PATCH] begin=end not a valid input refactoring logic for indexing --- src/operator/tensor/matrix_op-inl.h | 65 ++++++++++++-------------- tests/python/unittest/test_operator.py | 9 ++++ 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 3a58c1200ae0..5eecda622729 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -653,50 +653,43 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } for (index_t i = 0; i < param_begin.ndim(); ++i) { - index_t b = 0, e = dshape[i], s = 1; - const index_t len = dshape[i]; - if (param_step.ndim() != 0U) { - const auto& opt_step_val = param_step[i]; - if (opt_step_val.has_value()) { - s = opt_step_val.value(); - CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0"; - } - } + index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1; + CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0"; - if (len) { - if (param_begin[i].has_value()) { - b = param_begin[i].value(); - if (b < 0) { - b += len; - CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" - << b - len << " exceeds limit of " << len; - } - } else if (s < 0) { - b = len - 1; + index_t b = 0, e = 0; + const index_t len = dshape[i]; + if (len > 0) { + b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0); + e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len); + + // checking upper and lower bounds for begin + if (b < 0) { + b += len; + CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len + << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_LT(b, len) << "slicing with begin[" << i << "]=" - << b << " exceends limit of " << len; - - if (param_end[i].has_value()) { - e = param_end[i].value(); - if (e < 0) { - e += len; - CHECK_GE(e, 0) << "slicing with end[" << i << "]=" - << e - len << " exceeds limit of " << len; - } - } else if (s < 0) { - e = -1; + CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b + << " exceeds limit of input dimension[" << i << "]=" << len; + + // checking upper and lower bounds for end + if (e < 0 && param_end[i].has_value()) { + e += len; + CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len + << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_LE(e, len) << "slicing with end[" << i << "]=" - << e << " exceeds limit of " << len; - } else { - b = 0; - e = 0; + CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e + << " exceeds limit of input dimension[" << i << "]=" << len; + + // checking begin==end case which is not supported + CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" + << e << " results in an empty tensor and is not supported"; } + (*begin)[i] = b; (*end)[i] = e; (*step)[i] = s; } + for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) { (*begin)[i] = 0; (*end)[i] = dshape[i]; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7169395205e0..f4d2ef32cc2e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6606,6 +6606,15 @@ def test_slice_forward_backward(a, index): for index in index_list: test_slice_forward_backward(arr, index) + def test_begin_equals_end(shape, begin, end, step): + in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape) + out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step) + + assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,)) + assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1)) + assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2)) + assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1)) + # check numeric gradient in_data = np.arange(36).reshape(2, 2, 3, 3) data = mx.sym.Variable('data')