Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
clean up unnecessary space in topk
Browse files Browse the repository at this point in the history
  • Loading branch information
apeforest committed Aug 20, 2019
1 parent 2f2635e commit b431a59
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,18 +414,12 @@ void TopKImpl(const RunContext &ctx,
<< element_num << ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<index_t>() << " elements";
Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
size_t temp_size = 0;
// Temp space needed by the gpu-based full sorts.
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
// Temp space for cpu sorts.
temp_size = std::max(temp_size, static_cast<size_t>(sizeof(DType) * src.Size()));
// Temp space needed by the full sorts.
size_t temp_size = std::max(
mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size())
);

size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+ PadBytes(sizeof(index_t) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
Expand Down

0 comments on commit b431a59

Please sign in to comment.