diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 3a33a2a8a76c..5bac091b391a 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -465,7 +465,7 @@ def insert(arr, obj, values, axis=None): if not isinstance(values, NDArray): raise TypeError("'values' can not support type {}".format(str(type(values)))) if isinstance(obj, slice): - start = 0 if obj.start is None else obj.start + start = obj.start stop = obj.stop step = 1 if obj.step is None else obj.step return _npi.insert(arr, values, start=start, stop=stop, step=step, axis=axis) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index fbc2bf5124ec..19af9c2853be 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -2571,7 +2571,7 @@ def insert(arr, obj, values, axis=None): if not isinstance(values, ndarray): # pylint: disable= undefined-variable raise TypeError("'values' can not support type {}".format(str(type(values)))) if isinstance(obj, slice): - start = 0 if obj.start is None else obj.start + start = obj.start stop = obj.stop step = 1 if obj.step is None else obj.step return _npi.insert(arr, values, start=start, stop=stop, step=step, axis=axis) diff --git a/src/operator/numpy/np_insert_op-inl.h b/src/operator/numpy/np_insert_op-inl.h index aa8df33e6e1b..1e54cd760975 100644 --- a/src/operator/numpy/np_insert_op-inl.h +++ b/src/operator/numpy/np_insert_op-inl.h @@ -347,7 +347,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs, const NumpyInsertParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), - (param.stop.has_value() || param.int_ind.has_value()) ? 2U : 3U); + (param.step.has_value() || param.int_ind.has_value()) ? 2U : 3U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); mshadow::Stream *s = ctx.get_stream(); @@ -385,17 +385,25 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs, // get and check indices from slice or sequence of ints if (inputs.size() == 3U) { indices_len = inputs[insert_::kObj].shape_.Size(); - } else if (param.stop.has_value()) { + } else if (param.step.has_value()) { step = param.step.value(); CHECK_NE(step, 0) << "'step' can not equal to 0."; - stop = param.stop.value(); - stop += (stop < 0) ? N : 0; - stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop; - stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop; - start = param.start.value(); - start += (start < 0) ? N : 0; - start = (start < 0) ? ((step < 0) ? -1 : 0) : start; - start = (start >= N) ? ((step < 0) ? N - 1 : N) : start; + if (param.stop.has_value()) { + stop = param.stop.value(); + stop += (stop < 0) ? N : 0; + stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop; + stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop; + } else { + stop = (step > 0) ? N : -1; + } + if (param.start.has_value()) { + start = param.start.value(); + start += (start < 0) ? N : 0; + start = (start < 0) ? ((step < 0) ? -1 : 0) : start; + start = (start >= N) ? ((step < 0) ? N - 1 : N) : start; + } else { + start = (step > 0) ? 0 : N - 1; + } int seq_cnt = 0; if (step > 0 && stop >= start) { seq_cnt = (stop - start + step - 1) / step; @@ -451,7 +459,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs, } else if (indices_len == 1) { numnew = values.shape_[axis]; newshape[axis] += numnew; - if (param.start.has_value()) { + if (param.step.has_value()) { index = start; CHECK(index >= -1 * N && index <= N) << "Index should be in the range of [-r, r-1] where r is the dim size in 'axis'"; @@ -530,7 +538,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs, } else if (indices_len == 1) { MSHADOW_TYPE_SWITCH(outputs[insert_::kOut].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[insert_::kOut], req_type, { - if (param.stop.has_value()) { + if (param.step.has_value()) { Kernel, xpu>::Launch(s, outshape.Size(), outputs[insert_::kOut].dptr(), values.dptr(), arr.dptr(), @@ -589,7 +597,7 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs, Tensor order(order_ptr, Shape1(indices_len), s); int num_bits = common::ilog2ui(static_cast(indices_len) - 1); - if (param.stop.has_value()) { + if (param.step.has_value()) { Kernel::Launch(s, indices_len, indices_ptr, N, start, step); diff --git a/src/operator/numpy/np_insert_op.cc b/src/operator/numpy/np_insert_op.cc index ea34e77f16d2..1f4151883a3e 100644 --- a/src/operator/numpy/np_insert_op.cc +++ b/src/operator/numpy/np_insert_op.cc @@ -35,7 +35,7 @@ bool NumpyInsertType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const NumpyInsertParam& param = nnvm::get(attrs.parsed); - int insize = (param.stop.has_value() || param.int_ind.has_value()) ? 2 : 3; + int insize = (param.step.has_value() || param.int_ind.has_value()) ? 2 : 3; CHECK_EQ(in_type->size(), insize); CHECK_EQ(out_type->size(), 1U); if (insize == 3) { @@ -56,7 +56,7 @@ bool NumpyInsertShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; const NumpyInsertParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), - (param.stop.has_value() || param.int_ind.has_value()) ? 2U : 3U); + (param.step.has_value() || param.int_ind.has_value()) ? 2U : 3U); mxnet::TShape &arrshape = (*in_shape)[insert_::kArr]; mxnet::TShape &valshape = (*in_shape)[insert_::kValues]; mxnet::TShape &objShape = (*in_shape)[insert_::kObj]; @@ -88,16 +88,25 @@ bool NumpyInsertShape(const nnvm::NodeAttrs& attrs, int N = arrshape[axis]; if (in_shape->size() == 3U) { seq_cnt = objShape.Size(); - } else if (param.stop.has_value()) { + } else if (param.step.has_value()) { int step = param.step.value(); - int stop = param.stop.value(); - stop += (stop < 0) ? N : 0; - stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop; - stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop; - int start = param.start.value(); - start += (start < 0) ? N : 0; - start = (start < 0) ? ((step < 0) ? -1 : 0) : start; - start = (start >= N) ? ((step < 0) ? N - 1 : N) : start; + int stop, start; + if (param.stop.has_value()) { + stop = param.stop.value(); + stop += (stop < 0) ? N : 0; + stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop; + stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop; + } else { + stop = (step > 0) ? N : -1; + } + if (param.start.has_value()) { + start = param.start.value(); + start += (start < 0) ? N : 0; + start = (start < 0) ? ((step < 0) ? -1 : 0) : start; + start = (start >= N) ? ((step < 0) ? N - 1 : N) : start; + } else { + start = (step > 0) ? 0 : N - 1; + } seq_cnt = 0; if (step > 0 && stop >= start) { seq_cnt = (stop - start + step - 1) / step; @@ -140,13 +149,13 @@ NNVM_REGISTER_OP(_npi_insert) .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const NumpyInsertParam& params = nnvm::get(attrs.parsed); - return (params.stop.has_value() || params.int_ind.has_value()) ? 2U : 3U; + return (params.step.has_value() || params.int_ind.has_value()) ? 2U : 3U; }) .set_num_outputs(1) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const NumpyInsertParam& params = nnvm::get(attrs.parsed); - return (params.stop.has_value() || params.int_ind.has_value()) ? + return (params.step.has_value() || params.int_ind.has_value()) ? std::vector{"arr", "values"} : std::vector{"arr", "values", "obj"}; }) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c1a967c4597b..1adf9f64e791 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2023,9 +2023,9 @@ def GetNdim(tp): idx = _np.random.randint(-1 * H[3], H[3] + 1, size = (3)).tolist() config.append(tuple([H, idx, E, 3])) # test slice - for st in range(-5, 5): - for ed in range(-5,5): - for stp in [-1, 1, 2]: + for st in [-5, -3, -1, 0, 1, 3, 5, None]: + for ed in [-5, -3, -1, 0, 1, 3, 5, None]: + for stp in [-1, 1, 2, None]: config.append(tuple([A, slice(st, ed, stp), F, 1])) for arr_shape, obj, val_shape, axis in config: