Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add comment / abstarct basic function
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Dec 24, 2019
1 parent 71ecb71 commit 410faa4
Showing 1 changed file with 84 additions and 44 deletions.
128 changes: 84 additions & 44 deletions src/operator/numpy/np_delete_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,29 @@ struct NumpyDeleteParam : public dmlc::Parameter<NumpyDeleteParam> {
};

namespace delete_ {

enum DeleteOpInputs {kArr, kObj};
enum DeleteOpOutputs {kOut};
} // namespace delete_

struct SliceToIndices {
/*!
* \brief transfer slice to indices array
*/
template<typename IType>
MSHADOW_XINLINE static void Map(int i, IType* indices, int start, int step) {
indices[i] = start + i * step;
}
};

struct IsDeleteCal {
/*!
* \brief indicate which indices need to be deleted in input
* \param N used to check indices legality
* \param is_delete if is_delete[i] == False, index i needn't to be deleted from output
* if is_delete[i] == True, index i need to be deleted from output
* \param indices the indices need to be deleted
*/
template<typename IType>
MSHADOW_XINLINE static void Map(int i, int N, bool* is_delete, const IType* indices) {
if ((indices[i] >= 0) && (indices[i] < N)) {
Expand All @@ -95,7 +106,10 @@ struct IsDeleteCal {

struct OutPosCal {
/*!
* \brief map the index from input to output
* \brief map the index from input to output. e.g.
* \example original_position 0 1 2 3 4
* is_delete F T T F F
* out_position 0 - - 1 2
*/
MSHADOW_XINLINE static void Map(int i, int64_t* out_pos, const bool* is_delete) {
if (!is_delete[i]) {
Expand All @@ -111,7 +125,7 @@ struct OutPosCal {
};

template<int req, int ndim>
struct DeleteImpl {
struct DeleteKernel {
/*!
* \brief delete a sub-array from input along an axis according to 'is_delete'.
* \param out_data - output: a new array with sub-arrays along an axis deleted.
Expand Down Expand Up @@ -140,6 +154,47 @@ struct DeleteImpl {
}
};

/*!
* /brief equals to numpy's slice.indices(range)
* /param pstart - slice.start
* /param pstep - slice.step
* /param pstop - slice.stop
* /return start - slice.indices(range).start
* /return stop - slice.indices(range).stop
* /return step - slice.indices(range).step
* /return tot - total number of slice.indices(range)
*/
inline void SliceIndices(const dmlc::optional<int>& pstart,
const dmlc::optional<int>& pstop,
const dmlc::optional<int>& pstep,
const int range,
int* start, int* stop, int* step,
size_t* tot) {
*step = pstep.has_value() ? pstep.value() : 1;
CHECK_NE(*step, 0) << "'step' can not equal to 0.";
if (pstop.has_value()) {
*stop = pstop.value();
*stop += (*stop < 0) ? range : 0;
*stop = (*stop < 0) ? ((*step < 0) ? -1 : 0) : *stop;
*stop = (*stop >= range) ? ((*step < 0) ? range - 1 : range) : *stop;
} else {
*stop = (*step > 0) ? range : -1;
}
if (pstart.has_value()) {
*start = pstart.value();
*start += (*start < 0) ? range : 0;
*start = (*start < 0) ? ((*step < 0) ? -1 : 0) : *start;
*start = (*start >= range) ? ((*step < 0) ? range - 1 : range) : *start;
} else {
*start = (*step > 0) ? 0 : range - 1;
}
if (*step > 0 && *stop >= *start) {
*tot = static_cast<size_t>((*stop - *start + *step - 1) / *step);
} else if (*step < 0 && *stop <= *start) {
*tot = static_cast<size_t>((*stop - *start + *step + 1) / *step);
}
}

template<typename xpu>
void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
Expand All @@ -158,7 +213,7 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,

int ndim = inputs[delete_::kArr].shape().ndim();
int axis = param.axis.has_value() ? param.axis.value() : -1;
NDArray arr;
NDArray arr; // original array

if (!param.axis.has_value()) {
arr = inputs[delete_::kArr].Reshape(Shape1(inputs[delete_::kArr].shape().Size()));
Expand All @@ -176,57 +231,39 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,

axis = CheckAxis(axis, ndim);
int N = (arr.shape())[axis];
mxnet::TShape newshape(arr.shape());
mxnet::TShape outshape(arr.shape());
// if obj is slice, they're obj's arguments
int start = 0, stop = 0, step = 0;
// total number to be deleted
size_t numtodel = 0;
// if obj is scaler, index is it's value
int index = 0;

if (param.step.has_value()) {
step = param.step.value();
CHECK_NE(step, 0) << "'step' can not equal to 0.";
if (param.stop.has_value()) {
stop = param.stop.value();
stop += (stop < 0) ? N : 0;
stop = (stop < 0) ? ((step < 0) ? -1 : 0) : stop;
stop = (stop >= N) ? ((step < 0) ? N - 1 : N) : stop;
} else {
stop = (step > 0) ? N : -1;
}
if (param.start.has_value()) {
start = param.start.value();
start += (start < 0) ? N : 0;
start = (start < 0) ? ((step < 0) ? -1 : 0) : start;
start = (start >= N) ? ((step < 0) ? N - 1 : N) : start;
} else {
start = (step > 0) ? 0 : N - 1;
}
if (step > 0 && stop >= start) {
numtodel = static_cast<size_t>((stop - start + step - 1) / step);
} else if (step < 0 && stop <= start) {
numtodel = static_cast<size_t>((stop - start + step + 1) / step);
}
if (param.step.has_value()) { // obj is slice
SliceIndices(param.start, param.stop, param.step,
N, &start, &stop, &step, &numtodel);
if (numtodel == 0) {
const_cast<NDArray &>(outputs[delete_::kOut]).Init(arr.shape());
mxnet_op::copy(s, outputs[delete_::kOut].data(), inputs[delete_::kArr].data());
return;
}
newshape[axis] -= numtodel;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(newshape);
} else if (param.int_ind.has_value()) {
outshape[axis] -= numtodel;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(outshape);
} else if (param.int_ind.has_value()) { // obj is scaler
index = param.int_ind.value();
CHECK((index >= -1 * N) && (index < N))
<< "index " << index
<< " is out of bounds for axis " << axis
<< " with size " << N << "\n";
index += ((index < 0) ? N : 0);
numtodel = static_cast<size_t>(1);
newshape[axis] -= 1;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(newshape);
} else {
outshape[axis] -= 1;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(outshape);
} else { // obj is tensor
numtodel = inputs[delete_::kObj].shape().Size();
}

MSHADOW_TYPE_SWITCH(((inputs.size() == 2U) ?
MSHADOW_TYPE_SWITCH(((inputs.size() == 2U) ? // obj is tensor
inputs[delete_::kObj].dtype() :
mshadow::DataType<int64_t>::kFlag), IType, {
size_t temp_mem_size = sizeof(int64_t) * arr.shape()[axis] +
Expand All @@ -240,19 +277,22 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,
bool* is_delete_ptr = reinterpret_cast<bool*>
(temp_mem.dptr_ + sizeof(int64_t) * arr.shape()[axis] +
sizeof(IType) * numtodel);
if (param.step.has_value()) {
if (param.step.has_value()) { // obj is slice, transfer slice to tensor
Kernel<SliceToIndices, xpu>::Launch(s, numtodel, indices_ptr, start, step);
} else if (param.int_ind.has_value()) {
} else if (param.int_ind.has_value()) { // obj is scaler, copy it to tensor
Fill(s, TBlob(indices_ptr, Shape1(numtodel), xpu::kDevMask), kWriteTo, index);
} else {
} else { // obj is tensor, copy it to a unified tensor
mxnet_op::copy(s,
TBlob(indices_ptr, inputs[delete_::kObj].shape(), inputs[delete_::kObj].data().dev_mask()),
inputs[delete_::kObj].data());
}
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, arr.shape()[axis], is_delete_ptr);
// mark which position need to be deleted from input arr
Kernel<IsDeleteCal, xpu>::Launch(s, numtodel, N, is_delete_ptr, indices_ptr);
// calculate output data's original position in input arr
Kernel<OutPosCal, xpu>::Launch(s, arr.shape()[axis], out_pos_ptr, is_delete_ptr);
if (inputs.size() == 2U) {
if (inputs.size() == 2U) { // obj is tensor
// get total number of nonredundant indices
#ifdef __CUDACC__
thrust::device_ptr<bool>is_delete_dev(is_delete_ptr);
thrust::device_vector<bool>vec_is_delete(is_delete_dev, is_delete_dev + arr.shape()[axis]);
Expand All @@ -265,14 +305,14 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,
numtodel++;
}
}
newshape[axis] -= numtodel;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(newshape);
outshape[axis] -= numtodel;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(outshape);
}
MXNET_NDIM_SWITCH(newshape.ndim(), ndim, {
mshadow::Shape<ndim> out_strides = mxnet_op::calc_stride(newshape.get<ndim>());
MXNET_NDIM_SWITCH(outshape.ndim(), ndim, {
mshadow::Shape<ndim> out_strides = mxnet_op::calc_stride(outshape.get<ndim>());
MSHADOW_TYPE_SWITCH(outputs[delete_::kOut].dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req[delete_::kOut], req_type, {
Kernel<DeleteImpl<req_type, ndim>, xpu>::Launch(
Kernel<DeleteKernel<req_type, ndim>, xpu>::Launch(
s, arr.shape().Size(),
outputs[delete_::kOut].data().dptr<DType>(),
arr.data().dptr<DType>(),
Expand Down

0 comments on commit 410faa4

Please sign in to comment.