From 73a692e252fa5625bf002366653f31f46b1751e8 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 22 Aug 2019 22:02:59 -0700 Subject: [PATCH] Fix a memory misalignment in topk operator (#15948) * fix alignment * use correct type for shape index * clean up unnecessary space in topk * fix lint * add additional temp space * address reviewer comment * fix incorrect nidex type --- 3rdparty/mshadow/mshadow/tensor.h | 6 +++--- src/operator/tensor/ordering_op-inl.h | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/3rdparty/mshadow/mshadow/tensor.h b/3rdparty/mshadow/mshadow/tensor.h index 0d662621aa4d..ad29e751a050 100755 --- a/3rdparty/mshadow/mshadow/tensor.h +++ b/3rdparty/mshadow/mshadow/tensor.h @@ -69,7 +69,7 @@ struct Shape { * \param idx dimension index * \return the corresponding dimension size */ - MSHADOW_XINLINE index_t &operator[](index_t idx) { + MSHADOW_XINLINE index_t &operator[](int idx) { return shape_[idx]; } /*! @@ -77,7 +77,7 @@ struct Shape { * \param idx dimension index * \return the corresponding dimension size */ - MSHADOW_XINLINE const index_t &operator[](index_t idx) const { + MSHADOW_XINLINE const index_t &operator[](int idx) const { return shape_[idx]; } /*! @@ -484,7 +484,7 @@ struct Tensor: public TRValue, * \param idx the dimension count from the highest dimensin * \return the size */ - MSHADOW_XINLINE index_t size(index_t idx) const { + MSHADOW_XINLINE index_t size(int idx) const { return shape_[idx]; } /*! diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 880acf1f4cae..b36d79acfc7b 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -404,7 +404,7 @@ void TopKImpl(const RunContext &ctx, bool do_transpose = false; bool is_ascend = false; index_t k = 0; - size_t alignment = std::max(sizeof(DType), sizeof(int)); + size_t alignment = std::max(sizeof(DType), sizeof(index_t)); mxnet::TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); @@ -414,30 +414,30 @@ void TopKImpl(const RunContext &ctx, << element_num << ", but the selected index_t can only represent " << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); - size_t temp_size = 0; - // Temp space needed by the gpu-based full sorts. - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); - temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + // Temp space needed by the full sorts. + size_t temp_size = std::max( + mxnet::op::SortByKeyWorkspaceSize(src.Size()), + mxnet::op::SortByKeyWorkspaceSize(src.Size())); + temp_size = std::max(temp_size, - mxnet::op::SortByKeyWorkspaceSize(src.Size())); + mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); // Temp space for cpu sorts. - temp_size = std::max(temp_size, static_cast(sizeof(DType) * src.Size())); + temp_size = std::max(temp_size, sizeof(DType) * src.Size()); + size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { - workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment); + workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment); } workspace = resource.get_space_typed(Shape1(workspace_size), s); char* workspace_curr_ptr = workspace.dptr_; sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // contain sorted dat + Shape1(src.Size()), s); // contain sorted dat workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment); indices = Tensor(reinterpret_cast(workspace_curr_ptr), - Shape1(src.Size()), s); // indices in the original matrix + Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) {