Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reuse existing function / add licenses / modify alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Dec 16, 2019
1 parent 597c12f commit 91e84e4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 73 deletions.
120 changes: 49 additions & 71 deletions src/operator/numpy/np_delete_op-inl.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_delete_op-inl.h
Expand All @@ -11,6 +30,7 @@
#include <algorithm>
#include "../../common/utils.h"
#include "../tensor/sort_op.h"
#include "../tensor/init_op.h"
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../tensor/broadcast_reduce_op.h"
Expand Down Expand Up @@ -64,19 +84,11 @@ struct SliceToIndices {
}
};

template<typename IType>
struct AssignNum {
MSHADOW_XINLINE static void Map(int i, IType* output, const IType data) {
output[i] = data;
}
};

struct IsDeleteCal {
template<typename IType>
MSHADOW_XINLINE static void Map(int i, int N, bool* is_delete, const IType* indices) {
if ((static_cast<int64_t>(indices[i]) >= 0) &&
(static_cast<int64_t>(indices[i]) < N)) {
is_delete[static_cast<int64_t>(indices[i])] = true;
if ((indices[i] >= 0) && (indices[i] < N)) {
is_delete[static_cast<int>(indices[i])] = true;
}
}
};
Expand All @@ -98,64 +110,31 @@ struct OutPosCal {
}
};

template<int ndim>
inline mshadow::Shape<ndim> GetStride(const mxnet::TShape& shape) {
mshadow::Shape<ndim>stride;
size_t tmp = 1;
for (int i = shape.ndim() - 1; i >= 0; --i) {
stride[i] = tmp;
tmp *= shape[i];
}
return stride;
}

template<int ndim>
inline mshadow::Shape<ndim> GetKernelShape(const mxnet::TShape& shape) {
mshadow::Shape<ndim>k_shape;
for (int i = 0 ; i < shape.ndim() ; ++i) {
k_shape[i] = shape[i];
}
return k_shape;
}

template<int req>
template<int req, int ndim>
struct DeleteImpl {
/*!
* \brief delete a sub-array from input along an axis according to 'is_delete'.
* \tparam xpu - cpu or gpu.
* \param out_data - output: a new array with sub-arrays along an axis deleted.
* \param in_arr - input: 'arr', original array.
* \param is_delete - mark where will be deleted or be reminded in 'arr'
* \param out_pos - if is_delete[i] is 'false', out_pos[i] indicates its.
* \param arrshape - the shape of 'arr'.
* \param arr_stride - the stride of 'arr'.
* \param out_stride - the stride of 'out_data'.
* \param out_ndim - the ndim of 'out_data'.
* \param axis - delete sub-array along this axis
*/
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* in_arr,
const bool* is_delete,
const int64_t* out_pos,
const mshadow::Shape<10> arrshape,
const mshadow::Shape<10> arr_stride,
const mshadow::Shape<10> out_stride,
const int out_ndim, const int axis) {
const int64_t arr_head = i / arr_stride[axis];
const int64_t arr_mid = arr_head % arrshape[axis];
mshadow::Shape<10> arr_idx; // i -> position in in_arr's shape
for (int j = 0; j < out_ndim; ++j) {
const int64_t head = i / arr_stride[j];
const int64_t mid = head % arrshape[j];
arr_idx[j] = mid;
}
if (!is_delete[arr_mid]) {
arr_idx[axis] = out_pos[arr_mid];
int64_t dest_idx = 0;
for (int j =0; j < out_ndim; ++j) {
dest_idx += out_stride[j] * arr_idx[j];
}
const mshadow::Shape<ndim> arrshape,
const mshadow::Shape<ndim> out_stride,
const int axis) {
// i -> position in in_arr's shape
mshadow::Shape<ndim> arr_idx = mxnet_op::unravel(i, arrshape);
if (!is_delete[arr_idx[axis]]) {
arr_idx[axis] = out_pos[arr_idx[axis]];
int64_t dest_idx = mxnet_op::dot(arr_idx, out_stride);
KERNEL_ASSIGN(out_data[dest_idx], req, in_arr[i]);
}
}
Expand Down Expand Up @@ -248,24 +227,23 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,
}

