diff --git a/julia/src/ndarray.jl b/julia/src/ndarray.jl index dad9b59e8210..6987d572ea7a 100644 --- a/julia/src/ndarray.jl +++ b/julia/src/ndarray.jl @@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims) @_remap _prod(x::NDArray, ::Colon) prod(x) @_remap _prod(x::NDArray, dims) prod(x; axis = 0 .- dims, keepdims = true) +# TODO: support CartesianIndex ? +""" + argmax(x::NDArray; dims) -> indices + +Note that `NaN` is skipped during comparison. +This is different from Julia `Base.argmax`. + +## Examples + +```julia-repl +julia> x = NDArray([0. 1 2; 3 4 5]) +2×3 NDArray{Float64,2} @ CPU0: + 0.0 1.0 2.0 + 3.0 4.0 5.0 + +julia> argmax(x, dims = 1) +1×3 NDArray{Float64,2} @ CPU0: + 2.0 2.0 2.0 + +julia> argmax(x, dims = 2) +2×1 NDArray{Float64,2} @ CPU0: + 3.0 + 3.0 +``` + +See also [`argmin`](@ref mx.argmin). +""" +Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1 +@_remap _argmax(x::NDArray, ::Colon) argmax(x) +@_remap _argmax(x::NDArray, dims) argmax(x; axis = 0 .- dims, keepdims = true) + +""" + argmin(x::NDArray; dims) -> indices + +Note that `NaN` is skipped during comparison. +This is different from Julia `Base.argmin`. + +## Examples + +```julia-repl +julia> x = NDArray([0. 1 2; 3 4 5]) +2×3 NDArray{Float64,2} @ CPU0: + 0.0 1.0 2.0 + 3.0 4.0 5.0 + +julia> argmax(x, dims = 1) +1×3 NDArray{Float64,2} @ CPU0: + 2.0 2.0 2.0 + +julia> argmax(x, dims = 2) +2×1 NDArray{Float64,2} @ CPU0: + 3.0 + 3.0 +``` + +See also [`argmax`](@ref mx.argmax). +""" +Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1 +@_remap _argmin(x::NDArray, ::Colon) argmin(x) +@_remap _argmin(x::NDArray, dims) argmin(x; axis = 0 .- dims, keepdims = true) + _nddoc[:clip] = _nddoc[:clip!] = """ clip(x::NDArray, min, max) @@ -1734,6 +1795,10 @@ const _op_import_bl = [ # import black list; do not import these funcs "broadcast_axis", "broadcast_axes", "broadcast_hypot", + + # reduction + "argmax", + "argmin", ] macro _import_ndarray_functions() diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl index 9ca4ba206027..85328ff21bc8 100644 --- a/julia/test/unittest/ndarray.jl +++ b/julia/test/unittest/ndarray.jl @@ -1434,6 +1434,50 @@ function test_hypot() @test copy(z) == C end # function test_hypot +function test_argmax() + @info "NDArray::argmax" + let + A = [1. 5 3; + 4 2 6] + x = NDArray(A) + + @test copy(argmax(x, dims = 1)) == [2 1 2] + @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1) + end + + @info "NDArray::argmax::NaN" + let + A = [1. 5 3; + NaN 2 6] + x = NDArray(A) + + @test copy(argmax(x, dims = 1)) == [1 1 2] + @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1) + end +end + +function test_argmin() + @info "NDArray::argmin" + let + A = [1. 5 3; + 4 2 6] + x = NDArray(A) + + @test copy(argmin(x, dims = 1)) == [1 2 1] + @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1) + end + + @info "NDArray::argmin::NaN" + let + A = [1. 5 3; + NaN 2 6] + x = NDArray(A) + + @test copy(argmin(x, dims = 1)) == [1 2 1] + @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1) + end +end + ################################################################################ # Run tests ################################################################################ @@ -1479,6 +1523,8 @@ end # function test_hypot test_broadcast_to() test_broadcast_axis() test_hypot() + test_argmax() + test_argmin() end end