Skip to content

Commit

Permalink
fix #39203, 2-arg findmax should return index instead of value (#41076
Browse files Browse the repository at this point in the history
)
  • Loading branch information
JeffBezanson authored Jun 7, 2021
1 parent 355b66a commit a86cb62
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 36 deletions.
48 changes: 24 additions & 24 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,11 +771,11 @@ minimum(a; kw...) = mapreduce(identity, min, a; kw...)
## findmax, findmin, argmax & argmin

"""
findmax(f, domain) -> (f(x), x)
findmax(f, domain) -> (f(x), index)
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there
are multiple maximal points, then the first one will be returned.
Returns a pair of a value in the codomain (outputs of `f`) and the index of
the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is maximised.
If there are multiple maximal points, then the first one will be returned.
`domain` must be a non-empty iterable.
Expand All @@ -788,20 +788,20 @@ Values are compared with `isless`.
```jldoctest
julia> findmax(identity, 5:9)
(9, 9)
(9, 5)
julia> findmax(-, 1:10)
(-1, 1)
julia> findmax(first, [(1, :a), (2, :b), (2, :c)])
(2, (2, :b))
julia> findmax(first, [(1, :a), (3, :b), (3, :c)])
(3, 2)
julia> findmax(cos, 0:π/2:2π)
(1.0, 0.0)
(1.0, 1)
```
"""
findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)
findmax(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmax, pairs(domain) )
_rf_findmax((fm, im), (fx, ix)) = isless(fm, fx) ? (fx, ix) : (fm, im)

"""
findmax(itr) -> (x, index)
Expand All @@ -826,14 +826,14 @@ julia> findmax([1, 7, 7, NaN])
```
"""
findmax(itr) = _findmax(itr, :)
_findmax(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmax, pairs(a) )
_findmax(a, ::Colon) = findmax(identity, a)

"""
findmin(f, domain) -> (f(x), x)
findmin(f, domain) -> (f(x), index)
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there
are multiple minimal points, then the first one will be returned.
Returns a pair of a value in the codomain (outputs of `f`) and the index of
the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is minimised.
If there are multiple minimal points, then the first one will be returned.
`domain` must be a non-empty iterable.
Expand All @@ -846,21 +846,21 @@ are multiple minimal points, then the first one will be returned.
```jldoctest
julia> findmin(identity, 5:9)
(5, 5)
(5, 1)
julia> findmin(-, 1:10)
(-10, 10)
julia> findmin(first, [(1, :a), (1, :b), (2, :c)])
(1, (1, :a))
julia> findmin(first, [(2, :a), (2, :b), (3, :c)])
(2, 1)
julia> findmin(cos, 0:π/2:2π)
(-1.0, 3.141592653589793)
(-1.0, 3)
```
"""
findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)
findmin(f, domain) = mapfoldl( ((k, v),) -> (f(v), k), _rf_findmin, pairs(domain) )
_rf_findmin((fm, im), (fx, ix)) = isgreater(fm, fx) ? (fx, ix) : (fm, im)

"""
findmin(itr) -> (x, index)
Expand All @@ -885,7 +885,7 @@ julia> findmin([1, 7, 7, NaN])
```
"""
findmin(itr) = _findmin(itr, :)
_findmin(a, ::Colon) = mapfoldl( ((k, v),) -> (v, k), _rf_findmin, pairs(a) )
_findmin(a, ::Colon) = findmin(identity, a)

"""
argmax(f, domain)
Expand All @@ -909,7 +909,7 @@ julia> argmax(cos, 0:π/2:2π)
0.0
```
"""
argmax(f, domain) = findmax(f, domain)[2]
argmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)[2]

"""
argmax(itr)
Expand Down Expand Up @@ -962,7 +962,7 @@ julia> argmin(acos, 0:0.1:1)
1.0
```
"""
argmin(f, domain) = findmin(f, domain)[2]
argmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)[2]

"""
argmin(itr)
Expand Down
24 changes: 12 additions & 12 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,22 +391,22 @@ end

@testset "findmin(f, domain)" begin
@test findmin(-, 1:10) == (-10, 10)
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
@test findmin(identity, [1, 2, 3, missing]) === (missing, 4)
@test findmin(identity, [1, NaN, 3, missing]) === (missing, 4)
@test findmin(identity, [1, missing, NaN, 3]) === (missing, 2)
@test findmin(identity, [1, NaN, 3]) === (NaN, 2)
@test findmin(identity, [1, 3, NaN]) === (NaN, 3)
@test findmin(cos, 0:π/2:2π) == (-1.0, 3)
end

@testset "findmax(f, domain)" begin
@test findmax(-, 1:10) == (-1, 1)
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
@test findmax(identity, [1, 2, 3, missing]) === (missing, 4)
@test findmax(identity, [1, NaN, 3, missing]) === (missing, 4)
@test findmax(identity, [1, missing, NaN, 3]) === (missing, 2)
@test findmax(identity, [1, NaN, 3]) === (NaN, 2)
@test findmax(identity, [1, 3, NaN]) === (NaN, 3)
@test findmax(cos, 0:π/2:2π) == (1.0, 1)
end

@testset "argmin(f, domain)" begin
Expand Down

0 comments on commit a86cb62

Please sign in to comment.