diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 98bca3a43c60..fdc915eb4ac0 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -88,11 +88,7 @@ struct TopKParam : public dmlc::Parameter { .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) -#if MXNET_USE_INT64_TENSOR_SIZE == 1 - .set_default(mshadow::kInt64) -#else - .set_default(mshadow::kInt32) -#endif + .set_default(mshadow::kFloat32) .describe("DType of the output indices when ret_typ is \"indices\" or \"both\". " "An error will be raised if the selected data type cannot precisely represent the " "indices."); @@ -129,11 +125,7 @@ struct ArgSortParam : public dmlc::Parameter { .add_enum("float16", mshadow::kFloat16) .add_enum("float32", mshadow::kFloat32) .add_enum("float64", mshadow::kFloat64) -#if USE_INT64_TENSOR_SIZE == 1 - .set_default(mshadow::kInt64) -#else - .set_default(mshadow::kInt32) -#endif + .set_default(mshadow::kFloat32) .describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or" " \"both\". An error will be raised if the selected data type cannot precisely " "represent the indices."); @@ -748,8 +740,17 @@ inline bool TopKType(const nnvm::NodeAttrs& attrs, // out_attr[0] -> stores value // out_attr[1] -> stores indices if (out_size > 1) { - CHECK(type_assign(&(*out_attrs)[1], param.dtype)) - << "Failed to set the type of ret_indices."; + if (param.ret_typ == topk_enum::kReturnValue) { +#if USE_INT64_TENSOR_SIZE == 1 + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt64)) +#else + CHECK(type_assign(&(*out_attrs)[1], mshadow::kInt32)) +#endif + << "Failed to set the type of ret_indices."; + } else { + CHECK(type_assign(&(*out_attrs)[1], param.dtype)) + << "Failed to set the type of ret_indices."; + } } if (param.ret_typ == topk_enum::kReturnIndices) { CHECK(type_assign(&(*out_attrs)[0], param.dtype)) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index d84b4f082b63..f40bb3053358 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -822,10 +822,7 @@ def get_large_matrix(): # test for ret_typ=indices nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy() # Test the default dtype - if is_large_tensor_enabled: - assert nd_ret_topk.dtype == np.int64 - else: - assert nd_ret_topk.dtype == np.int32 + assert nd_ret_topk.dtype == np.float32 gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) assert_almost_equal(nd_ret_topk, gt) nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False, dtype=np.float64).asnumpy() @@ -866,10 +863,7 @@ def get_large_matrix(): nd_ret_topk_val = nd_ret_topk_val.asnumpy() nd_ret_topk_ind = nd_ret_topk_ind.asnumpy() assert nd_ret_topk_val.dtype == dtype - if is_large_tensor_enabled: - assert nd_ret_topk_ind.dtype == np.int64 - else: - assert nd_ret_topk_ind.dtype == np.int32 + assert nd_ret_topk_ind.dtype == np.float32 gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True) gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True) assert_almost_equal(nd_ret_topk_val, gt_val)