From 35d280d1c8560c41aca60ebeea8f72c2684496ca Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Tue, 8 Aug 2017 20:48:18 +0200 Subject: [PATCH 1/3] add few missing methods for set-like operations --- base/array.jl | 32 ------ base/bitset.jl | 41 ++------ base/set.jl | 181 ++++++++++++++++++++-------------- doc/src/stdlib/collections.md | 4 +- test/bitset.jl | 2 +- test/sets.jl | 153 ++++++++++++++++------------ 6 files changed, 209 insertions(+), 204 deletions(-) diff --git a/base/array.jl b/base/array.jl index 4a059886b674a..f4f21791768c1 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2253,22 +2253,6 @@ function union(vs...) end # setdiff only accepts two args -""" - setdiff(a, b) - -Construct the set of elements in `a` but not `b`. Maintains order with arrays. Note that -both arguments must be collections, and both will be iterated over. In particular, -`setdiff(set,element)` where `element` is a potential member of `set`, will not work in -general. - -# Examples -```jldoctest -julia> setdiff([1,2,3],[3,4,5]) -2-element Array{Int64,1}: - 1 - 2 -``` -""" function setdiff(a, b) args_type = promote_type(eltype(a), eltype(b)) bset = Set(b) @@ -2287,21 +2271,5 @@ end # recursing. Has the advantage of keeping order, too, but # not as fast as other methods that make a single pass and # store counts with a Dict. -symdiff(a) = a symdiff(a, b) = union(setdiff(a,b), setdiff(b,a)) -""" - symdiff(a, b, rest...) - -Construct the symmetric difference of elements in the passed in sets or arrays. -Maintains order with arrays. - -# Examples -```jldoctest -julia> symdiff([1,2,3],[3,4,5],[4,5,6]) -3-element Array{Int64,1}: - 1 - 2 - 6 -``` -""" symdiff(a, b, rest...) = symdiff(a, symdiff(b, rest...)) diff --git a/base/bitset.jl b/base/bitset.jl index b01e5b09a095b..f2c07235e3ae1 100644 --- a/base/bitset.jl +++ b/base/bitset.jl @@ -49,6 +49,8 @@ function copy!(dest::BitSet, src::BitSet) dest end +copymutable(s::BitSet) = copy(s) + eltype(s::BitSet) = Int sizehint!(s::BitSet, n::Integer) = (sizehint!(s.bits, (n+63) >> 6); s) @@ -253,49 +255,18 @@ isempty(s::BitSet) = _check0(s.bits, 1, length(s.bits)) # Mathematical set functions: union!, intersect!, setdiff!, symdiff! -union(s::BitSet) = copy(s) -union(s1::BitSet, s2::BitSet) = union!(copy(s1), s2) -union(s1::BitSet, ss::BitSet...) = union(s1, union(ss...)) -union(s::BitSet, ns) = union!(copy(s), ns) -union!(s::BitSet, ns) = (for n in ns; push!(s, n); end; s) +union(s::BitSet, sets...) = union!(copy(s), sets...) +union!(s::BitSet, ns) = foldl(push!, s, ns) union!(s1::BitSet, s2::BitSet) = _matched_map!(|, s1, s2) -intersect(s1::BitSet) = copy(s1) -intersect(s1::BitSet, ss::BitSet...) = intersect(s1, intersect(ss...)) -function intersect(s1::BitSet, ns) - s = BitSet() - for n in ns - n in s1 && push!(s, n) - end - s -end intersect(s1::BitSet, s2::BitSet) = length(s1.bits) < length(s2.bits) ? intersect!(copy(s1), s2) : intersect!(copy(s2), s1) -""" - intersect!(s1::BitSet, s2::BitSet) -Intersects sets `s1` and `s2` and overwrites the set `s1` with the result. If needed, `s1` -will be expanded to the size of `s2`. -""" intersect!(s1::BitSet, s2::BitSet) = _matched_map!(&, s1, s2) -setdiff(s::BitSet, ns) = setdiff!(copy(s), ns) -setdiff!(s::BitSet, ns) = (for n in ns; delete!(s, n); end; s) setdiff!(s1::BitSet, s2::BitSet) = _matched_map!((p, q) -> p & ~q, s1, s2) -symdiff(s::BitSet, ns) = symdiff!(copy(s), ns) -""" - symdiff!(s, itr) - -For each element in `itr`, destructively toggle its inclusion in set `s`. -""" -symdiff!(s::BitSet, ns) = (for n in ns; int_symdiff!(s, n); end; s) -""" - symdiff!(s, n) - -The set `s` is destructively modified to toggle the inclusion of integer `n`. -""" -symdiff!(s::BitSet, n::Integer) = int_symdiff!(s, n) +symdiff!(s::BitSet, ns) = foldl(int_symdiff!, s, ns) function int_symdiff!(s::BitSet, n::Integer) n0 = _check_bitset_bounds(n) @@ -306,6 +277,8 @@ end symdiff!(s1::BitSet, s2::BitSet) = _matched_map!(xor, s1, s2) +filter!(f, s::BitSet) = unsafe_filter!(f, s) + @inline in(n::Int, s::BitSet) = _bits_getindex(s.bits, n, s.offset) @inline in(n::Integer, s::BitSet) = _is_convertible_Int(n) ? in(Int(n), s) : false diff --git a/base/set.jl b/base/set.jl index dea43bc4b665f..6544c9a8c32bd 100644 --- a/base/set.jl +++ b/base/set.jl @@ -7,10 +7,12 @@ mutable struct Set{T} <: AbstractSet{T} Set{T}() where {T} = new(Dict{T,Void}()) Set{T}(s::Set{T}) where {T} = new(Dict{T,Void}(s.dict)) - Set{T}(itr) where {T} = union!(new(Dict{T,Void}()), itr) end + +Set{T}(itr) where {T} = union!(Set{T}(), itr) Set() = Set{Any}() + """ Set([itr]) @@ -51,7 +53,10 @@ end delete!(s::Set, x) = (delete!(s.dict, x); s) -copy(s::Set{T}) where T = Set{T}(s) +copy(s::Set) = copymutable(s) + +# Set is the default mutable fall-back +copymutable(s::AbstractSet{T}) where {T} = Set{T}(s) sizehint!(s::Set, newsz) = (sizehint!(s.dict, newsz); s) empty!(s::Set) = (empty!(s.dict); s) @@ -63,10 +68,10 @@ done(s::Set, state) = done(s.dict, state) next(s::Set, i) = (s.dict.keys[i], skip_deleted(s.dict, i+1)) """ - union(s1,s2...) - ∪(s1,s2...) + union(s, itrs...) + ∪(s, itrs...) -Construct the union of two or more sets. Maintains order with arrays. +Construct the union of sets. Maintain order with arrays. # Examples ```jldoctest @@ -92,21 +97,14 @@ julia> union([4, 2], [1, 2]) """ function union end -union(s::Set) = copy(s) -function union(s::Set, sets...) - u = Set{join_eltype(s, sets...)}() - union!(u, s) - for t in sets - union!(u, t) - end - return u -end +union(s::Set, sets...) = union!(Set{join_eltype(s, sets...)}(), s, sets...) + const ∪ = union """ - union!(s, iterable) + union!(s::AbstractSet, itrs...) -Union each element of `iterable` into set `s` in-place. +Construct the union of passed in sets and overwrite `s` with the result. # Examples ```jldoctest @@ -118,24 +116,26 @@ julia> a Set([7, 4, 3, 5, 1]) ``` """ -function union!(s::Set{T}, xs) where T - haslength(xs) && sizehint!(s, length(xs)) - for x=xs +function union!(s::Set{T}, itr) where T + haslength(itr) && sizehint!(s, length(itr)) + for x=itr push!(s, x) length(s) == max_values(T) && break end s end +union!(s::AbstractSet, sets...) = foldl(union!, s, sets) + join_eltype() = Bottom join_eltype(v1, vs...) = typejoin(eltype(v1), join_eltype(vs...)) """ - intersect(s1,s2...) - ∩(s1,s2) + intersect(s, itrs...) + ∩(s, itrs...) -Construct the intersection of two or more sets. -Maintains order and multiplicity of the first argument for arrays and ranges. +Construct the intersection of sets. +Maintain order and multiplicity of the first argument for arrays and ranges. # Examples ```jldoctest @@ -152,37 +152,52 @@ julia> intersect([1, 4, 4, 5, 6], [4, 6, 6, 7, 8]) """ function intersect end -intersect(s::Set) = copy(s) -function intersect(s::Set, sets...) - i = similar(s) - for x in s - inall = true - for t in sets - if !in(x, t) - inall = false - break - end - end - inall && push!(i, x) +intersect(s) = copymutable(s) + +# enable when as fast as the loop: +# intersect(s::AbstractSet, itr) = IntSet(x for x in itr if x in s) +function intersect(s::AbstractSet, itr) + t = similar(s) + for x in itr + x in s && push!(t, x) end - return i + t end + +intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...) + const ∩ = intersect -function setdiff(a::Set, b) - d = similar(a) - for x in a - if !(x in b) - push!(d, x) - end - end - d -end +""" + intersect!(s::AbstractSet, itrs...) + +Intersect all passed in sets and overwrite `s` with the result. +""" +intersect!(s::AbstractSet, s2::AbstractSet) = filter!(x -> x in s2, s) +intersect!(s::AbstractSet, itr) = intersect!(s, union!(similar(s), itr)) +intersect!(s::AbstractSet, itrs...) = foldl(intersect!, s, itrs) + +""" + setdiff(s, itrs...) + +Construct the set of elements in `s` but not in any of the iterables in `itrs`. +Maintain order with arrays. + +# Examples +```jldoctest +julia> setdiff([1,2,3], [3,4,5]) +2-element Array{Int64,1}: + 1 + 2 +``` +""" +setdiff(s::AbstractSet, itrs...) = setdiff!(copymutable(s), itrs...) +setdiff(s) = copymutable(s) """ - setdiff!(s, iterable) + setdiff!(s, itrs...) -Remove each element of `iterable` from set `s` in-place. +Remove from set `s` (in-place) each element of each iterable from `itrs`. # Examples ```jldoctest @@ -194,7 +209,41 @@ julia> a Set([4]) ``` """ -setdiff!(s::Set, xs) = (for x=xs; delete!(s, x); end; s) +setdiff!(s::AbstractSet, itrs...) = foldl(setdiff!, s, itrs) +setdiff!(s::AbstractSet, itr) = foldl(delete!, s, itr) + + +""" + symdiff(s, itrs...) + +Construct the symmetric difference of elements in the passed in sets. +Maintains order with arrays. + +# Examples +```jldoctest +julia> symdiff([1,2,3], [3,4,5], [4,5,6]) +3-element Array{Int64,1}: + 1 + 2 + 6 +``` +""" +symdiff(s::AbstractSet, sets...) = symdiff!(copymutable(s), sets...) +symdiff(s) = copymutable(s) # remove when method above becomes as efficient + +""" + symdiff!(s::AbstractSet, itrs...) + +Construct the symmetric difference of the passed in sets, and overwrite `s` with the result. +""" +symdiff!(s::AbstractSet, itrs...) = foldl(symdiff!, s, itrs) + +function symdiff!(s::AbstractSet, itr) + for x in itr + x in s ? delete!(s, x) : push!(s, x) + end + s +end ==(l::Set, r::Set) = (length(l) == length(r)) && (l <= r) <( l::Set, r::Set) = (length(l) < length(r)) && (l <= r) @@ -217,14 +266,7 @@ julia> issubset([1, 2, 3], [1, 2]) false ``` """ -function issubset(l, r) - for elt in l - if !in(elt, r) - return false - end - end - return true -end +issubset(l, r) = all(x -> x in r, l) const ⊆ = issubset ⊊(l::Set, r::Set) = <(l, r) ⊈(l::Set, r::Set) = !⊆(l, r) @@ -232,6 +274,11 @@ const ⊆ = issubset ⊉(l::Set, r::Set) = !⊇(l, r) ⊋(l::Set, r::Set) = <(r, l) +⊊(l::T, r::T) where {T<:AbstractSet} = <(l, r) +⊈(l::T, r::T) where {T<:AbstractSet} = !⊆(l, r) +⊉(l::T, r::T) where {T<:AbstractSet} = !⊇(l, r) +⊋(l::T, r::T) where {T<:AbstractSet} = <(r, l) + """ unique(itr) @@ -445,23 +492,11 @@ allunique(::Set) = true allunique(r::AbstractRange{T}) where {T} = (step(r) != zero(T)) || (length(r) <= 1) -function filter(f, s::Set) - u = similar(s) - for x in s - if f(x) - push!(u, x) - end - end - return u -end -function filter!(f, s::Set) - for x in s - if !f(x) - delete!(s, x) - end - end - return s -end +filter(f, s::AbstractSet) = foldl(push!, similar(s), Iterators.filter(f, s)) + +# it must be safe to delete the current element while iterating over s +unsafe_filter!(f, s::AbstractSet) = foldl(delete!, s, Iterators.filter(!f, s)) +filter!(f, s::Set) = unsafe_filter!(f, s) const hashs_seed = UInt === UInt64 ? 0x852ada37cfe8e0ce : 0xcfe8e0ce function hash(s::Set, h::UInt) diff --git a/doc/src/stdlib/collections.md b/doc/src/stdlib/collections.md index b3a5a96deb01c..fd140e8e6ff44 100644 --- a/doc/src/stdlib/collections.md +++ b/doc/src/stdlib/collections.md @@ -234,9 +234,7 @@ Base.intersect Base.setdiff Base.setdiff! Base.symdiff -Base.symdiff!(::BitSet, ::Integer) -Base.symdiff!(::BitSet, ::Any) -Base.symdiff!(::BitSet, ::BitSet) +Base.symdiff! Base.intersect! Base.issubset ``` diff --git a/test/bitset.jl b/test/bitset.jl index 86e6c96b09ce5..3c7eeffc60997 100644 --- a/test/bitset.jl +++ b/test/bitset.jl @@ -195,7 +195,7 @@ end @test intersect(BitSet([1,2,3])) == BitSet([1,2,3]) @test intersect(BitSet(1:7), BitSet(3:10)) == intersect(BitSet(3:10), BitSet(1:7)) == BitSet(3:7) - @test intersect(BitSet(1:10), BitSet(1:4), 1:5, [2,3,10]) == [2,3] + @test intersect(BitSet(1:10), BitSet(1:4), 1:5, [2,3,10]) == BitSet([2,3]) end @testset "setdiff, symdiff" begin diff --git a/test/sets.jl b/test/sets.jl index 8d025b773eeab..179aaecc5fb9f 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -157,42 +157,64 @@ end end @testset "union" begin - @test isequal(union(Set([1])),Set([1])) - s = ∪(Set([1,2]), Set([3,4])) - @test isequal(s, Set([1,2,3,4])) - s = union(Set([5,6,7,8]), Set([7,8,9])) - @test isequal(s, Set([5,6,7,8,9])) - s = Set([1,3,5,7]) - union!(s,(2,3,4,5)) - @test isequal(s,Set([1,2,3,4,5,7])) - @test ===(typeof(union(Set([1]), BitSet())), Set{Int}) - @test isequal(union(Set([1,2,3]), 2:4), Set([1,2,3,4])) - @test isequal(union(Set([1,2,3]), [2,3,4]), Set([1,2,3,4])) - @test isequal(union(Set([1,2,3]), [2,3,4], Set([5])), Set([1,2,3,4,5])) + for S in (Set, BitSet) + s = ∪(S([1,2]), S([3,4])) + @test isequal(s, S([1,2,3,4])) + s = union(S([5,6,7,8]), S([7,8,9])) + @test isequal(s, S([5,6,7,8,9])) + s = S([1,3,5,7]) + union!(s,(2,3,4,5)) + @test isequal(s,S([1,2,3,4,5,7])) + let s1 = S([1, 2, 3]) + @test s1 !== union(s1) == s1 + @test s1 !== union(s1, 2:4) == S([1,2,3,4]) + @test s1 !== union(s1, [2,3,4]) == S([1,2,3,4]) + @test s1 !== union(s1, [2,3,4], S([5])) == S([1,2,3,4,5]) + @test s1 === union!(s1, [2,3,4], S([5])) == S([1,2,3,4,5]) + end + end + @test typeof(union(Set([1]), BitSet())) === Set{Int} + @test typeof(union(BitSet([1]), Set())) === BitSet end + @testset "intersect" begin - @test isequal(intersect(Set([1])),Set([1])) - s = ∩(Set([1,2]), Set([3,4])) - @test isequal(s, Set()) - s = intersect(Set([5,6,7,8]), Set([7,8,9])) - @test isequal(s, Set([7,8])) - @test isequal(intersect(Set([2,3,1]), Set([4,2,3]), Set([5,4,3,2])), Set([2,3])) - @test ===(typeof(intersect(Set([1]), BitSet())), Set{Int}) - @test isequal(intersect(Set([1,2,3]), 2:10), Set([2,3])) - @test isequal(intersect(Set([1,2,3]), [2,3,4]), Set([2,3])) - @test isequal(intersect(Set([1,2,3]), [2,3,4], 3:4), Set([3])) + for S in (Set, BitSet) + s = ∩(S([1,2]), S([3,4])) + @test isequal(s, S()) + s = intersect(S([5,6,7,8]), S([7,8,9])) + @test isequal(s, S([7,8])) + @test isequal(intersect(S([2,3,1]), S([4,2,3]), S([5,4,3,2])), S([2,3])) + let s1 = S([1,2,3]) + @test s1 !== intersect(s1) == s1 + @test s1 !== intersect(s1, 2:10) == S([2,3]) + @test s1 !== intersect(s1, [2,3,4]) == S([2,3]) + @test s1 !== intersect(s1, [2,3,4], 3:4) == S([3]) + @test s1 === intersect!(s1, [2,3,4], 3:4) == S([3]) + end + end + @test typeof(intersect(Set([1]), BitSet())) === Set{Int} + @test typeof(intersect(BitSet([1]), Set())) === BitSet end + @testset "setdiff" begin - @test isequal(setdiff(Set([1,2,3]), Set()), Set([1,2,3])) - @test isequal(setdiff(Set([1,2,3]), Set([1])), Set([2,3])) - @test isequal(setdiff(Set([1,2,3]), Set([1,2])), Set([3])) - @test isequal(setdiff(Set([1,2,3]), Set([1,2,3])), Set()) - @test isequal(setdiff(Set([1,2,3]), Set([4])), Set([1,2,3])) - @test isequal(setdiff(Set([1,2,3]), Set([4,1])), Set([2,3])) - @test ===(typeof(setdiff(Set([1]), BitSet())), Set{Int}) - @test isequal(setdiff(Set([1,2,3]), 2:10), Set([1])) - @test isequal(setdiff(Set([1,2,3]), [2,3,4]), Set([1])) - @test_throws MethodError setdiff(Set([1,2,3]), Set([2,3,4]), Set([1])) + for S in (Set, BitSet) + @test isequal(setdiff(S([1,2,3]), S()), S([1,2,3])) + @test isequal(setdiff(S([1,2,3]), S([1])), S([2,3])) + @test isequal(setdiff(S([1,2,3]), S([1,2])), S([3])) + @test isequal(setdiff(S([1,2,3]), S([1,2,3])), S()) + @test isequal(setdiff(S([1,2,3]), S([4])), S([1,2,3])) + @test isequal(setdiff(S([1,2,3]), S([4,1])), S([2,3])) + let s1 = S([1, 2, 3]) + @test s1 !== setdiff(s1) == s1 + @test s1 !== setdiff(s1, 2:10) == S([1]) + @test s1 !== setdiff(s1, [2,3,4]) == S([1]) + @test s1 !== setdiff(s1, S([2,3,4]), S([1])) == S() + @test s1 === setdiff!(s1, S([2,3,4]), S([1])) == S() + end + end + @test typeof(setdiff(Set([1]), BitSet())) === Set{Int} + @test typeof(setdiff(BitSet([1]), Set())) === BitSet + s = Set([1,3,5,7]) setdiff!(s,(3,5)) @test isequal(s,Set([1,7])) @@ -200,6 +222,7 @@ end setdiff!(s, Set([2,4,5,6])) @test isequal(s,Set([1,3])) end + @testset "ordering" begin @test Set() < Set([1]) @test Set([1]) < Set([1,2]) @@ -215,33 +238,42 @@ end @test !(Set([1,2,3]) >= Set([1,2,4])) @test !(Set([1,2,3]) <= Set([1,2,4])) end + @testset "issubset, symdiff" begin - for (l,r) in ((Set([1,2]), Set([3,4])), - (Set([5,6,7,8]), Set([7,8,9])), - (Set([1,2]), Set([3,4])), - (Set([5,6,7,8]), Set([7,8,9])), - (Set([1,2,3]), Set()), - (Set([1,2,3]), Set([1])), - (Set([1,2,3]), Set([1,2])), - (Set([1,2,3]), Set([1,2,3])), - (Set([1,2,3]), Set([4])), - (Set([1,2,3]), Set([4,1]))) - @test issubset(intersect(l,r), l) - @test issubset(intersect(l,r), r) - @test issubset(l, union(l,r)) - @test issubset(r, union(l,r)) - @test isequal(union(intersect(l,r),symdiff(l,r)), union(l,r)) + for S in (Set, BitSet) + for (l,r) in ((S([1,2]), S([3,4])), + (S([5,6,7,8]), S([7,8,9])), + (S([1,2]), S([3,4])), + (S([5,6,7,8]), S([7,8,9])), + (S([1,2,3]), S()), + (S([1,2,3]), S([1])), + (S([1,2,3]), S([1,2])), + (S([1,2,3]), S([1,2,3])), + (S([1,2,3]), S([4])), + (S([1,2,3]), S([4,1]))) + @test issubset(intersect(l,r), l) + @test issubset(intersect(l,r), r) + @test issubset(l, union(l,r)) + @test issubset(r, union(l,r)) + @test isequal(union(intersect(l,r),symdiff(l,r)), union(l,r)) + end + @test ⊆(S([1]), S([1,2])) + @test ⊊(S([1]), S([1,2])) + @test !⊊(S([1]), S([1])) + @test ⊈(S([1]), S([2])) + @test ⊇(S([1,2]), S([1])) + @test ⊋(S([1,2]), S([1])) + @test !⊋(S([1]), S([1])) + @test ⊉(S([1]), S([2])) + let s1 = S([1,2,3,4]) + @test s1 !== symdiff(s1) == s1 + @test s1 !== symdiff(s1, S([2,4,5,6])) == S([1,3,5,6]) + @test s1 !== symdiff(s1, S([2,4,5,6]), [1,6,7]) == S([3,5,7]) + @test s1 === symdiff!(s1, S([2,4,5,6]), [1,6,7]) == S([3,5,7]) + end end - @test ⊆(Set([1]), Set([1,2])) - @test ⊊(Set([1]), Set([1,2])) - @test !⊊(Set([1]), Set([1])) - @test ⊈(Set([1]), Set([2])) - @test ⊇(Set([1,2]), Set([1])) - @test ⊋(Set([1,2]), Set([1])) - @test !⊋(Set([1]), Set([1])) - @test ⊉(Set([1]), Set([2])) - @test symdiff(Set([1,2,3,4]), Set([2,4,5,6])) == Set([1,3,5,6]) end + @testset "unique" begin u = unique([1, 1, 2]) @test in(1, u) @@ -307,11 +339,10 @@ end @test allunique(4:-1:5) # empty range @test allunique(7:-1:1) # negative step end -@testset "filter" begin - s = Set([1,2,3,4]) - @test isequal(filter(isodd,s), Set([1,3])) - filter!(isodd, s) - @test isequal(s, Set([1,3])) +@testset "filter(f, ::$S)" for S = (Set, BitSet) + s = S([1,2,3,4]) + @test s !== filter( isodd, s) == S([1,3]) + @test s === filter!(isodd, s) == S([1,3]) end @testset "first" begin @test_throws ArgumentError first(Set()) From de8774abbdb42d02e743c700c5a55d76a8a280ca Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Fri, 1 Sep 2017 15:31:42 +0200 Subject: [PATCH 2/3] fix perf regression in filter/filter! --- base/set.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/base/set.jl b/base/set.jl index 6544c9a8c32bd..b9caaf586d212 100644 --- a/base/set.jl +++ b/base/set.jl @@ -153,17 +153,7 @@ julia> intersect([1, 4, 4, 5, 6], [4, 6, 6, 7, 8]) function intersect end intersect(s) = copymutable(s) - -# enable when as fast as the loop: -# intersect(s::AbstractSet, itr) = IntSet(x for x in itr if x in s) -function intersect(s::AbstractSet, itr) - t = similar(s) - for x in itr - x in s && push!(t, x) - end - t -end - +intersect(s::AbstractSet, itr) = mapfilter(x->in(x, s), push!, itr, similar(s)) intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...) const ∩ = intersect @@ -492,12 +482,20 @@ allunique(::Set) = true allunique(r::AbstractRange{T}) where {T} = (step(r) != zero(T)) || (length(r) <= 1) -filter(f, s::AbstractSet) = foldl(push!, similar(s), Iterators.filter(f, s)) - -# it must be safe to delete the current element while iterating over s -unsafe_filter!(f, s::AbstractSet) = foldl(delete!, s, Iterators.filter(!f, s)) +filter(pred, s::AbstractSet) = mapfilter(pred, push!, s, similar(s)) filter!(f, s::Set) = unsafe_filter!(f, s) +# it must be safe to delete the current element while iterating over s: +unsafe_filter!(pred, s::AbstractSet) = mapfilter(!pred, delete!, s, s) + +# TODO: delete mapfilter in favor of comprehensions/foldl/filter when competitive +function mapfilter(pred, f, itr, res) + for x in itr + pred(x) && f(res, x) + end + res +end + const hashs_seed = UInt === UInt64 ? 0x852ada37cfe8e0ce : 0xcfe8e0ce function hash(s::Set, h::UInt) hv = hashs_seed From 6bf69739d9b7f575f12b024cff40751c7c4aeb77 Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Sat, 2 Sep 2017 14:20:16 +0200 Subject: [PATCH 3/3] add mutating set-like operations for AbstractVector --- NEWS.md | 14 +++++++ base/abstractarray.jl | 6 ++- base/array.jl | 92 ++++++++++++++++++------------------------- base/bitset.jl | 8 ++-- base/set.jl | 87 ++++++++++++++++++++++++++++------------ test/arrayops.jl | 2 +- test/sets.jl | 92 +++++++++++++++++++++++++++---------------- 7 files changed, 184 insertions(+), 117 deletions(-) diff --git a/NEWS.md b/NEWS.md index 06b4e218fde95..9396bea11240a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -469,6 +469,20 @@ Library improvements linear-to-cartesian conversion ([#24715]) - It has a new constructor taking an array + * several missing set-like operations have been added ([#23528]): + `union`, `intersect`, `symdiff`, `setdiff` are now implemented for + all collections with arbitrary many arguments, as well as the + mutating counterparts (`union!` etc.). The performance is also + much better in many cases. Note that this change is slightly + breaking: all the non-mutating functions always return a new + object even if only one argument is passed. Moreover the semantics + of `intersect` and `symdiff` is changed for vectors: + + `intersect` doesn't preserve the multiplicity anymore (use `filter` for + the old behavior) + + `symdiff` has been made consistent with the corresponding methods for + other containers, by taking the multiplicity of the arguments into account. + Use `unique` to get the old behavior. + * The type `LinearIndices` has been added, providing conversion from cartesian incices to linear indices using the normal indexing operation. ([#24715]) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0b0bc3ab3afe1..5f1386b729500 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -580,8 +580,10 @@ julia> empty([1.0, 2.0, 3.0], String) 0-element Array{String,1} ``` """ -empty(a::AbstractVector) = empty(a, eltype(a)) -empty(a::AbstractVector, ::Type{T}) where {T} = Vector{T}() +empty(a::AbstractVector{T}, ::Type{U}=T) where {T,U} = Vector{U}() + +# like empty, but should return a mutable collection, a Vector by default +emptymutable(a::AbstractVector{T}, ::Type{U}=T) where {T,U} = Vector{U}() ## from general iterable to any array diff --git a/base/array.jl b/base/array.jl index f4f21791768c1..bea192e129ea3 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2209,67 +2209,53 @@ function filter!(f, a::AbstractVector) return a end -function filter(f, a::Vector) - r = Vector{eltype(a)}() - for ai in a - if f(ai) - push!(r, ai) - end - end - return r -end +filter(f, a::Vector) = mapfilter(f, push!, a, similar(a, 0)) # set-like operators for vectors # These are moderately efficient, preserve order, and remove dupes. -function intersect(v1, vs...) - ret = Vector{promote_eltype(v1, vs...)}() - for v_elem in v1 - inall = true - for vsi in vs - if !in(v_elem, vsi) - inall=false; break - end - end - if inall - push!(ret, v_elem) - end +_unique_filter!(pred, update!, state) = function (x) + if pred(x, state) + update!(state, x) + true + else + false end - ret end -function union(vs...) - ret = Vector{promote_eltype(vs...)}() - seen = Set() - for v in vs - for v_elem in v - if !in(v_elem, seen) - push!(ret, v_elem) - push!(seen, v_elem) - end - end +_grow_filter!(seen) = _unique_filter!(∉, push!, seen) +_shrink_filter!(keep) = _unique_filter!(∈, pop!, keep) + +function _grow!(pred!, v::AbstractVector, itrs) + filter!(pred!, v) # uniquify v + foldl(v, itrs) do v, itr + mapfilter(pred!, push!, itr, v) end - ret end -# setdiff only accepts two args -function setdiff(a, b) - args_type = promote_type(eltype(a), eltype(b)) - bset = Set(b) - ret = Vector{args_type}() - seen = Set{eltype(a)}() - for a_elem in a - if !in(a_elem, seen) && !in(a_elem, bset) - push!(ret, a_elem) - push!(seen, a_elem) - end - end - ret +union!(v::AbstractVector{T}, itrs...) where {T} = + _grow!(_grow_filter!(sizehint!(Set{T}(), length(v))), v, itrs) + +symdiff!(v::AbstractVector{T}, itrs...) where {T} = + _grow!(_shrink_filter!(symdiff!(Set{T}(), v, itrs...)), v, itrs) + +function _shrink!(shrinker!, v::AbstractVector, itrs) + seen = Set{eltype(v)}() + filter!(_grow_filter!(seen), v) + shrinker!(seen, itrs...) + filter!(_in(seen), v) end -# symdiff is associative, so a relatively clean -# way to implement this is by using setdiff and union, and -# recursing. Has the advantage of keeping order, too, but -# not as fast as other methods that make a single pass and -# store counts with a Dict. -symdiff(a, b) = union(setdiff(a,b), setdiff(b,a)) -symdiff(a, b, rest...) = symdiff(a, symdiff(b, rest...)) + +intersect!(v::AbstractVector, itrs...) = _shrink!(intersect!, v, itrs) +setdiff!( v::AbstractVector, itrs...) = _shrink!(setdiff!, v, itrs) + +vectorfilter(f, v::AbstractVector) = filter(f, v) # TODO: do we want this special case? +vectorfilter(f, v) = [x for x in v if f(x)] + +function _shrink(shrinker!, itr, itrs) + keep = shrinker!(Set(itr), itrs...) + vectorfilter(_shrink_filter!(keep), itr) +end + +intersect(itr, itrs...) = _shrink(intersect!, itr, itrs) +setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs) diff --git a/base/bitset.jl b/base/bitset.jl index f2c07235e3ae1..234faa8a28ce2 100644 --- a/base/bitset.jl +++ b/base/bitset.jl @@ -32,7 +32,12 @@ BitSet(itr) = union!(BitSet(), itr) eltype(::Type{BitSet}) = Int similar(s::BitSet) = BitSet() + +empty(s::BitSet, ::Type{Int}=Int) = BitSet() +emptymutable(s::BitSet, ::Type{Int}=Int) = BitSet() + copy(s1::BitSet) = copy!(BitSet(), s1) +copymutable(s::BitSet) = copy(s) """ copy!(dst, src) @@ -49,8 +54,6 @@ function copy!(dest::BitSet, src::BitSet) dest end -copymutable(s::BitSet) = copy(s) - eltype(s::BitSet) = Int sizehint!(s::BitSet, n::Integer) = (sizehint!(s.bits, (n+63) >> 6); s) @@ -256,7 +259,6 @@ isempty(s::BitSet) = _check0(s.bits, 1, length(s.bits)) # Mathematical set functions: union!, intersect!, setdiff!, symdiff! union(s::BitSet, sets...) = union!(copy(s), sets...) -union!(s::BitSet, ns) = foldl(push!, s, ns) union!(s1::BitSet, s2::BitSet) = _matched_map!(|, s1, s2) intersect(s1::BitSet, s2::BitSet) = diff --git a/base/set.jl b/base/set.jl index b9caaf586d212..f1de13afd08c7 100644 --- a/base/set.jl +++ b/base/set.jl @@ -27,8 +27,13 @@ function Set(g::Generator) return Set{T}(g) end -similar(s::Set{T}) where {T} = Set{T}() -similar(s::Set, T::Type) = Set{T}() +similar(s::Set{T}, ::Type{U}=T) where {T,U} = Set{U}() + +empty(s::Set{T}, ::Type{U}=T) where {T,U} = Set{U}() + +# return an empty set with eltype T, which is mutable (can be grown) +# by default, a Set is returned +emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}() function show(io::IO, s::Set) print(io, "Set(") @@ -88,23 +93,30 @@ julia> union([1, 2], [2, 4]) 2 4 -julia> union([4, 2], [1, 2]) +julia> union([4, 2], 1:2) 3-element Array{Int64,1}: 4 2 1 + +julia> union(Set([1, 2]), 2:3) +Set([2, 3, 1]) ``` """ function union end -union(s::Set, sets...) = union!(Set{join_eltype(s, sets...)}(), s, sets...) +_in(itr) = x -> x in itr + +union(s, sets...) = union!(emptymutable(s, promote_eltype(s, sets...)), s, sets...) +union(s::AbstractSet) = copy(s) const ∪ = union """ - union!(s::AbstractSet, itrs...) + union!(s::Union{AbstractSet,AbstractVector}, itrs...) Construct the union of passed in sets and overwrite `s` with the result. +Maintain order with arrays. # Examples ```jldoctest @@ -116,6 +128,11 @@ julia> a Set([7, 4, 3, 5, 1]) ``` """ +union!(s::AbstractSet, sets...) = foldl(union!, s, sets) + +# default generic 2-args implementation with push! +union!(s::AbstractSet, itr) = foldl(push!, s, itr) + function union!(s::Set{T}, itr) where T haslength(itr) && sizehint!(s, length(itr)) for x=itr @@ -125,17 +142,13 @@ function union!(s::Set{T}, itr) where T s end -union!(s::AbstractSet, sets...) = foldl(union!, s, sets) - -join_eltype() = Bottom -join_eltype(v1, vs...) = typejoin(eltype(v1), join_eltype(vs...)) """ intersect(s, itrs...) ∩(s, itrs...) Construct the intersection of sets. -Maintain order and multiplicity of the first argument for arrays and ranges. +Maintain order with arrays. # Examples ```jldoctest @@ -144,28 +157,29 @@ julia> intersect([1, 2, 3], [3, 4, 5]) 3 julia> intersect([1, 4, 4, 5, 6], [4, 6, 6, 7, 8]) -3-element Array{Int64,1}: - 4 +2-element Array{Int64,1}: 4 6 + +julia> intersect(Set([1, 2]), BitSet([2, 3])) +Set([2]) ``` """ -function intersect end - -intersect(s) = copymutable(s) -intersect(s::AbstractSet, itr) = mapfilter(x->in(x, s), push!, itr, similar(s)) intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...) +intersect(s) = union(s) +intersect(s::AbstractSet, itr) = mapfilter(_in(s), push!, itr, emptymutable(s)) const ∩ = intersect """ - intersect!(s::AbstractSet, itrs...) + intersect!(s::Union{AbstractSet,AbstractVector}, itrs...) Intersect all passed in sets and overwrite `s` with the result. +Maintain order with arrays. """ -intersect!(s::AbstractSet, s2::AbstractSet) = filter!(x -> x in s2, s) -intersect!(s::AbstractSet, itr) = intersect!(s, union!(similar(s), itr)) intersect!(s::AbstractSet, itrs...) = foldl(intersect!, s, itrs) +intersect!(s::AbstractSet, s2::AbstractSet) = filter!(_in(s2), s) +intersect!(s::AbstractSet, itr) = intersect!(s, union!(emptymutable(s), itr)) """ setdiff(s, itrs...) @@ -182,12 +196,13 @@ julia> setdiff([1,2,3], [3,4,5]) ``` """ setdiff(s::AbstractSet, itrs...) = setdiff!(copymutable(s), itrs...) -setdiff(s) = copymutable(s) +setdiff(s) = union(s) """ setdiff!(s, itrs...) Remove from set `s` (in-place) each element of each iterable from `itrs`. +Maintain order with arrays. # Examples ```jldoctest @@ -207,7 +222,8 @@ setdiff!(s::AbstractSet, itr) = foldl(delete!, s, itr) symdiff(s, itrs...) Construct the symmetric difference of elements in the passed in sets. -Maintains order with arrays. +When `s` is not an `AbstractSet`, the order is maintained. +Note that in this case the multiplicity of elements matters. # Examples ```jldoctest @@ -216,15 +232,25 @@ julia> symdiff([1,2,3], [3,4,5], [4,5,6]) 1 2 6 + +julia> symdiff([1,2,1], [2, 1, 2]) +2-element Array{Int64,1}: + 1 + 2 + +julia> symdiff(unique([1,2,1]), unique([2, 1, 2])) +0-element Array{Int64,1} ``` """ -symdiff(s::AbstractSet, sets...) = symdiff!(copymutable(s), sets...) -symdiff(s) = copymutable(s) # remove when method above becomes as efficient +symdiff(s, sets...) = symdiff!(emptymutable(s, promote_eltype(s, sets...)), s, sets...) +symdiff(s) = symdiff!(copy(s)) """ - symdiff!(s::AbstractSet, itrs...) + symdiff!(s::Union{AbstractSet,AbstractVector}, itrs...) Construct the symmetric difference of the passed in sets, and overwrite `s` with the result. +When `s` is an array, the order is maintained. +Note that in this case the multiplicity of elements matters. """ symdiff!(s::AbstractSet, itrs...) = foldl(symdiff!, s, itrs) @@ -256,7 +282,18 @@ julia> issubset([1, 2, 3], [1, 2]) false ``` """ -issubset(l, r) = all(x -> x in r, l) +function issubset(l, r) + for elt in l + if !in(elt, r) + return false + end + end + return true +end + +# use the implementation below when it becoms as efficient +# issubset(l, r) = all(_in(r), l) + const ⊆ = issubset ⊊(l::Set, r::Set) = <(l, r) ⊈(l::Set, r::Set) = !⊆(l, r) diff --git a/test/arrayops.jl b/test/arrayops.jl index 6db4148f13694..bdd96ff1831a4 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -932,7 +932,7 @@ end @test isequal(setdiff([1,2,3,4], [7,8,9]), [1,2,3,4]) @test isequal(setdiff([1,2,3,4], Int64[]), Int64[1,2,3,4]) @test isequal(setdiff([1,2,3,4], [1,2,3,4,5]), Int64[]) - @test isequal(symdiff([1,2,3], [4,3,4]), [1,2,4]) + @test isequal(symdiff([1,2,3], [4,3,4]), [1,2]) @test isequal(symdiff(['e','c','a'], ['b','a','d']), ['e','c','b','d']) @test isequal(symdiff([1,2,3], [4,3], [5]), [1,2,4,5]) @test isequal(symdiff([1,2,3,4,5], [1,2,3], [3,4]), [3,5]) diff --git a/test/sets.jl b/test/sets.jl index 179aaecc5fb9f..66dcd34664c8a 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -157,14 +157,14 @@ end end @testset "union" begin - for S in (Set, BitSet) + for S in (Set, BitSet, Vector) s = ∪(S([1,2]), S([3,4])) - @test isequal(s, S([1,2,3,4])) + @test s == S([1,2,3,4]) s = union(S([5,6,7,8]), S([7,8,9])) - @test isequal(s, S([5,6,7,8,9])) + @test s == S([5,6,7,8,9]) s = S([1,3,5,7]) - union!(s,(2,3,4,5)) - @test isequal(s,S([1,2,3,4,5,7])) + union!(s, (2,3,4,5)) + @test s == S([1,3,5,7,2,4]) # order matters for Vector let s1 = S([1, 2, 3]) @test s1 !== union(s1) == s1 @test s1 !== union(s1, 2:4) == S([1,2,3,4]) @@ -173,17 +173,21 @@ end @test s1 === union!(s1, [2,3,4], S([5])) == S([1,2,3,4,5]) end end - @test typeof(union(Set([1]), BitSet())) === Set{Int} - @test typeof(union(BitSet([1]), Set())) === BitSet + @test union(Set([1]), BitSet()) isa Set{Int} + @test union(BitSet([1]), Set()) isa BitSet + @test union([1], BitSet()) isa Vector{Int} + # union must uniquify + @test union([1, 2, 1]) == union!([1, 2, 1]) == [1, 2] + @test union([1, 2, 1], [2, 2]) == union!([1, 2, 1], [2, 2]) == [1, 2] end @testset "intersect" begin - for S in (Set, BitSet) - s = ∩(S([1,2]), S([3,4])) - @test isequal(s, S()) + for S in (Set, BitSet, Vector) + s = S([1,2]) ∩ S([3,4]) + @test s == S() s = intersect(S([5,6,7,8]), S([7,8,9])) - @test isequal(s, S([7,8])) - @test isequal(intersect(S([2,3,1]), S([4,2,3]), S([5,4,3,2])), S([2,3])) + @test s == S([7,8]) + @test intersect(S([2,3,1]), S([4,2,3]), S([5,4,3,2])) == S([2,3]) let s1 = S([1,2,3]) @test s1 !== intersect(s1) == s1 @test s1 !== intersect(s1, 2:10) == S([2,3]) @@ -192,18 +196,22 @@ end @test s1 === intersect!(s1, [2,3,4], 3:4) == S([3]) end end - @test typeof(intersect(Set([1]), BitSet())) === Set{Int} - @test typeof(intersect(BitSet([1]), Set())) === BitSet + @test intersect(Set([1]), BitSet()) isa Set{Int} + @test intersect(BitSet([1]), Set()) isa BitSet + @test intersect([1], BitSet()) isa Vector{Int} + # intersect must uniquify + @test intersect([1, 2, 1]) == intersect!([1, 2, 1]) == [1, 2] + @test intersect([1, 2, 1], [2, 2]) == intersect!([1, 2, 1], [2, 2]) == [2] end @testset "setdiff" begin - for S in (Set, BitSet) - @test isequal(setdiff(S([1,2,3]), S()), S([1,2,3])) - @test isequal(setdiff(S([1,2,3]), S([1])), S([2,3])) - @test isequal(setdiff(S([1,2,3]), S([1,2])), S([3])) - @test isequal(setdiff(S([1,2,3]), S([1,2,3])), S()) - @test isequal(setdiff(S([1,2,3]), S([4])), S([1,2,3])) - @test isequal(setdiff(S([1,2,3]), S([4,1])), S([2,3])) + for S in (Set, BitSet, Vector) + @test setdiff(S([1,2,3]), S()) == S([1,2,3]) + @test setdiff(S([1,2,3]), S([1])) == S([2,3]) + @test setdiff(S([1,2,3]), S([1,2])) == S([3]) + @test setdiff(S([1,2,3]), S([1,2,3])) == S() + @test setdiff(S([1,2,3]), S([4])) == S([1,2,3]) + @test setdiff(S([1,2,3]), S([4,1])) == S([2,3]) let s1 = S([1, 2, 3]) @test s1 !== setdiff(s1) == s1 @test s1 !== setdiff(s1, 2:10) == S([1]) @@ -212,8 +220,13 @@ end @test s1 === setdiff!(s1, S([2,3,4]), S([1])) == S() end end - @test typeof(setdiff(Set([1]), BitSet())) === Set{Int} - @test typeof(setdiff(BitSet([1]), Set())) === BitSet + + @test setdiff(Set([1]), BitSet()) isa Set{Int} + @test setdiff(BitSet([1]), Set()) isa BitSet + @test setdiff([1], BitSet()) isa Vector{Int} + # setdiff must uniquify + @test setdiff([1, 2, 1]) == setdiff!([1, 2, 1]) == [1, 2] + @test setdiff([1, 2, 1], [2, 2]) == setdiff!([1, 2, 1], [2, 2]) == [1] s = Set([1,3,5,7]) setdiff!(s,(3,5)) @@ -240,7 +253,7 @@ end end @testset "issubset, symdiff" begin - for S in (Set, BitSet) + for S in (Set, BitSet, Vector) for (l,r) in ((S([1,2]), S([3,4])), (S([5,6,7,8]), S([7,8,9])), (S([1,2]), S([3,4])), @@ -255,16 +268,22 @@ end @test issubset(intersect(l,r), r) @test issubset(l, union(l,r)) @test issubset(r, union(l,r)) - @test isequal(union(intersect(l,r),symdiff(l,r)), union(l,r)) + if S === Vector + @test sort(union(intersect(l,r),symdiff(l,r))) == sort(union(l,r)) + else + @test union(intersect(l,r),symdiff(l,r)) == union(l,r) + end + end + if S !== Vector + @test ⊆(S([1]), S([1,2])) + @test ⊊(S([1]), S([1,2])) + @test !⊊(S([1]), S([1])) + @test ⊈(S([1]), S([2])) + @test ⊇(S([1,2]), S([1])) + @test ⊋(S([1,2]), S([1])) + @test !⊋(S([1]), S([1])) + @test ⊉(S([1]), S([2])) end - @test ⊆(S([1]), S([1,2])) - @test ⊊(S([1]), S([1,2])) - @test !⊊(S([1]), S([1])) - @test ⊈(S([1]), S([2])) - @test ⊇(S([1,2]), S([1])) - @test ⊋(S([1,2]), S([1])) - @test !⊋(S([1]), S([1])) - @test ⊉(S([1]), S([2])) let s1 = S([1,2,3,4]) @test s1 !== symdiff(s1) == s1 @test s1 !== symdiff(s1, S([2,4,5,6])) == S([1,3,5,6]) @@ -272,6 +291,13 @@ end @test s1 === symdiff!(s1, S([2,4,5,6]), [1,6,7]) == S([3,5,7]) end end + @test symdiff(Set([1,2,3,4]), Set([2,4,5,6])) == Set([1,3,5,6]) + @test symdiff(Set([1]), BitSet()) isa Set{Int} + @test symdiff(BitSet([1]), Set{Int}()) isa BitSet + @test symdiff([1], BitSet()) isa Vector{Int} + # symdiff must NOT uniquify + @test symdiff([1, 2, 1]) == symdiff!([1, 2, 1]) == [2] + @test symdiff([1, 2, 1], [2, 2]) == symdiff!([1, 2, 1], [2, 2]) == [2] end @testset "unique" begin