From 42746bc73e8bcb75bfcadd1398e6f71bc170fa10 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 22 Aug 2019 22:02:59 -0700 Subject: [PATCH] Fix a memory misalignment in topk operator (#15948) * fix alignment * use correct type for shape index * clean up unnecessary space in topk * fix lint * add additional temp space * address reviewer comment * fix incorrect nidex type --- src/operator/tensor/ordering_op-inl.h | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 1dda90104205..bd27441c1c73 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -385,8 +385,8 @@ void TopKImpl(const RunContext &ctx, int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; - size_t alignment = std::max(sizeof(DType), sizeof(int)); + index_t k = 0; + size_t alignment = std::max(sizeof(DType), sizeof(index_t)); mxnet::TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); @@ -395,31 +395,31 @@ void TopKImpl(const RunContext &ctx, << "The total element_num is " << element_num << ", but the selected IDType can only represent " << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); - size_t temp_size = 0; - // Temp space needed by the gpu-based full sorts. - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + // Temp space needed by the full sorts. + size_t temp_size = std::max( + mxnet::op::SortByKeyWorkspaceSize(src.Size()), + mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += PadBytes(sizeof(int) * src.Size(), alignment); // Temp space for cpu sorts. - temp_size = std::max(temp_size, static_cast(sizeof(DType) * src.Size())); + temp_size = std::max(temp_size, sizeof(DType) * src.Size()); + size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(int) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { - workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment); + workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment); } workspace = resource.get_space_typed(Shape1(workspace_size), s); char* workspace_curr_ptr = workspace.dptr_; sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // contain sorted dat + Shape1(src.Size()), s); // contain sorted dat workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment); - indices = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // indices in the original matrix - workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment); + indices = Tensor(reinterpret_cast(workspace_curr_ptr), + Shape1(src.Size()), s); // indices in the original matrix + workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr),