diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index fd491534f83a..c7a10541adbd 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -487,7 +487,7 @@ void EyeFill(const nnvm::NodeAttrs& attrs, struct range_fwd { template - MSHADOW_XINLINE static void Map(index_t i, int repeat, DType start, DType step, + MSHADOW_XINLINE static void Map(index_t i, index_t repeat, DType start, DType step, int req, DType* out) { KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step); } diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 1dda90104205..98bca3a43c60 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -81,12 +81,18 @@ struct TopKParam : public dmlc::Parameter { .describe("Whether to choose k largest or k smallest elements." " Top K largest elements will be chosen if set to false."); DMLC_DECLARE_FIELD(dtype) + // TODO(srivrohi): remove support for real data type in mxnet-2.0 .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) + .add_enum("int64", mshadow::kInt64) .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) - .set_default(mshadow::kFloat32) +#if MXNET_USE_INT64_TENSOR_SIZE == 1 + .set_default(mshadow::kInt64) +#else + .set_default(mshadow::kInt32) +#endif .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". " "An error will be raised if the selected data type cannot precisely represent the " "indices."); @@ -116,21 +122,33 @@ struct ArgSortParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(is_ascend).set_default(true) .describe("Whether to sort in ascending or descending order."); DMLC_DECLARE_FIELD(dtype) + // TODO(srivrohi): remove support for real data type in mxnet-2.0 .add_enum("uint8", mshadow::kUint8) .add_enum("int32", mshadow::kInt32) + .add_enum("int64", mshadow::kInt64) .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) - .set_default(mshadow::kFloat32) +#if USE_INT64_TENSOR_SIZE == 1 + .set_default(mshadow::kInt64) +#else + .set_default(mshadow::kInt32) +#endif .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or" " \"both\". An error will be raised if the selected data type cannot precisely " "represent the indices."); } }; -inline void ParseTopKParam(const mxnet::TShape& src_shape, const TopKParam& param, - mxnet::TShape *target_shape, int *batch_size, int *element_num, - int *axis, int *k, bool *do_transpose, bool *is_ascend) { +inline void ParseTopKParam(const TShape& src_shape, + const TopKParam& param, + TShape *target_shape, + size_t *batch_size, + index_t *element_num, + int *axis, + index_t *k, + bool *do_transpose, + bool *is_ascend) { *do_transpose = false; *k = param.k; *is_ascend = param.is_ascend; @@ -179,14 +197,14 @@ using namespace mshadow; struct fill_ind_to_one { template - MSHADOW_XINLINE static void Map(int i, const int* indices, DType* out) { + MSHADOW_XINLINE static void Map(int i, const index_t* indices, DType* out) { out[indices[i]] = static_cast(1); } }; struct fill_ind { template - MSHADOW_XINLINE static void Map(int i, const int* indices, const DType* val, + MSHADOW_XINLINE static void Map(int i, const index_t* indices, const DType* val, int req, DType* out) { KERNEL_ASSIGN(out[indices[i]], req, val[i]); } @@ -194,39 +212,43 @@ struct fill_ind { template MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, + const Tensor& ind, const Tensor& work, - int K, int N, bool is_ascend, + index_t K, index_t N, bool is_ascend, Stream *s) { // Use full sort when K is relatively large. const bool full_sort(K*8 > N); // Batch size. - const int M(work.size(0)/(sizeof(DType)*N)); + const index_t M(work.size(0)/(sizeof(DType)*N)); const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < M; ++i) { + for (index_t i = 0; i < M; ++i) { // Tensor `work` stores the flattened source data, while `dat` stores the sorted result. DType *vals = reinterpret_cast(work.dptr_); DType *sorted_vals = dat.dptr_+i*N; - int *indices = ind.dptr_+i*N; + index_t *indices = ind.dptr_+i*N; if (is_ascend) { if (full_sort) { std::sort(indices, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ + return vals[i1] < vals[i2]; }); } else { std::partial_sort(indices, indices+K, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] < vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ + return vals[i1] < vals[i2]; }); } } else { if (full_sort) { std::sort(indices, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ + return vals[i1] > vals[i2]; }); } else { std::partial_sort(indices, indices+K, indices+N, - [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); + [&](const index_t& i1, const index_t& i2){ + return vals[i1] > vals[i2]; }); } } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { sorted_vals[j] = vals[indices[j]]; } } @@ -235,18 +257,19 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, #ifdef __CUDACC__ template -MSHADOW_XINLINE bool TopKCompare(DType val1, int ind1, DType val2, int ind2, bool is_ascend) { +MSHADOW_XINLINE bool TopKCompare(DType val1, index_t ind1, DType val2, index_t ind2, + bool is_ascend) { // Negative indices denote undefined values which are considered arbitrary small resp. large. return (ind2 < 0) || (ind1 >= 0 && ((is_ascend && val1 < val2) || (!is_ascend && val1 > val2))); } template -MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int *ind2, +MSHADOW_XINLINE void MergeTopK(index_t K, DType *val1, index_t *ind1, DType *val2, index_t *ind2, bool is_ascend) { // In-place merge of two sorted top-K lists into val1/ind1. First determine the intervals // [0,..,i1], [0,..i2] of the two lists that will be part of the merged list. - int i1(K-1), i2(K-1); - for (int i = 0; i < K; ++i) { + index_t i1(K-1), i2(K-1); + for (index_t i = 0; i < K; ++i) { if (TopKCompare(val1[i1], ind1[i1], val2[i2], ind2[i2], is_ascend)) { --i2; } else { @@ -254,7 +277,7 @@ MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int * } } // Now merge the lists from back to front. - for (int i = K; i--;) { + for (index_t i = K; i--;) { if (i2 < 0 || i1 >= 0 && TopKCompare(val2[i2], ind2[i2], val1[i1], ind1[i1], is_ascend)) { val1[i] = val1[i1]; ind1[i] = ind1[i1]; @@ -268,28 +291,29 @@ MSHADOW_XINLINE void MergeTopK(int K, DType *val1, int *ind1, DType *val2, int * } template -__global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_ascend) { +__global__ void PartialSortSmallK(index_t K, index_t N, DType *val, index_t *ind, bool is_ascend) { // Buffer for pairwise reduction. - extern __shared__ int buff[]; + extern __shared__ index_t buff[]; // Start of buffer sections associated with this thread. - const int offset(threadIdx.x*K); - int *ind_buff = &buff[offset]; + const index_t offset(threadIdx.x*K); + index_t *ind_buff = &buff[offset]; DType *val_buff = reinterpret_cast(&buff[blockDim.x*K])+offset; // Initialize top-K values for this thread. - for (int i = 0; i < K; ++i) { + for (index_t i = 0; i < K; ++i) { ind_buff[i] = -1; } // Range of values this thread cares about. Each thread block processes // a different batch item (i.e. a different set of ind/val where we // have to select the top-K elements). All threads within the same // block work on the same batch item. - const int first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N); + const index_t first(blockIdx.x*N+threadIdx.x), last((blockIdx.x+1)*N); // Select top-K from this range and store it sorted in the buffer. // We assume a small K, so linear insertion is o.k. - for (int i = first; i < last; i += blockDim.x) { + for (index_t i = first; i < last; i += blockDim.x) { DType cur_val(val[i]); - int cur_ind(ind[i]); - for (int j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], ind_buff[j], is_ascend); ) { + index_t cur_ind(ind[i]); + for (index_t j = K; j-- && TopKCompare(cur_val, cur_ind, val_buff[j], + ind_buff[j], is_ascend); ) { if (j+1 < K) { val_buff[j+1] = val_buff[j]; ind_buff[j+1] = ind_buff[j]; @@ -300,7 +324,7 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as } // Recursive merge of sorted lists for this thread block. Note that blockDim.x is not // necessary a power of two, therefore the additional checks for last_s. - for (unsigned int s = (blockDim.x+1)/2, last_s = blockDim.x; + for (index_t s = (blockDim.x+1)/2, last_s = blockDim.x; last_s > 1; last_s = s, s = (s+1)/2) { __syncthreads(); if (threadIdx.x < s && threadIdx.x+s < last_s) { @@ -309,7 +333,7 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as } // Final updates on master thread. if (threadIdx.x == 0) { - for (int i = 0; i < K; ++i) { + for (index_t i = 0; i < K; ++i) { ind[blockIdx.x*N+i] = ind_buff[i]; val[blockIdx.x*N+i] = val_buff[i]; } @@ -318,20 +342,21 @@ __global__ void PartialSortSmallK(int K, int N, DType *val, int *ind, bool is_as template MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, - const Tensor& ind, + const Tensor& ind, const Tensor& work, - int K, int N, bool is_ascend, + index_t K, index_t N, bool is_ascend, Stream *s) { // Use full sort for all but very small K for which we // can do a partial sort entirely within shared memory. const bool full_sort(K > 5); // Batch size. - const int M(dat.size(0)/N); + const index_t M(dat.size(0)/N); if (full_sort) { // Divide workspace into two parts. The first one is needed to store batch ids. - size_t alignment = std::max(sizeof(DType), sizeof(int)); - size_t id_size = PadBytes(sizeof(int) * ind.size(0), alignment); - Tensor batch_id(reinterpret_cast(work.dptr_), Shape1(ind.size(0)), s); + size_t alignment = std::max(sizeof(DType), sizeof(index_t)); + size_t id_size = PadBytes(sizeof(index_t) * ind.size(0), alignment); + Tensor batch_id(reinterpret_cast(work.dptr_), + Shape1(ind.size(0)), s); Tensor sort_work(work.dptr_+id_size, Shape1(work.size(0)-id_size), s); mxnet::op::SortByKey(dat, ind, is_ascend, &sort_work); if (M > 1) { @@ -380,20 +405,22 @@ void TopKImpl(const RunContext &ctx, Tensor workspace; Tensor temp_workspace; Tensor sorted_dat; - Tensor indices, sel_indices; - int batch_size, element_num; // number of batches + the size of each batch + Tensor indices, sel_indices; + size_t batch_size = 0; + index_t element_num = 0; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; size_t alignment = std::max(sizeof(DType), sizeof(int)); mxnet::TShape target_shape; ParseTopKParam(src.shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); - CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) - << "'IDType' does not have a sufficient precision to represent the indices of the input array. " - << "The total element_num is " << element_num << ", but the selected IDType can only represent " - << mxnet::common::MaxIntegerValue() << " elements"; + CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) + << "'index_t' does not have a sufficient precision to represent " + << "the indices of the input array. The total element_num is " + << element_num << ", but the selected index_t can only represent " + << mxnet::common::MaxIntegerValue() << " elements"; Tensor dat = src.FlatTo3D(axis, axis, s); size_t temp_size = 0; // Temp space needed by the gpu-based full sorts. @@ -404,11 +431,11 @@ void TopKImpl(const RunContext &ctx, temp_size = std::max(temp_size, mxnet::op::SortByKeyWorkspaceSize(src.Size())); // Additional temp space for gpu full sorts for batch ids. - temp_size += PadBytes(sizeof(int) * src.Size(), alignment); + temp_size += PadBytes(sizeof(index_t) * src.Size(), alignment); // Temp space for cpu sorts. temp_size = std::max(temp_size, static_cast(sizeof(DType) * src.Size())); size_t workspace_size = temp_size + PadBytes(sizeof(DType) * src.Size(), alignment) - + PadBytes(sizeof(int) * src.Size(), alignment); + + PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { workspace_size += PadBytes(sizeof(int) * batch_size * k, alignment); } @@ -417,14 +444,14 @@ void TopKImpl(const RunContext &ctx, sorted_dat = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // contain sorted dat workspace_curr_ptr += PadBytes(sizeof(DType) * src.Size(), alignment); - indices = Tensor(reinterpret_cast(workspace_curr_ptr), + indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix - workspace_curr_ptr += PadBytes(sizeof(int) * src.Size(), alignment); + workspace_curr_ptr += PadBytes(sizeof(index_t) * src.Size(), alignment); if (param.ret_typ == topk_enum::kReturnMask) { - sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), + sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(batch_size * k), s); - workspace_curr_ptr += PadBytes(sizeof(int) * batch_size * k, alignment); + workspace_curr_ptr += PadBytes(sizeof(index_t) * batch_size * k, alignment); CHECK_EQ(sel_indices.CheckContiguous(), true); } @@ -454,7 +481,7 @@ void TopKImpl(const RunContext &ctx, workspace_curr_ptr += temp_size; } - mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, 0, 1, + mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, index_t{0}, index_t{1}, kWriteTo, indices.dptr_); CHECK_EQ(indices.CheckContiguous(), true); @@ -551,7 +578,7 @@ void TopK(const nnvm::NodeAttrs& attrs, }); } else { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, param); }); } } @@ -569,7 +596,8 @@ void Sort(const nnvm::NodeAttrs& attrs, topk_param.k = 0; topk_param.ret_typ = topk_enum::kReturnValue; MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], outputs, topk_param); + TopKImpl(ctx.run_ctx, ctx.requested[0], req, inputs[0], + outputs, topk_param); }); } @@ -605,30 +633,32 @@ void TopKBackwardImpl(const OpContext &ctx, using namespace mshadow::expr; Stream *s = ctx.run_ctx.get_stream(); CHECK(param.ret_typ == topk_enum::kReturnValue || param.ret_typ == topk_enum::kReturnBoth); - int batch_size, element_num; // number of batches + the size of each batch + size_t batch_size = 0; + index_t element_num = 0; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; mxnet::TShape target_shape; ParseTopKParam(outputs[0].shape_, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); CHECK_LE(element_num, mxnet::common::MaxIntegerValue()) - << "'IDType' does not have a sufficient precision to represent the indices of the input array. " - << "The total element_num is " << element_num << ", but the selected IDType can only represent " + << "'IDType' does not have a sufficient precision to represent " + << "the indices of the input array. The total element_num is " << element_num + << ", but the selected index_t can only represent " << mxnet::common::MaxIntegerValue() << " elements"; - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(batch_size * k + batch_size), s); - Tensor sel_indices = - Tensor(workspace.dptr_, Shape1(batch_size * k), s); - Tensor batch_shift = - Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(batch_size * k + batch_size), s); + Tensor sel_indices = + Tensor(workspace.dptr_, Shape1(batch_size * k), s); + Tensor batch_shift = + Tensor(workspace.dptr_ + batch_size * k, Shape1(batch_size), s); Tensor out_grad = inputs[0].get_with_shape(Shape2(inputs[0].shape_.Size(), 1), s); Tensor in_grad = outputs[0].get_with_shape(Shape2(outputs[0].shape_.Size(), 1), s); - mxnet_op::Kernel::Launch(s, batch_size, 1, 0, element_num, kWriteTo, + mxnet_op::Kernel::Launch(s, batch_size, 1, index_t{0}, element_num, kWriteTo, batch_shift.dptr_); if (do_transpose) { Tensor indices = inputs[2].FlatTo1D(s); @@ -639,13 +669,13 @@ void TopKBackwardImpl(const OpContext &ctx, mxnet::TShape(Shape3(src_shape[0], src_shape[2], k))), Shape3(0, 2, 1)), Shape1(batch_size * k)); - sel_indices += tcast(indices); + sel_indices += tcast(indices); sel_indices = transpose_indices(sel_indices, Shape3(src_shape[0], src_shape[2], src_shape[1]), Shape3(0, 2, 1)); } else { Tensor indices = inputs[2].get_with_shape(Shape2(batch_size, k), s); - sel_indices = reshape(tcast(indices) + + sel_indices = reshape(tcast(indices) + broadcast_to(inplace_reshape(batch_shift, Shape2(batch_size, 1)), mxnet::TShape(Shape2(batch_size, k))), Shape1(batch_size * k)); @@ -680,7 +710,7 @@ void TopKBackward_(const nnvm::NodeAttrs& attrs, }); } else if (param.ret_typ == topk_enum::kReturnValue) { MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, { - TopKBackwardImpl(ctx, inputs, req, outputs, param); + TopKBackwardImpl(ctx, inputs, req, outputs, param); }); } else { LOG(FATAL) << "Not Implemented"; @@ -715,14 +745,11 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); CHECK(out_size == 1 || out_size == 2); + // out_attr[0] -> stores value + // out_attr[1] -> stores indices if (out_size > 1) { - if (param.ret_typ == topk_enum::kReturnValue) { - CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) + CHECK(type_assign(&(*out_attrs)[1], param.dtype)) << "Failed to set the type of ret_indices."; - } else { - CHECK(type_assign(&(*out_attrs)[1], param.dtype)) - << "Failed to set the type of ret_indices."; - } } if (param.ret_typ == topk_enum::kReturnIndices) { CHECK(type_assign(&(*out_attrs)[0], param.dtype)) @@ -752,11 +779,12 @@ inline bool TopKShapeImpl(const TopKParam& param, CHECK_EQ(out_attrs->size(), 2U); } mxnet::TShape& in_shape = (*in_attrs)[0]; - int batch_size, element_num; // number of batches + the size of each batch + size_t batch_size = 0; + index_t element_num = 0; // number of batches + the size of each batch int axis = 0; bool do_transpose = false; bool is_ascend = false; - int k = 0; + index_t k = 0; mxnet::TShape target_shape; ParseTopKParam(in_shape, param, &target_shape, &batch_size, &element_num, &axis, &k, &do_transpose, &is_ascend); @@ -785,8 +813,12 @@ inline bool SortType(const nnvm::NodeAttrs& attrs, size_t out_size = out_attrs->size(); CHECK_EQ(in_size, 1); CHECK_EQ(out_size, 2); +#if MXNET_USE_INT64_TENSOR_SIZE == 1 + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64)) +#else CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) - << "Failed to set the type of ret_indices to int32."; +#endif + << "Failed to set the type of ret_indices"; CHECK(type_assign(&data_type, (*in_attrs)[0])) << "Incompatible dtype of input, in_attrs[0]=" << (*in_attrs)[0]; CHECK(type_assign(&data_type, (*out_attrs)[0])) << "Incompatible dtype of output, out_attrs[0]=" @@ -816,7 +848,7 @@ inline bool ArgSortType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { const ArgSortParam& param = nnvm::get(attrs.parsed); CHECK(type_assign(&(*out_attrs)[0], param.dtype)) - << "Failed to set the type of ret_indices to int32."; + << "Failed to set the type of ret_indices."; return true; } diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index cbba608d5d2f..0df481a01987 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -326,6 +326,32 @@ def test_softmax(): assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) +def test_argsort(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) + mx.nd.waitall() + assert (s[0].asnumpy() == (LARGE_X - 1)).all() + + +def test_sort(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + s = nd.sort(b, axis=0, is_ascend=False) + assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() + s = nd.sort(b, is_ascend=False) + assert np.sum(s[0].asnumpy() == 0).all() + + +def test_topk(): + b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) + k = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + assert np.all(ind == val) + b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value") + assert l.sum() == np.sum(np.arange(0, SMALL_Y)) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index e5315900c725..d84b4f082b63 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -29,6 +29,7 @@ from mxnet.test_utils import np_reduce from mxnet.test_utils import same from mxnet.test_utils import random_sample, rand_shape_nd +from mxnet import runtime from numpy.testing import assert_allclose import mxnet.autograd @@ -747,6 +748,7 @@ def test_linspace(): def test_order(): ctx = default_context() dat_size = 5 + is_large_tensor_enabled = runtime.Features().is_enabled('INT64_TENSOR_SIZE') def gt_topk(dat, axis, ret_typ, k, is_ascend): if ret_typ == "indices": if is_ascend: @@ -819,7 +821,11 @@ def get_large_matrix(): # test for ret_typ=indices nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() - assert nd_ret_topk.dtype == np.float32 # Test the default dtype + # Test the default dtype + if is_large_tensor_enabled: + assert nd_ret_topk.dtype == np.int64 + else: + assert nd_ret_topk.dtype == np.int32 gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) assert_almost_equal(nd_ret_topk, gt) nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False, dtype=np.float64).asnumpy() @@ -860,7 +866,10 @@ def get_large_matrix(): nd_ret_topk_val = nd_ret_topk_val.asnumpy() nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() assert nd_ret_topk_val.dtype == dtype - assert nd_ret_topk_ind.dtype == np.float32 + if is_large_tensor_enabled: + assert nd_ret_topk_ind.dtype == np.int64 + else: + assert nd_ret_topk_ind.dtype == np.int32 gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) assert_almost_equal(nd_ret_topk_val, gt_val)