diff --git a/base/array.jl b/base/array.jl index 1e3ea06975a83c..34993c842bd529 100644 --- a/base/array.jl +++ b/base/array.jl @@ -1723,48 +1723,61 @@ findlast(testf::Function, A) = findprev(testf, A, endof(A)) """ find(f::Function, A) -Return a vector `I` of the linear indexes of `A` where `f(A[I])` returns `true`. +Return a vector `I` of the indices or keys of `A` where `f(A[I])` returns `true`. If there are no such elements of `A`, return an empty array. +Indices or keys are of the same type as those returned by [`keys(A)`](@ref) +and [`pairs(A)`](@ref) for `AbstractArray` and `Associative` objects, +and are linear indices of type `Int` for other iterables. + # Examples ```jldoctest -julia> A = [1 2 0; 3 4 0] -2×3 Array{Int64,2}: - 1 2 0 - 3 4 0 +julia> x = [1, 3, 4] +3-element Array{Int64,1}: + 1 + 3 + 4 -julia> find(isodd, A) +julia> find(isodd, x) 2-element Array{Int64,1}: 1 2 + julia> A = [1 2 0; 3 4 0] + 2×3 Array{Int64,2}: + 1 2 0 + 3 4 0 + + julia> find(isodd, A) + 2-element Array{CartesianIndex{2},1}: + CartesianIndex(1, 1) + CartesianIndex(2, 1) + julia> find(!iszero, A) -4-element Array{Int64,1}: - 1 - 2 - 3 - 4 +4-element Array{CartesianIndex{2},1}: +CartesianIndex(1, 1) +CartesianIndex(2, 1) +CartesianIndex(1, 2) +CartesianIndex(2, 2) + +julia> d = Dict(:A => 10, :B => -1, :C => 0) +Dict{Symbol,Int64} with 3 entries: + :A => 10 + :B => -1 + :C => 0 + +julia> find(x -> x >= 0, d) +2-element Array{Symbol,1}: + :A + :C -julia> find(isodd, [2, 4]) -0-element Array{Int64,1} ``` """ -function find(testf::Function, A) - # use a dynamic-length array to store the indexes, then copy to a non-padded - # array for the return - tmpI = Vector{Int}() - inds = _index_remapper(A) - for (i,a) = enumerate(A) - if testf(a) - push!(tmpI, inds[i]) - end - end - I = Vector{Int}(uninitialized, length(tmpI)) - copy!(I, tmpI) - return I -end -_index_remapper(A::AbstractArray) = linearindices(A) -_index_remapper(iter) = OneTo(typemax(Int)) # safe for objects that don't implement length +find(testf::Function, A) = + collect(first(p) for p in _pairs(A) if testf(last(p))) + +_pairs(A::Union{AbstractArray, Associative}) = pairs(A) +_pairs(iter) = zip(OneTo(typemax(Int)), iter) # safe for objects that don't implement length """ find(A) @@ -1789,22 +1802,10 @@ julia> find(falses(3)) ``` """ function find(A) - nnzA = count(t -> t != 0, A) - I = Vector{Int}(uninitialized, nnzA) - cnt = 1 - inds = _index_remapper(A) - warned = false - for (i,a) in enumerate(A) - if !warned && !(a isa Bool) - depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find) - warned = true - end - if a != 0 - I[cnt] = inds[i] - cnt += 1 - end + if !(eltype(A) === Bool) && !all(x -> x isa Bool, A) + depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find) end - return I + collect(first(p) for p in _pairs(A) if last(p) != 0) end find(x::Bool) = x ? [1] : Vector{Int}() diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index d9af41a9ab18b6..0e4ba14fb1b50b 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1264,7 +1264,7 @@ function find(p::Function, S::SparseMatrixCSC) end sz = size(S) I, J = _findn(p, S) - return sub2ind(sz, I, J) + return CartesianIndex.(I, J) end findn(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = _findn(x->x!=0, S) diff --git a/test/arrayops.jl b/test/arrayops.jl index 5a8e2c2107cc07..7306a8ff0ded65 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -464,6 +464,16 @@ end @test findprev(isodd, [2,4,5,3,9,2,0], 7) == 5 @test findprev(isodd, [2,4,5,3,9,2,0], 2) == 0 end +@testset "find with Matrix" begin + A = [1 2 0; 3 4 0] + @test find(isodd, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1)] + @test find(!iszero, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1), + CartesianIndex(1, 2), CartesianIndex(2, 2)] +end +@testset "find with Dict" begin + d = Dict(:A => 10, :B => -1, :C => 0) + @test find(x -> x >= 0, d) == [:A, :C] +end @testset "find with general iterables" begin s = "julia" @test find(c -> c == 'l', s) == [3]