diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index fd124c0e2a63..3889b5953569 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -414,18 +414,12 @@ void TopKImpl(const RunContext &ctx, << element_num << ", but the selected index_t can only represent " << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); - size_t temp_size = 0; - // Temp space needed by the gpu-based full sorts. - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - // Additional temp space for gpu full sorts for batch ids. - temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); - // Temp space for cpu sorts. - temp_size = std::max(temp_size, static_cast(sizeof(DType) * src.Size())); + // Temp space needed by the full sorts. + size_t temp_size = std::max( + mxnet::op::SortByKeyWorkspaceSize(src.Size()), + mxnet::op::SortByKeyWorkspaceSize(src.Size()) + ); + size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) {