diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 1edcb5a74a77..f22f263b6ee9 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -88,11 +88,11 @@ struct ReduceAxisParam : public dmlc::Parameter { dmlc::optional axis; bool keepdims; DMLC_DECLARE_PARAMETER(ReduceAxisParam) { - DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional(-1)) .describe("The axis along which to perform the reduction. " "Negative values means indexing from right to left. " - "``Requires axis to be set as int, because global reduction " - "is not supported yet.``"); + "``The axis need to be set as an int. If the axis is " + "not set, the rightmost axis will be reduced.``"); DMLC_DECLARE_FIELD(keepdims).set_default(false) .describe("If this is set to `True`, the reduced axis is left " "in the result as dimension with size one."); diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 0aa48553901b..897079dc39c8 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -477,7 +477,6 @@ def test_dot(): C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True) assert_almost_equal(c, C.asnumpy(), atol=atol) - @with_seed() def test_reduce(): sample_num = 200 @@ -524,6 +523,31 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin), mx.nd.argmin, False) +@with_seed() +def test_argmax_argmin(): + # test optional parameters + # test name : input data, argmax result, argmin result + tests = { + 'axis_0' : [[[1, 2, 3], [4, 5, 6]], [1, 1, 1], [0, 0, 0]], + 'keep_dims' : [[[1, 2, 3], [4, 5, 6]], [[2], [2]], [[0], [0]]], + 'axis_none' : [[1, 2, 3, 4], 3, 0] + } + + arg_max = mx.nd.array(tests['axis_0'][0]).argmax(axis=0) + arg_min = mx.nd.array(tests['axis_0'][0]).argmin(axis=0) + assert_almost_equal(arg_max.asnumpy(), tests['axis_0'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['axis_0'][2]) + + arg_max = mx.nd.array(tests['keep_dims'][0]).argmax(axis=1, keepdims=True) + arg_min = mx.nd.array(tests['keep_dims'][0]).argmin(axis=1, keepdims=True) + assert_almost_equal(arg_max.asnumpy(), tests['keep_dims'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['keep_dims'][2]) + + arg_max = mx.nd.array(tests['axis_none'][0]).argmax() + arg_min = mx.nd.array(tests['axis_none'][0]).argmin() + assert_almost_equal(arg_max.asnumpy(), tests['axis_none'][1]) + assert_almost_equal(arg_min.asnumpy(), tests['axis_none'][2]) + @with_seed() def test_broadcast(): sample_num = 1000