diff --git a/base/broadcast.jl b/base/broadcast.jl index b55051d82546d..34c2241864fd6 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -8,7 +8,7 @@ Module containing the broadcasting implementation. module Broadcast using .Base.Cartesian -using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, +using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure, _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias import .Base: copy, copyto!, axes export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction @@ -691,8 +691,52 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])} eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])} eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...} +function promote_typejoin_union(::Type{T}) where T + if T === Union{} + return Union{} + elseif T isa Union + return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b)) + elseif T <: Tuple + return typejoin_union_tuple(T) + else + return T + end +end + +@pure function typejoin_union_tuple(T::Type) + u = Base.unwrap_unionall(T) + u isa Union && return typejoin( + typejoin_union_tuple(Base.rewrap_unionall(u.a, T)), + typejoin_union_tuple(Base.rewrap_unionall(u.b, T))) + p = (u::DataType).parameters + lr = length(p)::Int + if lr == 0 + return Tuple{} + end + c = Vector{Any}(undef, lr) + for i = 1:lr + pi = p[i] + U = Core.Compiler.unwrapva(pi) + if U === Union{} + ci = Union{} + elseif U isa Union + ci = typejoin(U.a, U.b) + else + ci = U + end + if i == lr && Core.Compiler.isvarargtype(pi) + N = (Base.unwrap_unionall(pi)::DataType).parameters[2] + c[i] = Base.rewrap_unionall(Vararg{ci, N}, pi) + else + c[i] = ci + end + end + return Base.rewrap_unionall(Tuple{c...}, T) +end + # Inferred eltype of result of broadcast(f, args...) -combine_eltypes(f, args::Tuple) = Base._return_type(f, eltypes(args)) +combine_eltypes(f, args::Tuple) = + promote_typejoin_union(Base._return_type(f, eltypes(args))) ## Broadcasting core @@ -877,7 +921,11 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict} dest = similar(bc′, typeof(val)) @inbounds dest[I] = val # Now handle the remaining values - return copyto_nonleaf!(dest, bc′, iter, state, 1) + # The typeassert gives inference a helping hand on the element type and dimensionality + # (work-around for #28382) + ElType′ = ElType <: Type ? Type : ElType + RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any + return copyto_nonleaf!(dest, bc′, iter, state, 1)::RT end ## general `copyto!` methods diff --git a/test/broadcast.jl b/test/broadcast.jl index 7fcc793fb6844..8a5b003fca8f3 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -365,7 +365,7 @@ end let f17314 = x -> x < 0 ? false : x @test eltype(broadcast(f17314, 1:3)) === Int @test eltype(broadcast(f17314, -1:1)) === Integer - @test eltype(broadcast(f17314, Int[])) == Union{Bool,Int} + @test eltype(broadcast(f17314, Int[])) === Integer end let io = IOBuffer() broadcast(x->print(io,x), 1:5) # broadcast with side effects @@ -950,3 +950,41 @@ p0 = copy(p) @test map(.+, [[1,2], [3,4]], [5, 6]) == [[6,7], [9,10]] @test repr(.!) == "Base.Broadcast.BroadcastFunction(!)" @test eval(:(.+)) == Base.BroadcastFunction(+) + +@testset "Issue #28382: inferrability of broadcast with Union eltype" begin + @test isequal([1, 2] .+ [3.0, missing], [4.0, missing]) + @test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Union{Float64, Missing}} + @test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + AbstractVector{<:Union{Float64, Missing}} + @test isequal([1, 2] + [3.0, missing], [4.0, missing]) + @test_broken Core.Compiler.return_type(+, Tuple{Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Union{Float64, Missing}} + @test Core.Compiler.return_type(+, Tuple{Vector{Int}, + Vector{Union{Float64, Missing}}}) == + AbstractVector{<:Union{Float64, Missing}} + @test_broken Core.Compiler.return_type(+, Tuple{Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Union{Float64, Missing}} + @test isequal(tuple.([1, 2], [3.0, missing]), [(1, 3.0), (2, missing)]) + @test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + Vector{<:Tuple{Int, Any}} + @test Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int}, + Vector{Union{Float64, Missing}}}) == + AbstractVector{<:Tuple{Int, Any}} + # Check that corner cases do not throw an error + @test isequal(broadcast(x -> x === 1 ? nothing : x, [1, 2, missing]), + [nothing, 2, missing]) + @test isequal(broadcast(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]), + [nothing, 2, 3, missing]) + @test broadcast((x,y)->(x==1 ? 1.0 : x, y), [1 2 3], ["a", "b", "c"]) == + [(1.0, "a") (2, "a") (3, "a") + (1.0, "b") (2, "b") (3, "b") + (1.0, "c") (2, "c") (3, "c")] + @test typeof.([iszero, isdigit]) == [typeof(iszero), typeof(isdigit)] + @test typeof.([iszero, iszero]) == [typeof(iszero), typeof(iszero)] +end