Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Revert default return type for indices in argsort() and topk() back t…
Browse files Browse the repository at this point in the history
…o float32
  • Loading branch information
Rohit Kumar Srivastava committed Jun 25, 2019
1 parent 7fe478a commit df1df8f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 18 deletions.
12 changes: 2 additions & 10 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ struct TopKParam : public dmlc::Parameter<TopKParam> {
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
#if MXNET_USE_INT64_TENSOR_SIZE == 1
.set_default(mshadow::kInt64)
#else
.set_default(mshadow::kInt32)
#endif
.set_default(mshadow::kFloat32)
.describe("DType of the output indices when ret_typ is \"indices\" or \"both\". "
"An error will be raised if the selected data type cannot precisely represent the "
"indices.");
Expand Down Expand Up @@ -129,11 +125,7 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
#if USE_INT64_TENSOR_SIZE == 1
.set_default(mshadow::kInt64)
#else
.set_default(mshadow::kInt32)
#endif
.set_default(mshadow::kFloat32)
.describe("DType of the output indices. It is only valid when ret_typ is \"indices\" or"
" \"both\". An error will be raised if the selected data type cannot precisely "
"represent the indices.");
Expand Down
10 changes: 2 additions & 8 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,7 @@ def get_large_matrix():
# test for ret_typ=indices
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
# Test the default dtype
if is_large_tensor_enabled:
assert nd_ret_topk.dtype == np.int64
else:
assert nd_ret_topk.dtype == np.int32
assert nd_ret_topk.dtype == np.float32
gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False, dtype=np.float64).asnumpy()
Expand Down Expand Up @@ -866,10 +863,7 @@ def get_large_matrix():
nd_ret_topk_val = nd_ret_topk_val.asnumpy()
nd_ret_topk_ind = nd_ret_topk_ind.asnumpy()
assert nd_ret_topk_val.dtype == dtype
if is_large_tensor_enabled:
assert nd_ret_topk_ind.dtype == np.int64
else:
assert nd_ret_topk_ind.dtype == np.int32
assert nd_ret_topk_ind.dtype == np.float32
gt_val = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
gt_ind = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk_val, gt_val)
Expand Down

0 comments on commit df1df8f

Please sign in to comment.