diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 9961218b5482..470abee71a59 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -29,7 +29,7 @@ namespace mxnet { namespace op { template -struct TakeCPU { +struct TakeZeroAxisCPU { // assume that idx have been flattened to a 1-D tensor (N,) // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) // M is the number of columns of in_data and out_data @@ -88,8 +88,9 @@ void EmbeddingOpForwardDnsImpl(mshadow::Stream* s, Tensor wmat = weight.get(s); Tensor out = output.get_with_shape( Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - Kernel, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_, wmat.dptr_, - idx.dptr_, wmat.shape_[1], wmat.shape_[0]); + Kernel, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_, + wmat.dptr_, idx.dptr_, + wmat.shape_[1], wmat.shape_[0]); }); }); } @@ -308,17 +309,17 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, } if (actual_axis == 0) { if (param.mode == take_::kClip) { - Kernel, cpu>::Launch(s, idxshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + Kernel, cpu>::Launch(s, idxshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); } else { - Kernel, cpu>::Launch(s, idxshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + Kernel, cpu>::Launch(s, idxshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); } } else { mshadow::Shape<10> in_strides; @@ -332,21 +333,25 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, out_strides[i] = stride; } if (param.mode == take_::kClip) { - Kernel, cpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), - oshape.ndim(), idxshape.ndim(), - arrshape[actual_axis], actual_axis); + Kernel, cpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + out_strides[actual_axis-1], + in_strides[actual_axis-1], + in_strides[actual_axis], arrshape.ndim(), + oshape.ndim(), idxshape.ndim(), + arrshape[actual_axis], actual_axis); } else { - Kernel, cpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), - oshape.ndim(), idxshape.ndim(), - arrshape[actual_axis], actual_axis); + Kernel, cpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + out_strides[actual_axis-1], + in_strides[actual_axis-1], + in_strides[actual_axis], arrshape.ndim(), + oshape.ndim(), idxshape.ndim(), + arrshape[actual_axis], actual_axis); } } }); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 0b4c20bf2bb5..3ccf1f39d4f7 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -116,11 +116,8 @@ struct AddTakeGradRspDeterministicKernel { } }; -/*! \brief name the struct Take instead of take - * to avoid conflict with the take function in mshadow - */ template -struct TakeGPU { +struct TakeZeroAxisGPU { // assume that idx have been flattened to a 1-D tensor (N,) // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) // M is the number of columns of in_data and out_data @@ -180,8 +177,8 @@ void EmbeddingOpForwardDnsImpl(mshadow::Stream* s, Tensor wmat = weight.get(s); Tensor out = output.get_with_shape( Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - Kernel, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, - idx.dptr_, wmat.shape_[1], wmat.shape_[0]); + Kernel, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, + idx.dptr_, wmat.shape_[1], wmat.shape_[0]); }); }); } @@ -502,17 +499,17 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, } if (actual_axis == 0) { if (param.mode == take_::kClip) { - Kernel, gpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + Kernel, gpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); } else { - Kernel, gpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + Kernel, gpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); } } else { mshadow::Shape<10> in_strides; @@ -526,19 +523,27 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, out_strides[i] = stride; } if (param.mode == take_::kClip) { - Kernel, gpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); + Kernel, gpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + out_strides[actual_axis-1], + in_strides[actual_axis-1], + in_strides[actual_axis], + arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], + actual_axis); } else { - Kernel, gpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); + Kernel, gpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + out_strides[actual_axis-1], + in_strides[actual_axis-1], + in_strides[actual_axis], + arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], + actual_axis); } } }); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index bb524dd0f5e9..828d761fefd4 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -296,11 +296,11 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } -/*! \brief name the struct Take instead of take - * to avoid conflict with the take function in mshadow +/*! \brief name the struct TakeNonzeroAxis for general take when + * axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero */ template -struct Take { +struct TakeNonzeroAxis { /*! * \brief Map function for take operator * \param i global thread id @@ -315,28 +315,28 @@ struct Take { */ template MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data, - const IType* idx, - const mshadow::Shape<10> in_stride, - const mshadow::Shape<10> out_stride, + const IType* idx, const int out_prev_stride, + const int in_prev_stride, const int in_stride, const int in_ndims, const int out_ndims, const int idx_ndims, const int axis_dim, const int axis) { // i is the global flattened index in the output - const int64_t out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); - const int64_t out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]); - const int64_t out_mid_index = out_rest_index / in_stride[axis]; + const int64_t out_head_index = i / out_prev_stride; + const int64_t out_rest_index = i % out_prev_stride; + const int64_t out_mid_index = out_rest_index / in_stride; const int64_t out_tail_index = (axis == in_ndims - 1) ? - 0 : (out_rest_index % in_stride[axis]); + 0 : (out_rest_index % in_stride); int64_t idx_index = static_cast(idx[out_mid_index]); if (clip) { idx_index = (idx_index < 0) ? 0 : idx_index; idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; + } else { + idx_index %= axis_dim; + idx_index += (idx_index < 0) ? axis_dim : 0; } - idx_index %= axis_dim; - idx_index += (idx_index < 0) ? axis_dim : 0; const int64_t in_tail_index = out_tail_index; const int64_t in_head_index = out_head_index; - int64_t in_src_index = in_tail_index + idx_index * in_stride[axis]; - in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1]; + int64_t in_src_index = in_tail_index + idx_index * in_stride; + in_src_index += in_head_index * in_prev_stride; out_data[i] = in_data[in_src_index]; } };