Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Speed up SequenceReverse #14627

Merged
merged 2 commits into from
Apr 29, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 28 additions & 29 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,40 +64,37 @@ struct SequenceReverseParam : public dmlc::Parameter<SequenceReverseParam> {
}
};

template <OpReqType req>
struct ReverseKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
const DType *const in_data,
const OpReqType req,
const index_t max_seq_len,
const index_t batch_size,
const index_t other_dim, const index_t numel,
const IType *const indices) {
for (index_t batch = 0; batch < batch_size; ++batch) {
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && i < static_cast<int>(padded_periods)) {
const int padded_in_offset =
(i + num_seq) * batch_size * other_dim + batch * other_dim;

for (index_t j = 0; j < other_dim; ++j) {
KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
in_data[padded_in_offset + j]);
}
}
// unpadded part
if (i < static_cast<int>(num_seq)) {
const int in_offset = i * batch_size * other_dim + batch * other_dim;
const int out_offset =
numel - (i + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

for (index_t j = 0; j < other_dim; ++j) {
KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]);
}
}
const index_t batch = i / (max_seq_len * other_dim);
const int id = (i / other_dim) % max_seq_len;
const index_t j = i % other_dim;
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && id < static_cast<int>(padded_periods)) {
const int padded_in_offset =
(id + num_seq) * batch_size * other_dim + batch * other_dim;

KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
szha marked this conversation as resolved.
Show resolved Hide resolved
in_data[padded_in_offset + j]);
}
// unpadded part
if (id < static_cast<int>(num_seq)) {
const int in_offset = id * batch_size * other_dim + batch * other_dim;
const int out_offset =
numel - (id + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

KERNEL_ASSIGN(out_data[out_offset + j], req, in_data[in_offset + j]);
}
}
};
Expand All @@ -118,9 +115,11 @@ class SequenceReverseOp : public Operator {
const index_t other_dim = data.size(2);
const index_t tensor_numel = data.shape_.Size();

mxnet_op::Kernel<ReverseKernel, xpu>::Launch(
s, max_seq_len, out.dptr_, data.dptr_, req, max_seq_len, batch_size,
other_dim, tensor_numel, indices);
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<ReverseKernel<req_type>, xpu>::Launch(
s, max_seq_len * batch_size * other_dim, out.dptr_, data.dptr_,
max_seq_len, batch_size, other_dim, tensor_numel, indices);
});
}

virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
Expand Down