Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Accelerate the performance of topk for CPU side (#12085)
Browse files Browse the repository at this point in the history
* Accelerate the performance of topk for CPU side

* Add comments for the code changes
  • Loading branch information
ciyongch authored and szha committed Aug 13, 2018
1 parent 9933d7a commit 95dd95c
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,13 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
// Use full sort when K is relatively large.
const bool full_sort(K*8 > N);
// Batch size.
const int M(dat.size(0)/N);
const int M(work.size(0)/(sizeof(real_t)*N));
const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < M; ++i) {
real_t *vals = dat.dptr_;
// Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
real_t *vals = reinterpret_cast<real_t*>(work.dptr_);
real_t *sorted_vals = dat.dptr_+i*N;
int *indices = ind.dptr_+i*N;
if (is_ascend) {
if (full_sort) {
Expand All @@ -193,11 +195,9 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
}
}
real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
for (int j = 0; j < K; ++j) {
buff[j] = vals[indices[j]];
sorted_vals[j] = vals[indices[j]];
}
std::copy(buff, buff+K, &vals[i*N]);
}
}

Expand Down Expand Up @@ -380,16 +380,7 @@ void TopKImpl(RunContext ctx,
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += sizeof(int) * src.Size();
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
kWriteTo, indices.dptr_);

CHECK_EQ(sorted_dat.CheckContiguous(), true);
CHECK_EQ(indices.CheckContiguous(), true);
if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(batch_size * k), s);
Expand All @@ -401,15 +392,47 @@ void TopKImpl(RunContext ctx,
CHECK_EQ(sel_indices.CheckContiguous(), true);
CHECK_EQ(mask_val.CheckContiguous(), true);
}
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;

if (std::is_same<xpu, cpu>::value) {
Tensor<xpu, 1, real_t> flattened_data;
if (do_transpose) {
flattened_data = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
Shape1(src.Size()), s);
workspace_curr_ptr += sizeof(real_t) * src.Size();
flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
CHECK_EQ(flattened_data.CheckContiguous(), true);
} else {
flattened_data = src.FlatTo1D<xpu, real_t>(s);
}
// `temp_workspace` stores the flattened data
temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
Shape1(sizeof(real_t)*src.Size()), s);
CHECK_EQ(temp_workspace.CheckContiguous(), true);
} else {
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
CHECK_EQ(sorted_dat.CheckContiguous(), true);
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;
}

mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
kWriteTo, indices.dptr_);
CHECK_EQ(indices.CheckContiguous(), true);

// 2. Perform inplace batch sort.
// After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
// up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
// `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
// a temporal buffer for GPU device.
TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);

// 3. Assign results to the ret blob
// When returning indices, only update(modulo) required elements instead of full elements
// to avoid redundant calculation.
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 2, real_t> ret_mask =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
Expand All @@ -427,7 +450,6 @@ void TopKImpl(RunContext ctx,
}
IndexFill(ret_mask, sel_indices, mask_val);
} else if (param.ret_typ == topk_enum::kReturnIndices) {
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_indices = tcast<real_t>(transpose(
Expand All @@ -437,14 +459,15 @@ void TopKImpl(RunContext ctx,
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
} else {
Tensor<xpu, 2, real_t> ret_indices =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
}
} else {
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
Expand All @@ -460,6 +483,7 @@ void TopKImpl(RunContext ctx,
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
} else {
Tensor<xpu, 2, real_t> ret_value =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
Expand All @@ -468,6 +492,7 @@ void TopKImpl(RunContext ctx,
ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
}
}
}
Expand Down

0 comments on commit 95dd95c

Please sign in to comment.