Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix slice op issues #13760 #14403

Merged
merged 1 commit into from
Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 29 additions & 36 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,50 +653,43 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}

for (index_t i = 0; i < param_begin.ndim(); ++i) {
index_t b = 0, e = dshape[i], s = 1;
const index_t len = dshape[i];
if (param_step.ndim() != 0U) {
const auto& opt_step_val = param_step[i];
if (opt_step_val.has_value()) {
s = opt_step_val.value();
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
}
}
index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";

if (len) {
if (param_begin[i].has_value()) {
b = param_begin[i].value();
if (b < 0) {
b += len;
CHECK_GE(b, 0) << "slicing with begin[" << i << "]="
<< b - len << " exceeds limit of " << len;
}
} else if (s < 0) {
b = len - 1;
index_t b = 0, e = 0;
const index_t len = dshape[i];
if (len > 0) {
b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0);
e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len);

// checking upper and lower bounds for begin
if (b < 0) {
b += len;
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LT(b, len) << "slicing with begin[" << i << "]="
<< b << " exceends limit of " << len;

if (param_end[i].has_value()) {
e = param_end[i].value();
if (e < 0) {
e += len;
CHECK_GE(e, 0) << "slicing with end[" << i << "]="
<< e - len << " exceeds limit of " << len;
}
} else if (s < 0) {
e = -1;
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
<< " exceeds limit of input dimension[" << i << "]=" << len;

// checking upper and lower bounds for end
if (e < 0 && param_end[i].has_value()) {
e += len;
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LE(e, len) << "slicing with end[" << i << "]="
<< e << " exceeds limit of " << len;
} else {
b = 0;
e = 0;
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
<< " exceeds limit of input dimension[" << i << "]=" << len;

// checking begin==end case which is not supported
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
<< e << " results in an empty tensor and is not supported";
}

(*begin)[i] = b;
(*end)[i] = e;
(*step)[i] = s;
}

for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) {
(*begin)[i] = 0;
(*end)[i] = dshape[i];
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6606,6 +6606,15 @@ def test_slice_forward_backward(a, index):
for index in index_list:
test_slice_forward_backward(arr, index)

def test_begin_equals_end(shape, begin, end, step):
in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)

assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,))
assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1))
assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2))
assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1))

# check numeric gradient
in_data = np.arange(36).reshape(2, 2, 3, 3)
data = mx.sym.Variable('data')
Expand Down