From 410faa4c5b72e83d8a6d842d9db4a6a5e447bdc5 Mon Sep 17 00:00:00 2001 From: JiangZhaoh Date: Tue, 24 Dec 2019 16:57:04 +0800 Subject: [PATCH] add comment / abstarct basic function --- src/operator/numpy/np_delete_op-inl.h | 128 +++++++++++++++++--------- 1 file changed, 84 insertions(+), 44 deletions(-) diff --git a/src/operator/numpy/np_delete_op-inl.h b/src/operator/numpy/np_delete_op-inl.h index bcdb126f233f..8a9fcb30840e 100644 --- a/src/operator/numpy/np_delete_op-inl.h +++ b/src/operator/numpy/np_delete_op-inl.h @@ -73,11 +73,15 @@ struct NumpyDeleteParam : public dmlc::Parameter { }; namespace delete_ { + enum DeleteOpInputs {kArr, kObj}; enum DeleteOpOutputs {kOut}; } // namespace delete_ struct SliceToIndices { + /*! + * \brief transfer slice to indices array + */ template MSHADOW_XINLINE static void Map(int i, IType* indices, int start, int step) { indices[i] = start + i * step; @@ -85,6 +89,13 @@ struct SliceToIndices { }; struct IsDeleteCal { + /*! + * \brief indicate which indices need to be deleted in input + * \param N used to check indices legality + * \param is_delete if is_delete[i] == False, index i needn't to be deleted from output + * if is_delete[i] == True, index i need to be deleted from output + * \param indices the indices need to be deleted + */ template MSHADOW_XINLINE static void Map(int i, int N, bool* is_delete, const IType* indices) { if ((indices[i] >= 0) && (indices[i] < N)) { @@ -95,7 +106,10 @@ struct IsDeleteCal { struct OutPosCal { /*! - * \brief map the index from input to output + * \brief map the index from input to output. e.g. + * \example original_position 0 1 2 3 4 + * is_delete F T T F F + * out_position 0 - - 1 2 */ MSHADOW_XINLINE static void Map(int i, int64_t* out_pos, const bool* is_delete) { if (!is_delete[i]) { @@ -111,7 +125,7 @@ struct OutPosCal { }; template -struct DeleteImpl { +struct DeleteKernel { /*! * \brief delete a sub-array from input along an axis according to 'is_delete'. * \param out_data - output: a new array with sub-arrays along an axis deleted. @@ -140,6 +154,47 @@ struct DeleteImpl { } }; +/*! + * /brief equals to numpy's slice.indices(range) + * /param pstart - slice.start + * /param pstep - slice.step + * /param pstop - slice.stop + * /return start - slice.indices(range).start + * /return stop - slice.indices(range).stop + * /return step - slice.indices(range).step + * /return tot - total number of slice.indices(range) + */ +inline void SliceIndices(const dmlc::optional& pstart, + const dmlc::optional& pstop, + const dmlc::optional& pstep, + const int range, + int* start, int* stop, int* step, + size_t* tot) { + *step = pstep.has_value() ? pstep.value() : 1; + CHECK_NE(*step, 0) << "'step' can not equal to 0."; + if (pstop.has_value()) { + *stop = pstop.value(); + *stop += (*stop < 0) ? range : 0; + *stop = (*stop < 0) ? ((*step < 0) ? -1 : 0) : *stop; + *stop = (*stop >= range) ? ((*step < 0) ? range - 1 : range) : *stop; + } else { + *stop = (*step > 0) ? range : -1; + } + if (pstart.has_value()) { + *start = pstart.value(); + *start += (*start < 0) ? range : 0; + *start = (*start < 0) ? ((*step < 0) ? -1 : 0) : *start; + *start = (*start >= range) ? ((*step < 0) ? range - 1 : range) : *start; + } else { + *start = (*step > 0) ? 0 : range - 1; + } + if (*step > 0 && *stop >= *start) { + *tot = static_cast((*stop - *start + *step - 1) / *step); + } else if (*step < 0 && *stop <= *start) { + *tot = static_cast((*stop - *start + *step + 1) / *step); + } +} + template void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, const OpContext &ctx, @@ -158,7 +213,7 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, int ndim = inputs[delete_::kArr].shape().ndim(); int axis = param.axis.has_value() ? param.axis.value() : -1; - NDArray arr; + NDArray arr; // original array if (!param.axis.has_value()) { arr = inputs[delete_::kArr].Reshape(Shape1(inputs[delete_::kArr].shape().Size())); @@ -176,43 +231,25 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, axis = CheckAxis(axis, ndim); int N = (arr.shape())[axis]; - mxnet::TShape newshape(arr.shape()); + mxnet::TShape outshape(arr.shape()); + // if obj is slice, they're obj's arguments int start = 0, stop = 0, step = 0; + // total number to be deleted size_t numtodel = 0; + // if obj is scaler, index is it's value int index = 0; - if (param.step.has_value()) { - step = param.step.value(); - CHECK_NE(step, 0) << "'step' can not equal to 0."; - if (param.stop.has_value()) { - stop = param.stop.value(); - stop += (stop < 0) ? N : 0; - stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop; - stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop; - } else { - stop = (step > 0) ? N : -1; - } - if (param.start.has_value()) { - start = param.start.value(); - start += (start < 0) ? N : 0; - start = (start < 0) ? ((step < 0) ? -1 : 0) : start; - start = (start >= N) ? ((step < 0) ? N - 1 : N) : start; - } else { - start = (step > 0) ? 0 : N - 1; - } - if (step > 0 && stop >= start) { - numtodel = static_cast((stop - start + step - 1) / step); - } else if (step < 0 && stop <= start) { - numtodel = static_cast((stop - start + step + 1) / step); - } + if (param.step.has_value()) { // obj is slice + SliceIndices(param.start, param.stop, param.step, + N, &start, &stop, &step, &numtodel); if (numtodel == 0) { const_cast(outputs[delete_::kOut]).Init(arr.shape()); mxnet_op::copy(s, outputs[delete_::kOut].data(), inputs[delete_::kArr].data()); return; } - newshape[axis] -= numtodel; - const_cast(outputs[delete_::kOut]).Init(newshape); - } else if (param.int_ind.has_value()) { + outshape[axis] -= numtodel; + const_cast(outputs[delete_::kOut]).Init(outshape); + } else if (param.int_ind.has_value()) { // obj is scaler index = param.int_ind.value(); CHECK((index >= -1 * N) && (index < N)) << "index " << index @@ -220,13 +257,13 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, << " with size " << N << "\n"; index += ((index < 0) ? N : 0); numtodel = static_cast(1); - newshape[axis] -= 1; - const_cast(outputs[delete_::kOut]).Init(newshape); - } else { + outshape[axis] -= 1; + const_cast(outputs[delete_::kOut]).Init(outshape); + } else { // obj is tensor numtodel = inputs[delete_::kObj].shape().Size(); } - MSHADOW_TYPE_SWITCH(((inputs.size() == 2U) ? + MSHADOW_TYPE_SWITCH(((inputs.size() == 2U) ? // obj is tensor inputs[delete_::kObj].dtype() : mshadow::DataType::kFlag), IType, { size_t temp_mem_size = sizeof(int64_t) * arr.shape()[axis] + @@ -240,19 +277,22 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, bool* is_delete_ptr = reinterpret_cast (temp_mem.dptr_ + sizeof(int64_t) * arr.shape()[axis] + sizeof(IType) * numtodel); - if (param.step.has_value()) { + if (param.step.has_value()) { // obj is slice, transfer slice to tensor Kernel::Launch(s, numtodel, indices_ptr, start, step); - } else if (param.int_ind.has_value()) { + } else if (param.int_ind.has_value()) { // obj is scaler, copy it to tensor Fill(s, TBlob(indices_ptr, Shape1(numtodel), xpu::kDevMask), kWriteTo, index); - } else { + } else { // obj is tensor, copy it to a unified tensor mxnet_op::copy(s, TBlob(indices_ptr, inputs[delete_::kObj].shape(), inputs[delete_::kObj].data().dev_mask()), inputs[delete_::kObj].data()); } mxnet_op::Kernel::Launch(s, arr.shape()[axis], is_delete_ptr); + // mark which position need to be deleted from input arr Kernel::Launch(s, numtodel, N, is_delete_ptr, indices_ptr); + // calculate output data's original position in input arr Kernel::Launch(s, arr.shape()[axis], out_pos_ptr, is_delete_ptr); - if (inputs.size() == 2U) { + if (inputs.size() == 2U) { // obj is tensor + // get total number of nonredundant indices #ifdef __CUDACC__ thrust::device_ptris_delete_dev(is_delete_ptr); thrust::device_vectorvec_is_delete(is_delete_dev, is_delete_dev + arr.shape()[axis]); @@ -265,14 +305,14 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs, numtodel++; } } - newshape[axis] -= numtodel; - const_cast(outputs[delete_::kOut]).Init(newshape); + outshape[axis] -= numtodel; + const_cast(outputs[delete_::kOut]).Init(outshape); } - MXNET_NDIM_SWITCH(newshape.ndim(), ndim, { - mshadow::Shape out_strides = mxnet_op::calc_stride(newshape.get()); + MXNET_NDIM_SWITCH(outshape.ndim(), ndim, { + mshadow::Shape out_strides = mxnet_op::calc_stride(outshape.get()); MSHADOW_TYPE_SWITCH(outputs[delete_::kOut].dtype(), DType, { MXNET_ASSIGN_REQ_SWITCH(req[delete_::kOut], req_type, { - Kernel, xpu>::Launch( + Kernel, xpu>::Launch( s, arr.shape().Size(), outputs[delete_::kOut].data().dptr(), arr.data().dptr(),