Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1413] Adding Large Tensor support for sort operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Jun 7, 2019
1 parent 3f4f3d5 commit 7c4cb0f
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 76 deletions.
6 changes: 6 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def rand_coord_2d(x_low, x_high, y_low, y_high):
return x, y


def create_2d_tensor(rows, columns):
a = np.arange(0, rows).reshape(rows, 1)
b = np.broadcast_to(a, shape=(a.shape[0], columns))
return mx.nd.array(b, dtype=np.int64)


def np_reduce(dat, axis, keepdims, numpy_reduce_func):
"""Compatible reduce for old version of NumPy.
Expand Down
2 changes: 1 addition & 1 deletion src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ struct Kernel<OP, cpu> {
}
}
#else
for (size_t i = 0; i < N; ++i) {
for (index_t i = 0; i < static_cast<index_t>(N); ++i) {
OP::Map(i, args...);
}
#endif
Expand Down
144 changes: 76 additions & 68 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
Expand Down Expand Up @@ -118,6 +119,7 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
DMLC_DECLARE_FIELD(dtype)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
Expand All @@ -129,8 +131,8 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
};

inline void ParseTopKParam(const mxnet::TShape& src_shape, const TopKParam& param,
mxnet::TShape *target_shape, int *batch_size, int *element_num,
int *axis, int *k, bool *do_transpose, bool *is_ascend) {
mxnet::TShape *target_shape, size_t *batch_size, index_t *element_num,
int *axis, index_t *k, bool *do_transpose, bool *is_ascend) {
*do_transpose = false;
*k = param.k;
*is_ascend = param.is_ascend;
Expand Down Expand Up @@ -179,54 +181,54 @@ using namespace mshadow;

struct fill_ind_to_one {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) {
MSHADOW_XINLINE static void Map(index_t i, const index_t* indices, DType* out) {
out[indices[i]] = static_cast<DType>(1);
}
};

struct fill_ind {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const int* indices, const DType* val,
MSHADOW_XINLINE static void Map(index_t i, const index_t* indices, const DType* val,
int req, DType* out) {
KERNEL_ASSIGN(out[indices[i]], req, val[i]);
}
};

template<typename DType>
MSHADOW_FORCE_INLINE void TopKSort(const Tensor<cpu, 1, DType>& dat,
const Tensor<cpu, 1, int>& ind,
const Tensor<cpu, 1, index_t>& ind,
const Tensor<cpu, 1, char>& work,
int K, int N, bool is_ascend,
index_t K, index_t N, bool is_ascend,
Stream<cpu> *s) {
// Use full sort when K is relatively large.
const bool full_sort(K*8 > N);
// Batch size.
const int M(work.size(0)/(sizeof(DType)*N));
const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < M; ++i) {
for (index_t i = 0; i < M; ++i) {
// Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
DType *vals = reinterpret_cast<DType*>(work.dptr_);
DType *sorted_vals = dat.dptr_+i*N;
int *indices = ind.dptr_+i*N;
index_t *indices = ind.dptr_+i*N;
if (is_ascend) {
if (full_sort) {
std::sort(indices, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; });
[&](const index_t& i1, const index_t& i2){ return vals[i1] < vals[i2]; });
}
} else {
if (full_sort) {
std::sort(indices, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
[&](const index_t& i1, const index_t& i2){ return vals[i1] > vals[i2]; });
} else {
std::partial_sort(indices, indices+K, indices+N,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
[&](const index_t& i1, const index_t& i2){ return vals[i1] > vals[i2]; });
}
}
for (int j = 0; j < K; ++j) {
for (index_t j = 0; j < K; ++j) {
sorted_vals[j] = vals[indices[j]];
}
}
Expand Down Expand Up @@ -364,9 +366,8 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor<gpu, 1, DType>& dat,
* \param param the topk parameters
* \tparam xpu the device type.
* \tparam DType type of the output value/mask.
* \tparam IDType type of the output indices.
*/
template<typename xpu, typename DType, typename IDType>
template<typename xpu, typename DType>
void TopKImpl(const RunContext &ctx,
const Resource &resource,
const std::vector<OpReqType>& req,
Expand All @@ -380,20 +381,22 @@ void TopKImpl(const RunContext &ctx,
Tensor<xpu, 1, char> workspace;
Tensor<xpu, 1, char> temp_workspace;
Tensor<xpu, 1, DType> sorted_dat;
Tensor<xpu, 1, int> indices, sel_indices;
int batch_size, element_num; // number of batches + the size of each batch
Tensor<xpu, 1, index_t> indices, sel_indices;
size_t batch_size;
index_t element_num; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
int k = 0;
index_t k = 0;
size_t alignment = std::max(sizeof(DType), sizeof(int));
mxnet::TShape target_shape;
ParseTopKParam(src.shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
<< "'IDType' does not have a sufficient precision to represent the indices of the input array. "
<< "The total element_num is " << element_num << ", but the selected IDType can only represent "
<< mxnet::common::MaxIntegerValue<IDType>() << " elements";
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
<< "'index_t' does not have a sufficient precision to represent the indices of "
<< "the input array. The total element_num is "
<< element_num << ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<index_t>() << " elements";
Tensor<xpu, 3, DType> dat = src.FlatTo3D<xpu, DType>(axis, axis, s);
size_t temp_size = 0;
// Temp space needed by the gpu-based full sorts.
Expand All @@ -417,12 +420,12 @@ void TopKImpl(const RunContext &ctx,
sorted_dat = Tensor<xpu, 1, DType>(reinterpret_cast<DType*>(workspace_curr_ptr),
Shape1(src.Size()), s); // contain sorted dat
workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment);
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment);

if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
sel_indices = Tensor<xpu, 1, index_t>(reinterpret_cast<index_t*>(workspace_curr_ptr),
Shape1(batch_size * k), s);
workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment);
CHECK_EQ(sel_indices.CheckContiguous(), true);
Expand Down Expand Up @@ -454,8 +457,8 @@ void TopKImpl(const RunContext &ctx,
workspace_curr_ptr += temp_size;
}

mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
kWriteTo, indices.dptr_);
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, static_cast<index_t>(0),
static_cast<index_t>(1), kWriteTo, indices.dptr_);
CHECK_EQ(indices.CheckContiguous(), true);

// 2. Perform inplace batch sort.
Expand Down Expand Up @@ -494,30 +497,30 @@ void TopKImpl(const RunContext &ctx,
}
} else if (param.ret_typ == topk_enum::kReturnIndices) {
if (do_transpose) {
Tensor<xpu, 3, IDType> ret_indices = ret[0].FlatTo3D<xpu, IDType>(axis, axis, s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(transpose(
Tensor<xpu, 3, index_t> ret_indices = ret[0].FlatTo3D<xpu, index_t>(axis, axis, s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<index_t>(F<mshadow_op::mod>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)), element_num)));
} else {
Tensor<xpu, 2, IDType> ret_indices =
ret[0].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
Tensor<xpu, 2, index_t> ret_indices =
ret[0].get_with_shape<xpu, 2, index_t>(Shape2(batch_size, k), s);
ASSIGN_DISPATCH(ret_indices, req[0], tcast<index_t>(F<mshadow_op::mod>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k),
element_num)));
}
} else {
if (do_transpose) {
Tensor<xpu, 3, DType> ret_value = ret[0].FlatTo3D<xpu, DType>(axis, axis, s);
Tensor<xpu, 3, IDType> ret_indices = ret[1].FlatTo3D<xpu, IDType>(axis, axis, s);
Tensor<xpu, 3, index_t> ret_indices = ret[1].FlatTo3D<xpu, index_t>(axis, axis, s);
ASSIGN_DISPATCH(ret_value, req[0], transpose(
slice<2>(inplace_reshape(sorted_dat,
Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
0, k), Shape3(0, 2, 1)));
ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(transpose(
ASSIGN_DISPATCH(ret_indices, req[1], tcast<index_t>(F<mshadow_op::mod>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
Expand All @@ -526,11 +529,11 @@ void TopKImpl(const RunContext &ctx,
} else {
Tensor<xpu, 2, DType> ret_value =
ret[0].get_with_shape<xpu, 2, DType>(Shape2(batch_size, k), s);
Tensor<xpu, 2, IDType> ret_indices =
ret[1].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
Tensor<xpu, 2, index_t> ret_indices =
ret[1].get_with_shape<xpu, 2, index_t>(Shape2(batch_size, k), s);
ASSIGN_DISPATCH(ret_value, req[0],
slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k));
ASSIGN_DISPATCH(ret_indices, req[1], tcast<IDType>(F<mshadow_op::mod>(slice<1>(
ASSIGN_DISPATCH(ret_indices, req[1], tcast<index_t>(F<mshadow_op::mod>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num)));
}
}
Expand All @@ -545,13 +548,13 @@ void TopK(const nnvm::NodeAttrs& attrs,
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
MSHADOW_TYPE_SWITCH(param.dtype, index_t, {
TopKImpl<xpu, DType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
})
});
} else {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
TopKImpl<xpu, DType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param);
});
}
}
Expand All @@ -569,7 +572,7 @@ void Sort(const nnvm::NodeAttrs& attrs,
topk_param.k = 0;
topk_param.ret_typ = topk_enum::kReturnValue;
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKImpl<xpu, DType, int>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param);
TopKImpl<xpu, DType>(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param);
});
}

Expand All @@ -587,14 +590,14 @@ void ArgSort(const nnvm::NodeAttrs& attrs,
topk_param.dtype = param.dtype;
topk_param.ret_typ = topk_enum::kReturnIndices;
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
MSHADOW_TYPE_SWITCH(param.dtype, index_t, {
TopKImpl<xpu, DType>(ctx.run_ctx,
ctx.requested[0], req, inputs[0], outputs, topk_param);
});
});
}

template<typename xpu, typename DType, typename IDType>
template<typename xpu, typename DType>
void TopKBackwardImpl(const OpContext &ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
Expand All @@ -605,47 +608,51 @@ void TopKBackwardImpl(const OpContext &ctx,
using namespace mshadow::expr;
Stream<xpu> *s = ctx.run_ctx.get_stream<xpu>();
CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth);
int batch_size, element_num; // number of batches + the size of each batch
size_t batch_size;
index_t element_num; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
int k = 0;
index_t k = 0;
mxnet::TShape target_shape;
ParseTopKParam(outputs[0].shape_, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<IDType>())
<< "'IDType' does not have a sufficient precision to represent the indices of the input array. "
<< "The total element_num is " << element_num << ", but the selected IDType can only represent "
<< mxnet::common::MaxIntegerValue<IDType>() << " elements";
Tensor<xpu, 1, int> workspace =
ctx.requested[0].get_space_typed<xpu, 1, int>(Shape1(batch_size * k + batch_size), s);
Tensor<xpu, 1, int> sel_indices =
Tensor<xpu, 1, int>(workspace.dptr_, Shape1(batch_size * k), s);
Tensor<xpu, 1, int> batch_shift =
Tensor<xpu, 1, int>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);
CHECK_LE(element_num, mxnet::common::MaxIntegerValue<index_t>())
<< "'index_t' does not have a sufficient precision to represent "
<< "the indices of the input array. The total element_num is "
<< element_num << ", but the selected index_t can only represent "
<< mxnet::common::MaxIntegerValue<index_t>() << " elements";
Tensor<xpu, 1, index_t> workspace =
ctx.requested[0].get_space_typed<xpu, 1, index_t>(Shape1(batch_size * k + batch_size), s);
Tensor<xpu, 1, index_t> sel_indices =
Tensor<xpu, 1, index_t>(workspace.dptr_, Shape1(batch_size * k), s);
Tensor<xpu, 1, index_t> batch_shift =
Tensor<xpu, 1, index_t>(workspace.dptr_ + batch_size * k, Shape1(batch_size), s);

Tensor<xpu, 2, DType> out_grad =
inputs[0].get_with_shape<xpu, 2, DType>(Shape2(inputs[0].shape_.Size(), 1), s);
Tensor<xpu, 2, DType> in_grad =
outputs[0].get_with_shape<xpu, 2, DType>(Shape2(outputs[0].shape_.Size(), 1), s);
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1, 0, element_num, kWriteTo,
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size, 1,
static_cast<index_t>(0),
element_num, kWriteTo,
batch_shift.dptr_);
if (do_transpose) {
Tensor<xpu, 1, IDType> indices = inputs[2].FlatTo1D<xpu, IDType>(s);
Tensor<xpu, 1, index_t> indices = inputs[2].FlatTo1D<xpu, index_t>(s);
mxnet::TShape src_shape = outputs[0].shape_.FlatTo3D(axis);
sel_indices = reshape(transpose(
broadcast_to(inplace_reshape(batch_shift,
Shape3(src_shape[0], src_shape[2], 1)),
mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))),
Shape3(0, 2, 1)),
Shape1(batch_size * k));
sel_indices += tcast<int>(indices);
sel_indices += indices;
sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]),
Shape3(0, 2, 1));
} else {
Tensor<xpu, 2, IDType> indices =
inputs[2].get_with_shape<xpu, 2, IDType>(Shape2(batch_size, k), s);
sel_indices = reshape(tcast<int>(indices) +
Tensor<xpu, 2, index_t> indices =
inputs[2].get_with_shape<xpu, 2, index_t>(Shape2(batch_size, k), s);
sel_indices = reshape(indices +
broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)),
mxnet::TShape(Shape2(batch_size, k))),
Shape1(batch_size * k));
Expand Down Expand Up @@ -674,13 +681,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs,
const TopKParam& param = nnvm::get<TopKParam>(attrs.parsed);
if (param.ret_typ == topk_enum::kReturnBoth) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
TopKBackwardImpl<xpu, DType, IDType>(ctx, inputs, req, outputs, param);
MSHADOW_TYPE_SWITCH(param.dtype, index_t, {
TopKBackwardImpl<xpu, DType>(ctx, inputs, req, outputs, param);
});
});
} else if (param.ret_typ == topk_enum::kReturnValue) {
MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
TopKBackwardImpl<xpu, DType, int>(ctx, inputs, req, outputs, param);
TopKBackwardImpl<xpu, DType>(ctx, inputs, req, outputs, param);
});
} else {
LOG(FATAL) << "Not Implemented";
Expand Down Expand Up @@ -717,7 +724,7 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs,
CHECK(out_size == 1 || out_size == 2);
if (out_size > 1) {
if (param.ret_typ == topk_enum::kReturnValue) {
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
<< "Failed to set the type of ret_indices.";
} else {
CHECK(type_assign(&(*out_attrs)[1], param.dtype))
Expand Down Expand Up @@ -752,11 +759,12 @@ inline bool TopKShapeImpl(const TopKParam& param,
CHECK_EQ(out_attrs->size(), 2U);
}
mxnet::TShape& in_shape = (*in_attrs)[0];
int batch_size, element_num; // number of batches + the size of each batch
size_t batch_size;
index_t element_num; // number of batches + the size of each batch
int axis = 0;
bool do_transpose = false;
bool is_ascend = false;
int k = 0;
index_t k = 0;
mxnet::TShape target_shape;
ParseTopKParam(in_shape, param,
&target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend);
Expand Down Expand Up @@ -785,8 +793,8 @@ inline bool SortType(const nnvm::NodeAttrs& attrs,
size_t out_size = out_attrs->size();
CHECK_EQ(in_size, 1);
CHECK_EQ(out_size, 2);
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32))
<< "Failed to set the type of ret_indices to int32.";
CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64))
<< "Failed to set the type of ret_indices to int64.";
CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]="
<< (*in_attrs)[0];
CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]="
Expand Down
Loading

0 comments on commit 7c4cb0f

Please sign in to comment.