Skip to content

Commit

Permalink
Make return type of map inferrable with heterogeneous arrays
Browse files Browse the repository at this point in the history
Inference is not able to detect the element type automatically,
but we can do it manually since we know promote_typejoin is used for widening.
This is similar to the approach used for `broadcast` at #30485.
  • Loading branch information
nalimilan committed Aug 29, 2021
1 parent 6e91085 commit b96dd48
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 51 deletions.
9 changes: 7 additions & 2 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,13 @@ function collect(itr::Generator)
return _array_for(et, itr.iter, isz)
end
v1, st = y
arr = _array_for(typeof(v1), itr.iter, isz, shape)
return collect_to_with_first!(arr, v1, itr, st)
dest = _array_for(typeof(v1), itr.iter, isz, shape)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
ElType = promote_typejoin_union(et)
ElType′ = ElType <: Type ? Type : ElType
RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end
end

Expand Down
46 changes: 1 addition & 45 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
Expand Down Expand Up @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))
Expand Down
44 changes: 44 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
end
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

function typejoin_union_tuple(T::Type)
@_pure_meta
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Returns length, isfixed
function full_va_len(p)
Expand Down
4 changes: 0 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,6 @@ end
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
Expand Down
21 changes: 21 additions & 0 deletions test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
@test A == map(x->x*x*x, Float64[1:10...])
@test A === B
end

# Issue #28382: inferrability of map with Union eltype
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
[(1.0, "a"), (2, "b"), (3, "c")]
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
[[1, 2],[missing, 1]])
end

function testmap_equivalence(mapf, f, c...)
Expand Down

0 comments on commit b96dd48

Please sign in to comment.