Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Modify ndarray slice to have numpy compatbile behaviou
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Mao committed Aug 9, 2019
1 parent a3babc4 commit 18792d5
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 25 deletions.
79 changes: 54 additions & 25 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,13 +685,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
<< "Static array size=" << ndim
<< " is not equal to data shape ndim=" << dshape.ndim();

if (param_step.ndim() != 0) {
if (param_step.ndim() > 0) {
CHECK_EQ(param_step.ndim(), param_begin.ndim())
<< "step and begin must have the same length";
}

for (int i = 0; i < param_begin.ndim(); ++i) {
index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
index_t s = param_step.ndim() > 0 && param_step[i].has_value() ? param_step[i].value() : 1;
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";

index_t b = 0, e = 0;
Expand All @@ -703,29 +703,44 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
// checking upper and lower bounds for begin
if (b < 0) {
b += len;
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
if (!Imperative::Get()->is_np_shape()) {
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
}
if (!Imperative::Get()->is_np_shape()) {
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
<< " exceeds limit of input dimension[" << i << "]=" << len;

// checking upper and lower bounds for end
if (e < 0 && param_end[i].has_value()) {
if (!(s < 0 && e == -1)) {
// Keep end=-1 as one-beyond-limits index for negative stride
e += len;
e += len;
if (!Imperative::Get()->is_np_shape()) {
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
<< " exceeds limit of input dimension[" << i << "]=" << len;
if (!Imperative::Get()->is_np_shape()) {
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
<< " exceeds limit of input dimension[" << i << "]=" << len;
}

// checking begin==end case which is not supported
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
<< e << " results in an empty tensor and is not supported";
if (!Imperative::Get()->is_np_shape()) {
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
<< e << " results in an empty tensor and is not supported";
}
}

if (Imperative::Get()->is_np_shape()) {
// move the begin and end to correct position for calculating dim size
b = b < 0 && s > 0 ? 0 : b;
b = b > len-1 && s < 0 ? len-1 : b;
// if the start value lead to empty tensor under step s, use -1 for indication
b = b < 0 || b > len-1 ? -1 : b;
e = e > -1 ? e : -1;
e = e > len ? len : e;
}
(*begin)[i] = b;
(*end)[i] = e;
(*step)[i] = s;
Expand All @@ -741,17 +756,29 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
inline void SetSliceOpOutputDimSize(const index_t i, const int b,
const int e, const int s,
mxnet::TShape* oshape) {
if (e != b) {
if (s > 0) {
CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (e - b - 1) / s + 1;
if (!Imperative::Get()->is_np_shape()) { //handle as ndarray
if (e != b) {
if (s > 0) {
CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (e - b - 1) / s + 1;
} else {
CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (b - e - 1) / (-s) + 1;
}
} // else leave oshape[i] as 0 for partial infer
} else { //handle as numpy compatible array
if (e != b && b >= 0) {
if (s > 0) {
(*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0;
} else {
(*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0;
}
} else {
CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (b - e - 1) / (-s) + 1;
(*oshape)[i] = 0;
}
} // else leave oshape[i] as 0 for partial infer
}
}

inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -852,6 +879,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
Stream<xpu>* s = ctx.get_stream<xpu>();
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
if (Imperative::Get()->is_np_shape() && out.Size() == 0) return;
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
Expand Down Expand Up @@ -951,6 +979,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
} else if (req[0] == kWriteInplace) {
LOG(FATAL) << "_slice_backward does not support kWriteInplace";
}
if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return;
MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ Example::
[5., 7.],
[1., 3.]]
)code" ADD_FILELINE)
.add_alias("_npx_slice")
.set_attr_parser(ParamParser<SliceParam>)
.set_attr<mxnet::FInferShape>("FInferShape", SliceOpShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
Expand Down
72 changes: 72 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,78 @@ def is_int(dtype):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_npx_slice():
class TestSlice(HybridBlock):
def __init__(self, begin, end, step):
super(TestSlice, self).__init__()
self._begin = begin
self._end = end
self._step = step

def hybrid_forward(self, F, a, *args, **kwargs):
return F.npx.slice(a, begin=self._begin, end=self._end, step=self._step)

def get_start_end_step(shape):
start = []
end = []
step_switch = random.randint(-1,1)
step = None if step_switch == 0 else []
for i in range(len(shape)):
s = random.randint(0, shape[i]-1)
e = random.randint(s+1, shape[i])
if step_switch == 1:
step.append(1)
start.append(s)
end.append(e)
elif step_switch == -1:
step.append(-1)
if e == shape[i]:
e -= 1
s -= 1
if s == -1:
s = None
start.append(e)
end.append(s)
else:
start.append(s)
end.append(e)
return start, end, step

for hybridize in [True, False]:
for i in range(10):
dim = random.randint(1,4)
shape = [random.randint(1,5) for i in range(dim)]

# test gluon
start, end, step = get_start_end_step(shape)
test_slice = TestSlice(begin=start, end=end, step=step)
if hybridize:
test_slice.hybridize()

a = mx.nd.random.uniform(shape=shape).as_np_ndarray()
a.attach_grad()
if step is not None:
expected_ret = a.as_nd_ndarray().slice(start, end, step)
else:
expected_ret = a.as_nd_ndarray().slice(start, end)
with mx.autograd.record():
y = test_slice(a)

assert_almost_equal(y.asnumpy(), expected_ret.asnumpy(), rtol=1e-3, atol=1e-5)

# test backward
mx.autograd.backward(y)
expected_grad = _np.zeros(shape)
basic_index = tuple([
slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i])
for i in range(len(start))
])
expected_grad[basic_index] = 1
assert_almost_equal(a.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 18792d5

Please sign in to comment.