MSHADOW_TYPE_SWITCH(((inputs.size() == 2U) ?
inputs[delete_::kObj].dtype() :
mshadow::DataType<int64_t>::kFlag), IType, {
inputs[delete_::kObj].dtype() :
mshadow::DataType<int64_t>::kFlag), IType, {
size_t temp_mem_size = sizeof(int64_t) * arr.shape()[axis] +
sizeof(IType) * numtodel +
sizeof(bool) * arr.shape()[axis];
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
int64_t* out_pos_ptr = reinterpret_cast<int64_t*>(temp_mem.dptr_);
IType* indices_ptr = reinterpret_cast<IType*>
(temp_mem.dptr_ + sizeof(int64_t) * arr.shape()[axis]);
bool* is_delete_ptr = reinterpret_cast<bool*>
(temp_mem.dptr_ + sizeof(int64_t) * arr.shape()[axis] +
sizeof(IType) * numtodel);
sizeof(IType) * numtodel);
if (param.step.has_value()) {
Kernel<SliceToIndices, xpu>::Launch(s, numtodel,
indices_ptr, start, step);
Kernel<SliceToIndices, xpu>::Launch(s, numtodel, indices_ptr, start, step);
} else if (param.int_ind.has_value()) {
Kernel<AssignNum<IType>, xpu>::Launch(s, numtodel, indices_ptr, index);
Fill(s, TBlob(indices_ptr, Shape1(numtodel), xpu::kDevMask), kWriteTo, index);
} else {
mxnet_op::copy(s,
TBlob(indices_ptr, inputs[delete_::kObj].shape(), inputs[delete_::kObj].data().dev_mask()),
Expand All @@ -290,18 +268,18 @@ void NumpyDeleteCompute(const nnvm::NodeAttrs& attrs,
newshape[axis] -= numtodel;
const_cast<NDArray &>(outputs[delete_::kOut]).Init(newshape);
}
mshadow::Shape<10> arr_strides = GetStride<10>(arr.shape());
mshadow::Shape<10> out_strides = GetStride<10>(newshape);
mshadow::Shape<10> k_arrshape = GetKernelShape<10>(arr.shape());
MSHADOW_TYPE_SWITCH(outputs[delete_::kOut].dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req[delete_::kOut], req_type, {
Kernel<DeleteImpl<req_type>, xpu>::Launch(
s, arr.shape().Size(),
outputs[delete_::kOut].data().dptr<DType>(),
arr.data().dptr<DType>(),
is_delete_ptr, out_pos_ptr,
k_arrshape, arr_strides, out_strides,
newshape.ndim(), axis);
MXNET_NDIM_SWITCH(newshape.ndim(), ndim, {
mshadow::Shape<ndim> out_strides = mxnet_op::calc_stride(newshape.get<ndim>());
MSHADOW_TYPE_SWITCH(outputs[delete_::kOut].dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(req[delete_::kOut], req_type, {
Kernel<DeleteImpl<req_type, ndim>, xpu>::Launch(
s, arr.shape().Size(),
outputs[delete_::kOut].data().dptr<DType>(),
arr.data().dptr<DType>(),
is_delete_ptr, out_pos_ptr,
arr.shape().get<ndim>(),
out_strides, axis);
});
});
});
});
Expand Down
23 changes: 21 additions & 2 deletions src/operator/numpy/np_delete_op.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_delete_op.cc
Expand All @@ -21,8 +40,8 @@ bool NumpyDeleteType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_type->size(), 1U);
if (insize == 3) {
CHECK_NE((*in_type)[1], -1) << "Index type must be set for insert operator\n";
CHECK(((*in_type)[1] == mshadow::DataType<int64_t>::kFlag)
|| ((*in_type)[1] == mshadow::DataType<int32_t>::kFlag))
CHECK(((*in_type)[1] == mshadow::DataType<int64_t>::kFlag) ||
((*in_type)[1] == mshadow::DataType<int32_t>::kFlag))
<< "Index type only support int32 or int64.\n";
}
TYPE_ASSIGN_CHECK(*out_type, 0, (*in_type)[0]);
Expand Down
19 changes: 19 additions & 0 deletions src/operator/numpy/np_delete_op.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.ø
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_delete_op.cu
Expand Down

0 comments on commit 91e84e4

Please sign in to comment.