Skip to content

Commit

Permalink
julia: fix argmax for NDArray (apache#13871)
Browse files Browse the repository at this point in the history
- fix 0-based index output to 1-based index

close apache#13786
  • Loading branch information
iblislin authored and stephenrawls committed Feb 16, 2019
1 parent 8147b7a commit d44ef61
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
65 changes: 65 additions & 0 deletions julia/src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,67 @@ Base.prod(x::NDArray; dims = :) = _prod(x, dims)
@_remap _prod(x::NDArray, ::Colon) prod(x)
@_remap _prod(x::NDArray, dims) prod(x; axis = 0 .- dims, keepdims = true)

# TODO: support CartesianIndex ?
"""
argmax(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmax`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmin`](@ref mx.argmin).
"""
Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1
@_remap _argmax(x::NDArray, ::Colon) argmax(x)
@_remap _argmax(x::NDArray, dims) argmax(x; axis = 0 .- dims, keepdims = true)

"""
argmin(x::NDArray; dims) -> indices
Note that `NaN` is skipped during comparison.
This is different from Julia `Base.argmin`.
## Examples
```julia-repl
julia> x = NDArray([0. 1 2; 3 4 5])
2×3 NDArray{Float64,2} @ CPU0:
0.0 1.0 2.0
3.0 4.0 5.0
julia> argmax(x, dims = 1)
1×3 NDArray{Float64,2} @ CPU0:
2.0 2.0 2.0
julia> argmax(x, dims = 2)
2×1 NDArray{Float64,2} @ CPU0:
3.0
3.0
```
See also [`argmax`](@ref mx.argmax).
"""
Base.argmin(x::NDArray; dims = :) = _argmin(x, dims) .+ 1
@_remap _argmin(x::NDArray, ::Colon) argmin(x)
@_remap _argmin(x::NDArray, dims) argmin(x; axis = 0 .- dims, keepdims = true)

_nddoc[:clip] = _nddoc[:clip!] =
"""
clip(x::NDArray, min, max)
Expand Down Expand Up @@ -1734,6 +1795,10 @@ const _op_import_bl = [ # import black list; do not import these funcs
"broadcast_axis",
"broadcast_axes",
"broadcast_hypot",

# reduction
"argmax",
"argmin",
]

macro _import_ndarray_functions()
Expand Down
46 changes: 46 additions & 0 deletions julia/test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,50 @@ function test_hypot()
@test copy(z) == C
end # function test_hypot

function test_argmax()
@info "NDArray::argmax"
let
A = [1. 5 3;
4 2 6]
x = NDArray(A)

@test copy(argmax(x, dims = 1)) == [2 1 2]
@test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
end

@info "NDArray::argmax::NaN"
let
A = [1. 5 3;
NaN 2 6]
x = NDArray(A)

@test copy(argmax(x, dims = 1)) == [1 1 2]
@test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1)
end
end

function test_argmin()
@info "NDArray::argmin"
let
A = [1. 5 3;
4 2 6]
x = NDArray(A)

@test copy(argmin(x, dims = 1)) == [1 2 1]
@test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
end

@info "NDArray::argmin::NaN"
let
A = [1. 5 3;
NaN 2 6]
x = NDArray(A)

@test copy(argmin(x, dims = 1)) == [1 2 1]
@test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1)
end
end

################################################################################
# Run tests
################################################################################
Expand Down Expand Up @@ -1479,6 +1523,8 @@ end # function test_hypot
test_broadcast_to()
test_broadcast_axis()
test_hypot()
test_argmax()
test_argmin()
end

end

0 comments on commit d44ef61

Please sign in to comment.