diff --git a/src/pairwise.jl b/src/pairwise.jl index 97e51a3f1..822a6d703 100644 --- a/src/pairwise.jl +++ b/src/pairwise.jl @@ -122,6 +122,33 @@ function _pairwise!(f, dest::AbstractMatrix, x, y; return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric) end +if VERSION >= v"1.6.0-DEV" + # Function has moved in Julia 1.7 + if isdefined(Base, :typejoin_union_tuple) + using Base: typejoin_union_tuple + else + using Base.Broadcast: typejoin_union_tuple + end +else + typejoin_union_tuple(::Type) = Any +end + +# Identical to `Base.promote_typejoin` except that it uses `promote_type` +# instead of `typejoin` to combine members of `Union` types +function promote_type_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa UnionAll + return Any # TODO: compute more precise bounds + elseif T isa Union + return promote_type(promote_type_union(T.a), promote_type_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end +end + function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing} x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x) y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y) @@ -148,10 +175,11 @@ function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmiss if isconcretetype(eltype(dest)) return dest else - # Final eltype depends on actual contents (consistent with map and broadcast) + # Final eltype depends on actual contents (consistent with `map` and `broadcast` + # but using `promote_type` rather than `promote_typejoin`) U = mapreduce(typeof, promote_type, dest) # V is inferred (contrary to U), but it only gives an upper bound for U - V = promote_type(T, Tsm) + V = promote_type_union(Union{T, Tsm}) return convert(Matrix{U}, dest)::Matrix{<:V} end end diff --git a/test/pairwise.jl b/test/pairwise.jl index 09699b276..59e081b08 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -258,4 +258,34 @@ arbitrary_fun(x, y) = cor(x, y) end end end + + @testset "promote_type_union" begin + @test StatsBase.promote_type_union(Int) === Int + @test StatsBase.promote_type_union(Real) === Real + @test StatsBase.promote_type_union(Union{Int, Float64}) === Float64 + @test StatsBase.promote_type_union(Union{Int, Missing}) === Union{Int, Missing} + @test StatsBase.promote_type_union(Union{Int, String}) === Any + @test StatsBase.promote_type_union(Vector) === Any + @test StatsBase.promote_type_union(Union{}) === Union{} + if VERSION >= v"1.6.0-DEV" + @test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) === + Tuple{Real} + else + @test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) === + Any + end + end + + @testset "type-unstable corner case (#771)" begin + v = [rand(5) for _=1:10] + function f(v) + pairwise(v) do x, y + (x[1] < 0 ? nothing : + x[1] > y[1] ? 1 : 1.5, + 0) + end + end + res = f(v) + @test res isa Matrix{Tuple{Real, Int}} + end end \ No newline at end of file