Skip to content

Commit

Permalink
Speed up SequenceReverse (apache#14627)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx authored and szha committed Apr 29, 2019
1 parent c18381d commit 64287dd
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,40 +64,37 @@ struct SequenceReverseParam : public dmlc::Parameter<SequenceReverseParam> {
}
};

template <OpReqType req>
struct ReverseKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
const DType *const in_data,
const OpReqType req,
const index_t max_seq_len,
const index_t batch_size,
const index_t other_dim, const index_t numel,
const IType *const indices) {
for (index_t batch = 0; batch < batch_size; ++batch) {
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && i < static_cast<int>(padded_periods)) {
const int padded_in_offset =
(i + num_seq) * batch_size * other_dim + batch * other_dim;

for (index_t j = 0; j < other_dim; ++j) {
KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
in_data[padded_in_offset + j]);
}
}
// unpadded part
if (i < static_cast<int>(num_seq)) {
const int in_offset = i * batch_size * other_dim + batch * other_dim;
const int out_offset =
numel - (i + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

for (index_t j = 0; j < other_dim; ++j) {
KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]);
}
}
const index_t batch = i / (max_seq_len * other_dim);
const int id = (i / other_dim) % max_seq_len;
const index_t j = i % other_dim;
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && id < static_cast<int>(padded_periods)) {
const int padded_in_offset =
(id + num_seq) * batch_size * other_dim + batch * other_dim;

KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
in_data[padded_in_offset + j]);
}
// unpadded part
if (id < static_cast<int>(num_seq)) {
const int in_offset = id * batch_size * other_dim + batch * other_dim;
const int out_offset =
numel - (id + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]);
}
}
};
Expand All @@ -118,9 +115,11 @@ class SequenceReverseOp : public Operator {
const index_t other_dim = data.size(2);
const index_t tensor_numel = data.shape_.Size();

mxnet_op::Kernel<ReverseKernel, xpu>::Launch(
s, max_seq_len, out.dptr_, data.dptr_, req, max_seq_len, batch_size,
other_dim, tensor_numel, indices);
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<ReverseKernel<req_type>, xpu>::Launch(
s, max_seq_len * batch_size * other_dim, out.dptr_, data.dptr_,
max_seq_len, batch_size, other_dim, tensor_numel, indices);
});
}

virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
Expand Down

0 comments on commit 64287dd

Please sign in to comment.