Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix a memory misalignment in topk operator (#15948)
Browse files Browse the repository at this point in the history
* fix alignment

* use correct type for shape index

* clean up unnecessary space in topk

* fix lint

* add additional temp space

* address reviewer comment

* fix incorrect nidex type
  • Loading branch information
apeforest authored and sxjscience committed Aug 23, 2019
1 parent fade159 commit 73a692e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions 3rdparty/mshadow/mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ struct Shape {
* \param idx dimension index
* \return the corresponding dimension size
*/
MSHADOW_XINLINE index_t &operator[](index_t idx) {
MSHADOW_XINLINE index_t &operator[](int idx) {
return shape_[idx];
}
/*!
* \brief get corresponding index
* \param idx dimension index
* \return the corresponding dimension size
*/
MSHADOW_XINLINE const index_t &operator[](index_t idx) const {
MSHADOW_XINLINE const index_t &operator[](int idx) const {
return shape_[idx];
}
/*!
Expand Down Expand Up @@ -484,7 +484,7 @@ struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
* \param idx the dimension count from the highest dimensin
* \return the size
*/
MSHADOW_XINLINE index_t size(index_t idx) const {
MSHADOW_XINLINE index_t size(int idx) const {
return shape_[idx];
}
/*!
Expand Down
24 changes: 12 additions & 12 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ void TopKImpl(const RunContext &ctx,
bool do_transpose = false;
bool is_ascend = false;
index_t k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(int));
size_t alignment = std::max(sizeof(DType), sizeof(index_t));
mxnet::TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Expand All @@ -414,30 +414,30 @@ void TopKImpl(const RunContext &ctx,
<< element_num << ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<index_t>() << " elements";
Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
size_t temp_size = 0;
// Temp space needed by the gpu-based full sorts.
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<int, int, xpu>(src.Size()));
temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<int, DType, xpu>(src.Size()));
// Temp space needed by the full sorts.
size_t temp_size = std::max(
mxnet::op::SortByKeyWorkspaceSize<index_t, DType, xpu>(src.Size()),
mxnet::op::SortByKeyWorkspaceSize<DType, index_t, xpu>(src.Size()));

temp_size = std::max(temp_size,
mxnet::op::SortByKeyWorkspaceSize<DType, int, xpu>(src.Size()));
mxnet::op::SortByKeyWorkspaceSize<index_t, index_t, xpu>(src.Size()));
// Additional temp space for gpu full sorts for batch ids.
temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment);
// Temp space for cpu sorts.
temp_size = std::max(temp_size, static_cast<size_t>(sizeof(DType) * src.Size()));
temp_size = std::max(temp_size, sizeof(DType) * src.Size());

size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment)
+ PadBytes(sizeof(index_t) * src.Size(), alignment);
if (param.ret_typ == topk_enum::kReturnMask) {
workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment);
workspace_size += PadBytes(sizeof(index_t) * batch_size * k, alignment);
}
workspace = resource.get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
char* workspace_curr_ptr = workspace.dptr_;
sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s); // contain sorted dat
Shape1(src.Size()), s); // contain sorted dat
workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment);

if (param.ret_typ == topk_enum::kReturnMask) {
Expand Down

0 comments on commit 73a692e

Please sign in to comment.