Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Faster general take (#16615)
Browse files Browse the repository at this point in the history
* Sped up perf of take op when axis != 0

* Formatting and syntax fixes

* Rename Take to specify axis

* Fix line length lint errors
  • Loading branch information
blchu authored and ptrendx committed Oct 26, 2019
1 parent 7862738 commit 0712f00
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 69 deletions.
59 changes: 32 additions & 27 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace mxnet {
namespace op {

template<bool clip = true>
struct TakeCPU {
struct TakeZeroAxisCPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
Expand Down Expand Up @@ -88,8 +88,9 @@ void EmbeddingOpForwardDnsImpl<cpu>(mshadow::Stream<cpu>* s,
Tensor<cpu, 2, DType> wmat = weight.get<cpu, 2, DType>(s);
Tensor<cpu, 2, DType> out = output.get_with_shape<cpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Kernel<TakeCPU<true>, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_,
wmat.dptr_, idx.dptr_,
wmat.shape_[1], wmat.shape_[0]);
});
});
}
Expand Down Expand Up @@ -308,17 +309,17 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisCPU<true>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
} else {
Kernel<TakeCPU<false>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisCPU<false>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
}
} else {
mshadow::Shape<10> in_strides;
Expand All @@ -332,21 +333,25 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
out_strides[i] = stride;
}
if (param.mode == take_::kClip) {
Kernel<Take<true>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis], arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
} else {
Kernel<Take<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis], arrshape.ndim(),
oshape.ndim(), idxshape.ndim(),
arrshape[actual_axis], actual_axis);
}
}
});
Expand Down
61 changes: 33 additions & 28 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,8 @@ struct AddTakeGradRspDeterministicKernel {
}
};

/*! \brief name the struct Take instead of take
* to avoid conflict with the take function in mshadow
*/
template<bool clip = true>
struct TakeGPU {
struct TakeZeroAxisGPU {
// assume that idx have been flattened to a 1-D tensor (N,)
// assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
// M is the number of columns of in_data and out_data
Expand Down Expand Up @@ -180,8 +177,8 @@ void EmbeddingOpForwardDnsImpl<gpu>(mshadow::Stream<gpu>* s,
Tensor<gpu, 2, DType> wmat = weight.get<gpu, 2, DType>(s);
Tensor<gpu, 2, DType> out = output.get_with_shape<gpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
Kernel<TakeZeroAxisGPU<true>, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
});
});
}
Expand Down Expand Up @@ -502,17 +499,17 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
}
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisGPU<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
} else {
Kernel<TakeGPU<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
Kernel<TakeZeroAxisGPU<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
}
} else {
mshadow::Shape<10> in_strides;
Expand All @@ -526,19 +523,27 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
out_strides[i] = stride;
}
if (param.mode == take_::kClip) {
Kernel<Take<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<true>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis],
arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis],
actual_axis);
} else {
Kernel<Take<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis], actual_axis);
Kernel<TakeNonzeroAxis<false>, gpu>::Launch(s, oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis-1],
in_strides[actual_axis-1],
in_strides[actual_axis],
arrshape.ndim(), oshape.ndim(),
idxshape.ndim(), arrshape[actual_axis],
actual_axis);
}
}
});
Expand Down
28 changes: 14 additions & 14 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

/*! \brief name the struct Take instead of take
* to avoid conflict with the take function in mshadow
/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
*/
template<bool clip = true>
struct Take {
struct TakeNonzeroAxis {
/*!
* \brief Map function for take operator
* \param i global thread id
Expand All @@ -315,28 +315,28 @@ struct Take {
*/
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* in_data,
const IType* idx,
const mshadow::Shape<10> in_stride,
const mshadow::Shape<10> out_stride,
const IType* idx, const int out_prev_stride,
const int in_prev_stride, const int in_stride,
const int in_ndims, const int out_ndims, const int idx_ndims,
const int axis_dim, const int axis) {
// i is the global flattened index in the output
const int64_t out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
const int64_t out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]);
const int64_t out_mid_index = out_rest_index / in_stride[axis];
const int64_t out_head_index = i / out_prev_stride;
const int64_t out_rest_index = i % out_prev_stride;
const int64_t out_mid_index = out_rest_index / in_stride;
const int64_t out_tail_index = (axis == in_ndims - 1) ?
0 : (out_rest_index % in_stride[axis]);
0 : (out_rest_index % in_stride);
int64_t idx_index = static_cast<int64_t>(idx[out_mid_index]);
if (clip) {
idx_index = (idx_index < 0) ? 0 : idx_index;
idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index;
} else {
idx_index %= axis_dim;
idx_index += (idx_index < 0) ? axis_dim : 0;
}
idx_index %= axis_dim;
idx_index += (idx_index < 0) ? axis_dim : 0;
const int64_t in_tail_index = out_tail_index;
const int64_t in_head_index = out_head_index;
int64_t in_src_index = in_tail_index + idx_index * in_stride[axis];
in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1];
int64_t in_src_index = in_tail_index + idx_index * in_stride;
in_src_index += in_head_index * in_prev_stride;
out_data[i] = in_data[in_src_index];
}
};
Expand Down

0 comments on commit 0712f00

Please sign in to comment.