From e62d8da4d017cadf0d7b632acabc71dbfcc228c3 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Thu, 6 Jun 2019 23:39:11 +0000 Subject: [PATCH] [MXNET-1413] Adding Large Tensor support for sort operators --- python/mxnet/test_utils.py | 6 + src/operator/mxnet_op.h | 2 +- src/operator/tensor/ordering_op-inl.h | 190 +++++++++++++------------- tests/nightly/test_large_array.py | 32 ++++- 4 files changed, 129 insertions(+), 101 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bd102412c6e2..1219624e185a 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -425,6 +425,12 @@ def rand_coord_2d(x_low, x_high, y_low, y_high): return x, y +def create_2d_tensor(rows, columns): + a = np.arange(0, rows).reshape(rows, 1) + b = np.broadcast_to(a, shape=(a.shape[0], columns)) + return mx.nd.array(b, dtype=np.int64) + + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f17b708a7687..f5c19c742a03 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -630,7 +630,7 @@ struct Kernel { } } #else - for (size_t i = 0; i < N; ++i) { + for (index_t i = 0; i < static_cast(N); ++i) { OP::Map(i, args...); } #endif diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 1dda90104205..76c24aba29b0 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -81,11 +81,7 @@ struct TopKParam : public dmlc::Parameter { .describe("Whether to choose k largest or k smallest elements." " Top K largest elements will be chosen if set to false."); DMLC_DECLARE_FIELD(dtype) - .add_enum("uint8", mshadow::kUint8) - .add_enum("int32", mshadow::kInt32) - .add_enum("float16", mshadow::kFloat16) - .add_enum("float32", mshadow::kFloat32) - .add_enum("float64", mshadow::kFloat64) + MXNET_ADD_ALL_TYPES .set_default(mshadow::kFloat32) .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". " "An error will be raised if the selected data type cannot precisely represent the " @@ -116,11 +112,7 @@ struct ArgSortParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(is_ascend).set_default(true) .describe("Whether to sort in ascending or descending order."); DMLC_DECLARE_FIELD(dtype) - .add_enum("uint8", mshadow::kUint8) - .add_enum("int32", mshadow::kInt32) - .add_enum("float16", mshadow::kFloat16) - .add_enum("float32", mshadow::kFloat32) - .add_enum("float64", mshadow::kFloat64) + MXNET_ADD_ALL_TYPES .set_default(mshadow::kFloat32) .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or" " \"both\". An error will be raised if the selected data type cannot precisely " @@ -129,8 +121,8 @@ struct ArgSortParam : public dmlc::Parameter { }; inline void ParseTopKParam(const mxnet::TShape& src_shape, const TopKParam& param, - mxnet::TShape *target_shape, int *batch_size, int *element_num, - int *axis, int *k, bool *do_transpose, bool *is_ascend) { + mxnet::TShape *target_shape, size_t *batch_size, index_t *element_num, + int *axis, index_t *k, bool *do_transpose, bool *is_ascend) { *do_transpose = false; *k = param.k; *is_ascend = param.is_ascend; @@ -179,14 +171,14 @@ using namespace mshadow; struct fill_ind_to_one { template - MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) { + MSHADOW_XINLINE static void Map(index_t i, const index_t* indices, DType* out) { out[indices[i]] = static_cast(1); } }; struct fill_ind { template - MSHADOW_XINLINE static void Map(int i, const int* indices, const DType* val, + MSHADOW_XINLINE static void Map(index_t i, const index_t* indices, const DType* val, int req, DType* out) { KERNEL_ASSIGN(out[indices[i]], req, val[i]); } @@ -194,9 +186,9 @@ struct fill_ind { template MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, + const Tensor& ind, const Tensor& work, - int K, int N, bool is_ascend, + index_t K, index_t N, bool is_ascend, Stream *s) { // Use full sort when K is relatively large. const bool full_sort(K*8 > N); @@ -204,29 +196,29 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, const int M(work.size(0)/(sizeof(DType)*N)); const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < M; ++i) { + for (index_t i = 0; i < M; ++i) { // Tensor `work` stores the flattened source data, while `dat` stores the sorted result. DType *vals = reinterpret_cast(work.dptr_); DType *sorted_vals = dat.dptr_+i*N; - int *indices = ind.dptr_+i*N; + index_t *indices = ind.dptr_+i*N; if (is_ascend) { if (full_sort) { std::sort(indices, indices+N, [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); } else { std::partial_sort(indices, indices+K, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ return vals[i1] < vals[i2]; }); } } else { if (full_sort) { std::sort(indices, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ return vals[i1] > vals[i2]; }); } else { std::partial_sort(indices, indices+K, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ return vals[i1] > vals[i2]; }); } } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { sorted_vals[j] = vals[indices[j]]; } } @@ -235,18 +227,22 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, #ifdef __CUDACC__ template -MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, bool is_ascend) { +MSHADOW_XINLINE bool TopKCompare(DType val1, + index_t ind1, + DType val2, + index_t ind2, + bool is_ascend) { // Negative indices denote undefined values which are considered arbitrary small resp. large. return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2))); } template -MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int *ind2, +MSHADOW_XINLINE void MergeTopK(int K, DType *val1, index_t *ind1, DType *val2, index_t *ind2, bool is_ascend) { // In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list. - int i1(K-1), i2(K-1); - for (int i = 0; i < K; ++i) { + index_t i1(K-1), i2(K-1); + for (index_t i = 0; i < K; ++i) { if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) { --i2; } else { @@ -268,15 +264,15 @@ MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int * } template -__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_ascend) { +__global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t *ind, bool is_ascend) { // Buffer for pairwise reduction. - extern __shared__ int buff[]; + extern __shared__ index_t buff[]; // Start of buffer sections associated with this thread. const int offset(threadIdx.x*K); - int *ind_buff = &buff[offset]; + index_t *ind_buff = &buff[offset]; DType *val_buff = reinterpret_cast(&buff[blockDim.x*K])+offset; // Initialize top-K values for this thread. - for (int i = 0; i < K; ++i) { + for (index_t i = 0; i < K; ++i) { ind_buff[i] = -1; } // Range of values this thread cares about. Each thread block processes @@ -286,10 +282,11 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N); // Select top-K from this range and store it sorted in the buffer. // We assume a small K, so linear insertion is o.k. - for (int i = first; i < last; i += blockDim.x) { + for (index_t i = first; i < last; i += blockDim.x) { DType cur_val(val[i]); - int cur_ind(ind[i]); - for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], ind_buff[j], is_ascend); ) { + index_t cur_ind(ind[i]); + for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], + ind_buff[j], is_ascend); ) { if (j+1 < K) { val_buff[j+1] = val_buff[j]; ind_buff[j+1] = ind_buff[j]; @@ -309,7 +306,7 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as } // Final updates on master thread. if (threadIdx.x == 0) { - for (int i = 0; i < K; ++i) { + for (index_t i = 0; i < K; ++i) { ind[blockIdx.x*N+i] = ind_buff[i]; val[blockIdx.x*N+i] = val_buff[i]; } @@ -318,9 +315,9 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as template MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, + const Tensor& ind, const Tensor& work, - int K, int N, bool is_ascend, + index_t K, index_t N, bool is_ascend, Stream *s) { // Use full sort for all but very small K for which we // can do a partial sort entirely within shared memory. @@ -331,7 +328,8 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, // Divide workspace into two parts. The first one is needed to store batch ids. size_t alignment = std::max(sizeof(DType), sizeof(int)); size_t id_size = PadBytes(sizeof(int) * ind.size(0), alignment); - Tensor batch_id(reinterpret_cast(work.dptr_), Shape1(ind.size(0)), s); + Tensor batch_id(reinterpret_cast(work.dptr_), + Shape1(ind.size(0)), s); Tensor sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s); mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work); if (M > 1) { @@ -364,9 +362,8 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, * \param param the topk parameters * \tparam xpu the device type. * \tparam DType type of the output value/mask. - * \tparam IDType type of the output indices. */ -template +template void TopKImpl(const RunContext &ctx, const Resource &resource, const std::vector& req, @@ -380,20 +377,22 @@ void TopKImpl(const RunContext &ctx, Tensor workspace; Tensor temp_workspace; Tensor sorted_dat; - Tensor indices, sel_indices; - int batch_size, element_num; // number of batches + the size of each batch + Tensor indices, sel_indices; + size_t batch_size; + index_t element_num; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; size_t alignment = std::max(sizeof(DType), sizeof(int)); mxnet::TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) - << "'IDType' does not have a sufficient precision to represent the indices of the input array. " - << "The total element_num is " << element_num << ", but the selected IDType can only represent " - << mxnet::common::MaxIntegerValue() << " elements"; + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'index_t' does not have a sufficient precision to represent the indices of " + << "the input array. The total element_num is " + << element_num << ", but the selected index_t can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. @@ -417,12 +416,12 @@ void TopKImpl(const RunContext &ctx, sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // contain sorted dat workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment); - indices = Tensor(reinterpret_cast(workspace_curr_ptr), + indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { - sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), + sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(batch_size * k), s); workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment); CHECK_EQ(sel_indices.CheckContiguous(), true); @@ -454,8 +453,8 @@ void TopKImpl(const RunContext &ctx, workspace_curr_ptr += temp_size; } - mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, 0, 1, - kWriteTo, indices.dptr_); + mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, index_t{0}, + index_t{1}, kWriteTo, indices.dptr_); CHECK_EQ(indices.CheckContiguous(), true); // 2. Perform inplace batch sort. @@ -494,8 +493,8 @@ void TopKImpl(const RunContext &ctx, } } else if (param.ret_typ == topk_enum::kReturnIndices) { if (do_transpose) { - Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); - ASSIGN_DISPATCH(ret_indices, req[0], tcast(F(transpose( + Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); + ASSIGN_DISPATCH(ret_indices, req[0], tcast(F(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], @@ -503,21 +502,21 @@ void TopKImpl(const RunContext &ctx, 0, k), Shape3(0, 2, 1)), element_num))); } else { - Tensor ret_indices = - ret[0].get_with_shape(Shape2(batch_size, k), s); - ASSIGN_DISPATCH(ret_indices, req[0], tcast(F(slice<1>( + Tensor ret_indices = + ret[0].get_with_shape(Shape2(batch_size, k), s); + ASSIGN_DISPATCH(ret_indices, req[0], tcast(F(slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num))); } } else { if (do_transpose) { Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); - Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); + Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); ASSIGN_DISPATCH(ret_value, req[0], transpose( slice<2>(inplace_reshape(sorted_dat, Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)), 0, k), Shape3(0, 2, 1))); - ASSIGN_DISPATCH(ret_indices, req[1], tcast(F(transpose( + ASSIGN_DISPATCH(ret_indices, req[1], tcast(F(transpose( slice<2>(inplace_reshape(indices, Shape3(ret_indices.shape_[0], ret_indices.shape_[2], @@ -526,11 +525,11 @@ void TopKImpl(const RunContext &ctx, } else { Tensor ret_value = ret[0].get_with_shape(Shape2(batch_size, k), s); - Tensor ret_indices = - ret[1].get_with_shape(Shape2(batch_size, k), s); + Tensor ret_indices = + ret[1].get_with_shape(Shape2(batch_size, k), s); ASSIGN_DISPATCH(ret_value, req[0], slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k)); - ASSIGN_DISPATCH(ret_indices, req[1], tcast(F(slice<1>( + ASSIGN_DISPATCH(ret_indices, req[1], tcast(F(slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k), element_num))); } } @@ -545,13 +544,13 @@ void TopK(const nnvm::NodeAttrs& attrs, const TopKParam& param = nnvm::get(attrs.parsed); if (param.ret_typ == topk_enum::kReturnIndices || param.ret_typ == topk_enum::kReturnBoth) { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(param.dtype, IDType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + MSHADOW_TYPE_SWITCH(param.dtype, index_t, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); }) }); } else { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); }); } } @@ -569,7 +568,7 @@ void Sort(const nnvm::NodeAttrs& attrs, topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnValue; MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); }); } @@ -587,14 +586,14 @@ void ArgSort(const nnvm::NodeAttrs& attrs, topk_param.dtype = param.dtype; topk_param.ret_typ = topk_enum::kReturnIndices; MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(param.dtype, IDType, { - TopKImpl(ctx.run_ctx, + MSHADOW_TYPE_SWITCH(param.dtype, index_t, { + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); }); }); } -template +template void TopKBackwardImpl(const OpContext &ctx, const std::vector& inputs, const std::vector& req, @@ -605,33 +604,37 @@ void TopKBackwardImpl(const OpContext &ctx, using namespace mshadow::expr; Stream *s = ctx.run_ctx.get_stream(); CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth); - int batch_size, element_num; // number of batches + the size of each batch + size_t batch_size; + index_t element_num; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; mxnet::TShape target_shape; ParseTopKParam(outputs[0].shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) - << "'IDType' does not have a sufficient precision to represent the indices of the input array. " - << "The total element_num is " << element_num << ", but the selected IDType can only represent " - << mxnet::common::MaxIntegerValue() << " elements"; - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(batch_size * k + batch_size), s); - Tensor sel_indices = - Tensor(workspace.dptr_, Shape1(batch_size * k), s); - Tensor batch_shift = - Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'index_t' does not have a sufficient precision to represent " + << "the indices of the input array. The total element_num is " + << element_num << ", but the selected index_t can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(batch_size * k + batch_size), s); + Tensor sel_indices = + Tensor(workspace.dptr_, Shape1(batch_size * k), s); + Tensor batch_shift = + Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); Tensor out_grad = inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); Tensor in_grad = outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, + mxnet_op::Kernel::Launch(s, batch_size, 1, + static_cast(0), + element_num, kWriteTo, batch_shift.dptr_); if (do_transpose) { - Tensor indices = inputs[2].FlatTo1D(s); + Tensor indices = inputs[2].FlatTo1D(s); mxnet::TShape src_shape = outputs[0].shape_.FlatTo3D(axis); sel_indices = reshape(transpose( broadcast_to(inplace_reshape(batch_shift, @@ -639,13 +642,13 @@ void TopKBackwardImpl(const OpContext &ctx, mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))), Shape3(0, 2, 1)), Shape1(batch_size * k)); - sel_indices += tcast(indices); + sel_indices += indices; sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } else { - Tensor indices = - inputs[2].get_with_shape(Shape2(batch_size, k), s); - sel_indices = reshape(tcast(indices) + + Tensor indices = + inputs[2].get_with_shape(Shape2(batch_size, k), s); + sel_indices = reshape(indices + broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), mxnet::TShape(Shape2(batch_size, k))), Shape1(batch_size * k)); @@ -674,13 +677,13 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, const TopKParam& param = nnvm::get(attrs.parsed); if (param.ret_typ == topk_enum::kReturnBoth) { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(param.dtype, IDType, { - TopKBackwardImpl(ctx, inputs, req, outputs, param); + MSHADOW_TYPE_SWITCH(param.dtype, index_t, { + TopKBackwardImpl(ctx, inputs, req, outputs, param); }); }); } else if (param.ret_typ == topk_enum::kReturnValue) { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKBackwardImpl(ctx, inputs, req, outputs, param); + TopKBackwardImpl(ctx, inputs, req, outputs, param); }); } else { LOG(FATAL) << "Not Implemented"; @@ -717,7 +720,7 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, CHECK(out_size == 1 || out_size == 2); if (out_size > 1) { if (param.ret_typ == topk_enum::kReturnValue) { - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64)) << "Failed to set the type of ret_indices."; } else { CHECK(type_assign(&(*out_attrs)[1], param.dtype)) @@ -752,11 +755,12 @@ inline bool TopKShapeImpl(const TopKParam& param, CHECK_EQ(out_attrs->size(), 2U); } mxnet::TShape& in_shape = (*in_attrs)[0]; - int batch_size, element_num; // number of batches + the size of each batch + size_t batch_size; + index_t element_num; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; mxnet::TShape target_shape; ParseTopKParam(in_shape, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); @@ -785,8 +789,8 @@ inline bool SortType(const nnvm::NodeAttrs& attrs, size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); CHECK_EQ(out_size, 2); - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) - << "Failed to set the type of ret_indices to int32."; + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64)) + << "Failed to set the type of ret_indices to int64."; CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index cbba608d5d2f..7a6b6df6238c 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -17,7 +17,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_2d_tensor from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -292,12 +292,6 @@ def test_unravel_index(): assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() -def create_2d_tensor(rows, columns): - a = np.arange(0, rows).reshape(rows, 1) - b = np.broadcast_to(a, shape=(a.shape[0], columns)) - return nd.array(b, dtype=np.int64) - - def test_transpose(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) t = b.T @@ -326,6 +320,30 @@ def test_softmax(): assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) +def test_argsort(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) + mx.nd.waitall() + assert (s[0].asnumpy() == (LARGE_X - 1)).all() + + +def test_sort(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + s = nd.sort(b, axis=0, is_ascend=False) + assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() + s = nd.sort(b, is_ascend=False) + assert np.sum(s[0].asnumpy() == 0).all() + + +def test_topk(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + k = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y + b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value") + assert l.sum() == np.sum(np.arange(0, SMALL_Y)) + + if __name__ == '__main__': import nose nose.runmodule()