Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix slice bug
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Nov 22, 2019
1 parent 458401a commit 3e9cf37
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 31 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def insert(arr, obj, values, axis=None):
if not isinstance(values, NDArray):
raise TypeError("'values' can not support type {}".format(str(type(values))))
if isinstance(obj, slice):
start = 0 if obj.start is None else obj.start
start = obj.start
stop = obj.stop
step = 1 if obj.step is None else obj.step
return _npi.insert(arr, values, start=start, stop=stop, step=step, axis=axis)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2570,7 +2570,7 @@ def insert(arr, obj, values, axis=None):
if not isinstance(values, ndarray): # pylint: disable= undefined-variable
raise TypeError("'values' can not support type {}".format(str(type(values))))
if isinstance(obj, slice):
start = 0 if obj.start is None else obj.start
start = obj.start
stop = obj.stop
step = 1 if obj.step is None else obj.step
return _npi.insert(arr, values, start=start, stop=stop, step=step, axis=axis)
Expand Down
34 changes: 21 additions & 13 deletions src/operator/numpy/np_insert_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,

const NumpyInsertParam& param = nnvm::get<NumpyInsertParam>(attrs.parsed);
CHECK_EQ(inputs.size(),
(param.stop.has_value() || param.int_ind.has_value()) ? 2U : 3U);
(param.step.has_value() || param.int_ind.has_value()) ? 2U : 3U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
Expand Down Expand Up @@ -385,17 +385,25 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,
// get and check indices from slice or sequence of ints
if (inputs.size() == 3U) {
indices_len = inputs[insert_::kObj].shape_.Size();
} else if (param.stop.has_value()) {
} else if (param.step.has_value()) {
step = param.step.value();
CHECK_NE(step, 0) << "'step' can not equal to 0.";
stop = param.stop.value();
stop += (stop < 0) ? N : 0;
stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop;
stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop;
start = param.start.value();
start += (start < 0) ? N : 0;
start = (start < 0) ? ((step < 0) ? -1 : 0) : start;
start = (start >= N) ? ((step < 0) ? N - 1 : N) : start;
if (param.stop.has_value()) {
stop = param.stop.value();
stop += (stop < 0) ? N : 0;
stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop;
stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop;
} else {
stop = (step > 0) ? N : -1;
}
if (param.start.has_value()) {
start = param.start.value();
start += (start < 0) ? N : 0;
start = (start < 0) ? ((step < 0) ? -1 : 0) : start;
start = (start >= N) ? ((step < 0) ? N - 1 : N) : start;
} else {
start = (step > 0) ? 0 : N - 1;
}
int seq_cnt = 0;
if (step > 0 && stop >= start) {
seq_cnt = (stop - start + step - 1) / step;
Expand Down Expand Up @@ -451,7 +459,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,
} else if (indices_len == 1) {
numnew = values.shape_[axis];
newshape[axis] += numnew;
if (param.start.has_value()) {
if (param.step.has_value()) {
index = start;
CHECK(index >= -1 * N && index <= N)
<< "Index should be in the range of [-r, r-1] where r is the dim size in 'axis'";
Expand Down Expand Up @@ -530,7 +538,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,
} else if (indices_len == 1) {
MSHADOW_TYPE_SWITCH(outputs[insert_::kOut].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[insert_::kOut], req_type, {
if (param.stop.has_value()) {
if (param.step.has_value()) {
Kernel<InsertSingleIndexForward<req_type>, xpu>::Launch(s, outshape.Size(),
outputs[insert_::kOut].dptr<DType>(),
values.dptr<DType>(), arr.dptr<DType>(),
Expand Down Expand Up @@ -589,7 +597,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, int> order(order_ptr, Shape1(indices_len), s);
int num_bits = common::ilog2ui(static_cast<unsigned int>(indices_len) - 1);

if (param.stop.has_value()) {
if (param.step.has_value()) {
Kernel<SliceToIndices, xpu>::Launch(s, indices_len,
indices_ptr, N,
start, step);
Expand Down
35 changes: 22 additions & 13 deletions src/operator/numpy/np_insert_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bool NumpyInsertType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
const NumpyInsertParam& param = nnvm::get<NumpyInsertParam>(attrs.parsed);
int insize = (param.stop.has_value() || param.int_ind.has_value()) ? 2 : 3;
int insize = (param.step.has_value() || param.int_ind.has_value()) ? 2 : 3;
CHECK_EQ(in_type->size(), insize);
CHECK_EQ(out_type->size(), 1U);
if (insize == 3) {
Expand All @@ -56,7 +56,7 @@ bool NumpyInsertShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
const NumpyInsertParam& param = nnvm::get<NumpyInsertParam>(attrs.parsed);
CHECK_EQ(in_shape->size(),
(param.stop.has_value() || param.int_ind.has_value()) ? 2U : 3U);
(param.step.has_value() || param.int_ind.has_value()) ? 2U : 3U);
mxnet::TShape &arrshape = (*in_shape)[insert_::kArr];
mxnet::TShape &valshape = (*in_shape)[insert_::kValues];
mxnet::TShape &objShape = (*in_shape)[insert_::kObj];
Expand Down Expand Up @@ -88,16 +88,25 @@ bool NumpyInsertShape(const nnvm::NodeAttrs& attrs,
int N = arrshape[axis];
if (in_shape->size() == 3U) {
seq_cnt = objShape.Size();
} else if (param.stop.has_value()) {
} else if (param.step.has_value()) {
int step = param.step.value();
int stop = param.stop.value();
stop += (stop < 0) ? N : 0;
stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop;
stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop;
int start = param.start.value();
start += (start < 0) ? N : 0;
start = (start < 0) ? ((step < 0) ? -1 : 0) : start;
start = (start >= N) ? ((step < 0) ? N - 1 : N) : start;
int stop, start;
if (param.stop.has_value()) {
stop = param.stop.value();
stop += (stop < 0) ? N : 0;
stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop;
stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop;
} else {
stop = (step > 0) ? N : -1;
}
if (param.start.has_value()) {
start = param.start.value();
start += (start < 0) ? N : 0;
start = (start < 0) ? ((step < 0) ? -1 : 0) : start;
start = (start >= N) ? ((step < 0) ? N - 1 : N) : start;
} else {
start = (step > 0) ? 0 : N - 1;
}
seq_cnt = 0;
if (step > 0 && stop >= start) {
seq_cnt = (stop - start + step - 1) / step;
Expand Down Expand Up @@ -140,13 +149,13 @@ NNVM_REGISTER_OP(_npi_insert)
.set_attr_parser(ParamParser<NumpyInsertParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
const NumpyInsertParam& params = nnvm::get<NumpyInsertParam>(attrs.parsed);
return (params.stop.has_value() || params.int_ind.has_value()) ? 2U : 3U;
return (params.step.has_value() || params.int_ind.has_value()) ? 2U : 3U;
})
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyInsertParam& params = nnvm::get<NumpyInsertParam>(attrs.parsed);
return (params.stop.has_value() || params.int_ind.has_value()) ?
return (params.step.has_value() || params.int_ind.has_value()) ?
std::vector<std::string>{"arr", "values"} :
std::vector<std::string>{"arr", "values", "obj"};
})
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,9 +2028,9 @@ def GetNdim(tp):
idx = _np.random.randint(-1 * H[3], H[3] + 1, size = (3)).tolist()
config.append(tuple([H, idx, E, 3]))
# test slice
for st in range(-5, 5):
for ed in range(-5,5):
for stp in [-1, 1, 2]:
for st in [-5, -3, -1, 0, 1, 3, 5, None]:
for ed in [-5, -3, -1, 0, 1, 3, 5, None]:
for stp in [-1, 1, 2, None]:
config.append(tuple([A, slice(st, ed, stp), F, 1]))

for arr_shape, obj, val_shape, axis in config:
Expand Down

0 comments on commit 3e9cf37

Please sign in to comment.