Skip to content

Commit

Permalink
Fix pairwise for type-unstable corner case function (#772)
Browse files Browse the repository at this point in the history
`promote_type` is not a completely correct way of computing an upper bound
for the return eltype. Use the same strategy as `map` and `broadcast` in
Base instead, but with `promote_eltype` rather than `promote_typejoin`.
We can still use `typejoin_union_tuple` since promotion does not happen
inside tuple types.

On Julia versions before 1.6 we would have to copy the full definition of
`typejoin_union_tuple`, which is quite complex, so instead fall back to
inferring eltype `Any`.
  • Loading branch information
nalimilan authored Mar 15, 2022
1 parent f9cfd12 commit a1b02d8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
32 changes: 30 additions & 2 deletions src/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,33 @@ function _pairwise!(f, dest::AbstractMatrix, x, y;
return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric)
end

if VERSION >= v"1.6.0-DEV"
# Function has moved in Julia 1.7
if isdefined(Base, :typejoin_union_tuple)
using Base: typejoin_union_tuple
else
using Base.Broadcast: typejoin_union_tuple
end
else
typejoin_union_tuple(::Type) = Any
end

# Identical to `Base.promote_typejoin` except that it uses `promote_type`
# instead of `typejoin` to combine members of `Union` types
function promote_type_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_type(promote_type_union(T.a), promote_type_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing}
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
Expand All @@ -148,10 +175,11 @@ function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmiss
if isconcretetype(eltype(dest))
return dest
else
# Final eltype depends on actual contents (consistent with map and broadcast)
# Final eltype depends on actual contents (consistent with `map` and `broadcast`
# but using `promote_type` rather than `promote_typejoin`)
U = mapreduce(typeof, promote_type, dest)
# V is inferred (contrary to U), but it only gives an upper bound for U
V = promote_type(T, Tsm)
V = promote_type_union(Union{T, Tsm})
return convert(Matrix{U}, dest)::Matrix{<:V}
end
end
Expand Down
30 changes: 30 additions & 0 deletions test/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,34 @@ arbitrary_fun(x, y) = cor(x, y)
end
end
end

@testset "promote_type_union" begin
@test StatsBase.promote_type_union(Int) === Int
@test StatsBase.promote_type_union(Real) === Real
@test StatsBase.promote_type_union(Union{Int, Float64}) === Float64
@test StatsBase.promote_type_union(Union{Int, Missing}) === Union{Int, Missing}
@test StatsBase.promote_type_union(Union{Int, String}) === Any
@test StatsBase.promote_type_union(Vector) === Any
@test StatsBase.promote_type_union(Union{}) === Union{}
if VERSION >= v"1.6.0-DEV"
@test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) ===
Tuple{Real}
else
@test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) ===
Any
end
end

@testset "type-unstable corner case (#771)" begin
v = [rand(5) for _=1:10]
function f(v)
pairwise(v) do x, y
(x[1] < 0 ? nothing :
x[1] > y[1] ? 1 : 1.5,
0)
end
end
res = f(v)
@test res isa Matrix{Tuple{Real, Int}}
end
end

0 comments on commit a1b02d8

Please sign in to comment.