From b96dd48b8b0401b252ef79d1c6a565f149c04d04 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 29 Aug 2021 16:24:55 +0200 Subject: [PATCH] Make return type of map inferrable with heterogeneous arrays Inference is not able to detect the element type automatically, but we can do it manually since we know promote_typejoin is used for widening. This is similar to the approach used for `broadcast` at #30485. --- base/array.jl | 9 ++++++-- base/broadcast.jl | 46 +-------------------------------------- base/promotion.jl | 44 +++++++++++++++++++++++++++++++++++++ test/broadcast.jl | 4 ---- test/generic_map_tests.jl | 21 ++++++++++++++++++ 5 files changed, 73 insertions(+), 51 deletions(-) diff --git a/base/array.jl b/base/array.jl index 6a16db66a8cfd..976846549df2e 100644 --- a/base/array.jl +++ b/base/array.jl @@ -775,8 +775,13 @@ function collect(itr::Generator) return _array_for(et, itr.iter, isz) end v1, st = y - arr = _array_for(typeof(v1), itr.iter, isz, shape) - return collect_to_with_first!(arr, v1, itr, st) + dest = _array_for(typeof(v1), itr.iter, isz, shape) + # The typeassert gives inference a helping hand on the element type and dimensionality + # (work-around for #28382) + ElType = promote_typejoin_union(et) + ElType′ = ElType <: Type ? Type : ElType + RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any + collect_to_with_first!(dest, v1, itr, st)::RT end end diff --git a/base/broadcast.jl b/base/broadcast.jl index b34a73041708b..90479189ffee4 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -8,7 +8,7 @@ Module containing the broadcasting implementation. module Broadcast using .Base.Cartesian -using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure, +using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure, _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias import .Base: copy, copyto!, axes export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])} eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])} eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...} -function promote_typejoin_union(::Type{T}) where T - if T === Union{} - return Union{} - elseif T isa UnionAll - return Any # TODO: compute more precise bounds - elseif T isa Union - return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b)) - elseif T <: Tuple - return typejoin_union_tuple(T) - else - return T - end -end - -@pure function typejoin_union_tuple(T::Type) - u = Base.unwrap_unionall(T) - u isa Union && return typejoin( - typejoin_union_tuple(Base.rewrap_unionall(u.a, T)), - typejoin_union_tuple(Base.rewrap_unionall(u.b, T))) - p = (u::DataType).parameters - lr = length(p)::Int - if lr == 0 - return Tuple{} - end - c = Vector{Any}(undef, lr) - for i = 1:lr - pi = p[i] - U = Core.Compiler.unwrapva(pi) - if U === Union{} - ci = Union{} - elseif U isa Union - ci = typejoin(U.a, U.b) - else - ci = U - end - if i == lr && Core.Compiler.isvarargtype(pi) - c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci} - else - c[i] = ci - end - end - return Base.rewrap_unionall(Tuple{c...}, T) -end - # Inferred eltype of result of broadcast(f, args...) combine_eltypes(f, args::Tuple) = promote_typejoin_union(Base._return_type(f, eltypes(args))) diff --git a/base/promotion.jl b/base/promotion.jl index cae7d07f06da6..21245f0e05c70 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b)) end _promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing}) +function promote_typejoin_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa UnionAll + return Any # TODO: compute more precise bounds + elseif T isa Union + return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end +end + +function typejoin_union_tuple(T::Type) + @_pure_meta + u = Base.unwrap_unionall(T) + u isa Union && return typejoin( + typejoin_union_tuple(Base.rewrap_unionall(u.a, T)), + typejoin_union_tuple(Base.rewrap_unionall(u.b, T))) + p = (u::DataType).parameters + lr = length(p)::Int + if lr == 0 + return Tuple{} + end + c = Vector{Any}(undef, lr) + for i = 1:lr + pi = p[i] + U = Core.Compiler.unwrapva(pi) + if U === Union{} + ci = Union{} + elseif U isa Union + ci = typejoin(U.a, U.b) + else + ci = U + end + if i == lr && Core.Compiler.isvarargtype(pi) + c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci} + else + c[i] = ci + end + end + return Base.rewrap_unionall(Tuple{c...}, T) +end # Returns length, isfixed function full_va_len(p) diff --git a/test/broadcast.jl b/test/broadcast.jl index 66c215aee9293..329bcc602206b 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -991,10 +991,6 @@ end @test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, Vector{Union{Float64, Missing}}}) == Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} - @test isequal([1, 2] + [3.0, missing], [4.0, missing]) - @test Core.Compiler.return_type(+, Tuple{Vector{Int}, - Vector{Union{Float64, Missing}}}) == - Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} @test Core.Compiler.return_type(+, Tuple{Vector{Int}, Vector{Union{Float64, Missing}}}) == Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} diff --git a/test/generic_map_tests.jl b/test/generic_map_tests.jl index 8e77533362fe3..9915280e0abb6 100644 --- a/test/generic_map_tests.jl +++ b/test/generic_map_tests.jl @@ -53,6 +53,27 @@ function generic_map_tests(mapf, inplace_mapf=nothing) @test A == map(x->x*x*x, Float64[1:10...]) @test A === B end + + # Issue #28382: inferrability of map with Union eltype + @test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing]) + @test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}} + @test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)]) + @test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Tuple{Int, Any}} + # Check that corner cases do not throw an error + @test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]), + [nothing, 2, missing]) + @test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]), + [nothing, 2, 3, missing]) + @test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) == + [(1.0, "a"), (2, "b"), (3, "c")] + @test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)] + @test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)] + @test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]), + [[1, 2],[missing, 1]]) end function testmap_equivalence(mapf, f, c...)