From 9baa31dbae56e8a920a2c282e993508527a351d1 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sat, 16 Mar 2019 00:34:35 +0000 Subject: [PATCH] speedup SequenceMask on GPU --- src/operator/sequence_mask-inl.h | 79 +++++++------------------------- src/operator/sequence_mask.cc | 64 ++++++++++++++++++++++++++ src/operator/sequence_mask.cu | 59 ++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 62 deletions(-) diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h index 372cf57e03dc..05a9424fd891 100644 --- a/src/operator/sequence_mask-inl.h +++ b/src/operator/sequence_mask-inl.h @@ -65,70 +65,24 @@ struct SequenceMaskParam : public dmlc::Parameter { } }; -// (seqlen, batch, rest) case -template -struct SequenceMask0Kernel { - template - MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, - index_t max_s_len, index_t batch_size, - index_t restsize, DType value) { - const index_t seqpos = static_cast(idx[b]); -#pragma unroll - for (index_t s = seqpos; s < max_s_len; ++s) { - index_t incr = (s * batch_size * restsize) + (b * restsize); -#pragma unroll - for (index_t r = 0; r < restsize; ++r) - KERNEL_ASSIGN(in[incr + r], req, value); - } - } -}; - -// (batch, seqlen, rest) case -template -struct SequenceMask1Kernel { - template - MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, - index_t max_s_len, index_t batch_size, - index_t restsize, DType value) { - const index_t seqpos = static_cast(idx[b]); -#pragma unroll - for (index_t s = seqpos; s < max_s_len; ++s) { - index_t incr = (b * max_s_len * restsize) + (s * restsize); -#pragma unroll - for (index_t r = 0; r < restsize; ++r) - KERNEL_ASSIGN(in[incr + r], req, value); - } - } -}; +template +void SequenceMaskExec(const mshadow::Tensor &data, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s, + int axis, DType val); +#ifdef __CUDACC__ +template +void SequenceMaskExec(const mshadow::Tensor &data, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s, + int axis, DType val); +#endif template class SequenceMaskOp : public Operator { public: explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; } - void sequence_mask(const mshadow::Tensor &data, - const mshadow::Tensor &indices, - const OpReqType req, mshadow::Stream *const s, - DType val) { - using namespace mshadow; - using namespace mshadow::expr; - - index_t batch = indices.size(0); - index_t max_seq_len = data.size(param_.axis); - index_t restsize = data.size(2); - - MXNET_ASSIGN_REQ_SWITCH(req, req_type, { - if (param_.axis == 1) - mxnet_op::Kernel, xpu>::Launch( - s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, - val); - else - mxnet_op::Kernel, xpu>::Launch( - s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, - val); - }); - } - virtual void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data, @@ -155,8 +109,8 @@ class SequenceMaskOp : public Operator { if (param_.use_sequence_length) { Tensor indices = in_data[seq_mask::kSequenceLength].get(s); - sequence_mask(out, indices, req[seq_mask::kOut], s, - static_cast(param_.value)); + SequenceMaskExec(out, indices, req[seq_mask::kOut], s, + param_.axis, static_cast(param_.value)); } } @@ -198,11 +152,12 @@ class SequenceMaskOp : public Operator { s3, s); out_g_temp = F(out_g); out_g = out_g_temp; - sequence_mask(out_g, indices, kWriteInplace, s, DType(0.)); + SequenceMaskExec(out_g, indices, kWriteInplace, s, param_.axis, DType(0.)); Assign(data_g, kAddTo, F(out_g)); } else { Assign(data_g, req[seq_mask::kData], F(out_g)); - sequence_mask(data_g, indices, req[seq_mask::kData], s, DType(0.)); + SequenceMaskExec( + data_g, indices, req[seq_mask::kData], s, param_.axis, DType(0.)); } } } diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc index c3bf12d3a862..d28fbdf1a25c 100644 --- a/src/operator/sequence_mask.cc +++ b/src/operator/sequence_mask.cc @@ -27,6 +27,70 @@ namespace mxnet { namespace op { + +// (seqlen, batch, rest) case +template +struct SequenceMask0CPUKernel { + template + MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + const index_t seqpos = static_cast(idx[b]); +#pragma unroll + for (index_t s = seqpos; s < max_s_len; ++s) { + index_t incr = (s * batch_size * restsize) + (b * restsize); +#pragma unroll + for (index_t r = 0; r < restsize; ++r) + KERNEL_ASSIGN(in[incr + r], req, value); + } + } +}; + +// (batch, seqlen, rest) case +template +struct SequenceMask1CPUKernel { + template + MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + const index_t seqpos = static_cast(idx[b]); +#pragma unroll + for (index_t s = seqpos; s < max_s_len; ++s) { + index_t incr = (b * max_s_len * restsize) + (s * restsize); +#pragma unroll + for (index_t r = 0; r < restsize; ++r) + KERNEL_ASSIGN(in[incr + r], req, value); + } + } +}; + +template +void SequenceMaskExec( + const mshadow::Tensor &data, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s, + int axis, DType val) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + index_t batch = indices.size(0); + index_t max_seq_len = data.size(axis); + index_t restsize = data.size(2); + + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + if (axis == 1) { + Kernel, cpu>::Launch( + s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + } else { + Kernel, cpu>::Launch( + s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + } + }); +} + template <> Operator *CreateOp(SequenceMaskParam param, int dtype, int itype) { Operator *op = nullptr; diff --git a/src/operator/sequence_mask.cu b/src/operator/sequence_mask.cu index cec627c4c697..6728cf958860 100644 --- a/src/operator/sequence_mask.cu +++ b/src/operator/sequence_mask.cu @@ -29,6 +29,65 @@ namespace mxnet { namespace op { +// (seqlen, batch, rest) case +template +struct SequenceMask0GPUKernel { + template + MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + index_t b = i / restsize % batch_size; + const index_t seqpos = static_cast(idx[b]); + index_t s = i / restsize / batch_size; + if (s >= seqpos) { + KERNEL_ASSIGN(in[i], req, value); + } + } +}; + +// (batch, seqlen, rest) case +template +struct SequenceMask1GPUKernel { + template + MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx, + index_t max_s_len, index_t batch_size, + index_t restsize, DType value) { + index_t b = i / restsize / max_s_len; + const index_t seqpos = static_cast(idx[b]); + index_t s = i / restsize % max_s_len; + if (s >= seqpos) { + KERNEL_ASSIGN(in[i], req, value); + } + } +}; + +template +void SequenceMaskExec( + const mshadow::Tensor &data, + const mshadow::Tensor &indices, + const OpReqType req, mshadow::Stream *const s, + int axis, DType val) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + + index_t batch = indices.size(0); + index_t max_seq_len = data.size(axis); + index_t restsize = data.size(2); + + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + if (axis == 1) { + Kernel, gpu>::Launch( + s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + } else { + Kernel, gpu>::Launch( + s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize, + val); + } + }); +} + template <> Operator *CreateOp(SequenceMaskParam param, int dtype, int itype) { Operator *op = NULL; MSHADOW_TYPE_SWITCH(dtype, DType, {