From 0557e9372535a12b04020530844978deb4ad395c Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 4 Mar 2022 10:05:43 +0100 Subject: [PATCH 1/3] Fix pairwise for type-unstable corner case function `promote_type` is not a completely correct way of computing an upper bound for the return eltype. Use the same strategy as `map` and `broadcast` in Base instead, but with `promote_eltype` rather than `promote_typejoin`. We can still use `typejoin_union_tuple` since promotion does not happen inside tuple types. On Julia versions before 1.6 we would have to copy the full definition of `typejoin_union_tuple`, which is quite complex, so instead fall back to inferring eltype `Any` (inference wasn't so good anyway on these versions). --- src/pairwise.jl | 31 +++++++++++++++++++++++++++++-- test/pairwise.jl | 13 +++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/pairwise.jl b/src/pairwise.jl index 97e51a3f1..8bb820c11 100644 --- a/src/pairwise.jl +++ b/src/pairwise.jl @@ -122,6 +122,32 @@ function _pairwise!(f, dest::AbstractMatrix, x, y; return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric) end +if VERSION >= v"1.6.0-DEV" + # Function has moved in Julia 1.7 + if isdefined(Base, :typejoin_union_tuple) + using Base: typejoin_union_tuple + else + using Base.Broadcast: typejoin_union_tuple + end + # Identical to `Base.promote_typejoin` except that it uses `promote_type` + # instead of `typejoin` + function promote_type_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa UnionAll + return Any # TODO: compute more precise bounds + elseif T isa Union + return promote_type(promote_type_union(T.a), promote_type_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end + end +else + promote_type_union(::Type) = Any +end + function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing} x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x) y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y) @@ -148,10 +174,11 @@ function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmiss if isconcretetype(eltype(dest)) return dest else - # Final eltype depends on actual contents (consistent with map and broadcast) + # Final eltype depends on actual contents (consistent with `map` and `broadcast` + # but using `promote_type` rather than `promote_typejoin`) U = mapreduce(typeof, promote_type, dest) # V is inferred (contrary to U), but it only gives an upper bound for U - V = promote_type(T, Tsm) + V = promote_type_union(Union{T, Tsm}) return convert(Matrix{U}, dest)::Matrix{<:V} end end diff --git a/test/pairwise.jl b/test/pairwise.jl index 09699b276..397256c3f 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -258,4 +258,17 @@ arbitrary_fun(x, y) = cor(x, y) end end end + + @testset "type-unstable corner case (#771)" begin + v = [rand(5) for _=1:10] + function f(v) + pairwise(v) do x, y + (x[1] < 0 ? nothing : + x[1] > y[1] ? 1 : 1.5, + 0) + end + end + res = f(v) + @test res isa Matrix{Tuple{Real, Int}} + end end \ No newline at end of file From 626438e202252f0d9404261d4dcd86df08a73ae4 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 15 Mar 2022 10:30:36 +0100 Subject: [PATCH 2/3] More precise type on Julia < 1.6, more tests --- src/pairwise.jl | 33 +++++++++++++++++---------------- test/pairwise.jl | 15 +++++++++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/pairwise.jl b/src/pairwise.jl index 8bb820c11..822a6d703 100644 --- a/src/pairwise.jl +++ b/src/pairwise.jl @@ -129,23 +129,24 @@ if VERSION >= v"1.6.0-DEV" else using Base.Broadcast: typejoin_union_tuple end - # Identical to `Base.promote_typejoin` except that it uses `promote_type` - # instead of `typejoin` - function promote_type_union(::Type{T}) where T - if T === Union{} - return Union{} - elseif T isa UnionAll - return Any # TODO: compute more precise bounds - elseif T isa Union - return promote_type(promote_type_union(T.a), promote_type_union(T.b)) - elseif T <: Tuple - return typejoin_union_tuple(T) - else - return T - end - end else - promote_type_union(::Type) = Any + typejoin_union_tuple(::Type) = Any +end + +# Identical to `Base.promote_typejoin` except that it uses `promote_type` +# instead of `typejoin` to combine members of `Union` types +function promote_type_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa UnionAll + return Any # TODO: compute more precise bounds + elseif T isa Union + return promote_type(promote_type_union(T.a), promote_type_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end end function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing} diff --git a/test/pairwise.jl b/test/pairwise.jl index 397256c3f..241ae44e1 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -259,6 +259,21 @@ arbitrary_fun(x, y) = cor(x, y) end end + @testset "promote_type_union" begin + @test StatsBase.promote_type_union(Union{Int, Float64}) === Float64 + @test StatsBase.promote_type_union(Union{Int, Missing}) === Union{Int, Missing} + @test StatsBase.promote_type_union(Union{Int, String}) === Any + @test StatsBase.promote_type_union(Vector) === Any + @test StatsBase.promote_type_union(Union{}) === Union{} + if VERSION >= v"1.6.0-DEV" + @test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) === + Tuple{Real} + else + @test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) === + Any + end + end + @testset "type-unstable corner case (#771)" begin v = [rand(5) for _=1:10] function f(v) From 650ca27bd34c25b180a8becf4d3142f2016c5484 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 15 Mar 2022 10:48:42 +0100 Subject: [PATCH 3/3] Add tests --- test/pairwise.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/pairwise.jl b/test/pairwise.jl index 241ae44e1..59e081b08 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -260,6 +260,8 @@ arbitrary_fun(x, y) = cor(x, y) end @testset "promote_type_union" begin + @test StatsBase.promote_type_union(Int) === Int + @test StatsBase.promote_type_union(Real) === Real @test StatsBase.promote_type_union(Union{Int, Float64}) === Float64 @test StatsBase.promote_type_union(Union{Int, Missing}) === Union{Int, Missing} @test StatsBase.promote_type_union(Union{Int, String}) === Any