Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Faster general take #16615

Merged
merged 5 commits into from
Oct 26, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace mxnet {
namespace op {

template<bool clip = true>
struct TakeCPU {
struct TakeZeroAxisCPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
Expand Down Expand Up @@ -88,8 +88,9 @@ void EmbeddingOpForwardDnsImpl<cpu>(mshadow::Stream<cpu>* s,
Tensor<cpu, 2, DType> wmat = weight.get<cpu, 2, DType>(s);
Tensor<cpu, 2, DType> out = output.get_with_shape<cpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Kernel<TakeCPU<true>, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_,
wmat.dptr_, idx.dptr_,
wmat.shape_[1], wmat.shape_[0]);
});
});
}
Expand Down Expand Up @@ -308,17 +309,17 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
} else {
Kernel<TakeCPU<false>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisCPU<false>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
}
} else {
mshadow::Shape<10> in_strides;
Expand All @@ -332,21 +333,25 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
out_strides[i] = stride;
}
if (param.mode == take_::kClip) {
Kernel<Take<true>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis], arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
} else {
Kernel<Take<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis], arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
}
}
});
Expand Down
61 changes: 33 additions & 28 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,8 @@ struct AddTakeGradRspDeterministicKernel {
}
};

/*! \brief name the struct Take instead of take
* to avoid conflict with the take function in mshadow
*/
template<bool clip = true>
struct TakeGPU {
struct TakeZeroAxisGPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
Expand Down Expand Up @@ -180,8 +177,8 @@ void EmbeddingOpForwardDnsImpl<gpu>(mshadow::Stream<gpu>* s,
Tensor<gpu, 2, DType> wmat = weight.get<gpu, 2, DType>(s);
Tensor<gpu, 2, DType> out = output.get_with_shape<gpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
Kernel<TakeZeroAxisGPU<true>, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
});
});
}
Expand Down Expand Up @@ -502,17 +499,17 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisGPU<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
} else {
Kernel<TakeGPU<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisGPU<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
}
} else {
mshadow::Shape<10> in_strides;
Expand All @@ -526,19 +523,27 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
out_strides[i] = stride;
}
if (param.mode == take_::kClip) {
Kernel<Take<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis],
arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis],
actual_axis);
} else {
Kernel<Take<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis],
arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis],
actual_axis);
}
}
});
Expand Down
28 changes: 14 additions & 14 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

/*! \brief name the struct Take instead of take
* to avoid conflict with the take function in mshadow
/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
*/
template<bool clip = true>
struct Take {
struct TakeNonzeroAxis {
/*!
* \brief Map function for take operator
* \param i global thread id
Expand All @@ -315,28 +315,28 @@ struct Take {
*/
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
const IType* idx,
const mshadow::Shape<10> in_stride,
const mshadow::Shape<10> out_stride,
const IType* idx, const int out_prev_stride,
const int in_prev_stride, const int in_stride,
const int in_ndims, const int out_ndims, const int idx_ndims,
const int axis_dim, const int axis) {
// i is the global flattened index in the output
const int64_t out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
const int64_t out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]);
const int64_t out_mid_index = out_rest_index / in_stride[axis];
const int64_t out_head_index = i / out_prev_stride;
const int64_t out_rest_index = i % out_prev_stride;
const int64_t out_mid_index = out_rest_index / in_stride;
const int64_t out_tail_index = (axis == in_ndims - 1) ?
0 : (out_rest_index % in_stride[axis]);
0 : (out_rest_index % in_stride);
int64_t idx_index = static_cast<int64_t>(idx[out_mid_index]);
if (clip) {
idx_index = (idx_index < 0) ? 0 : idx_index;
idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index;
} else {
idx_index %= axis_dim;
idx_index += (idx_index < 0) ? axis_dim : 0;
}
idx_index %= axis_dim;
idx_index += (idx_index < 0) ? axis_dim : 0;
const int64_t in_tail_index = out_tail_index;
const int64_t in_head_index = out_head_index;
int64_t in_src_index = in_tail_index + idx_index * in_stride[axis];
in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1];
int64_t in_src_index = in_tail_index + idx_index * in_stride;
in_src_index += in_head_index * in_prev_stride;
out_data[i] = in_data[in_src_index];
}
};
Expand Down