Skip to content

Commit

Permalink
Stop using @pure with promote_typejoin
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Aug 13, 2020
1 parent 8a1e65f commit 999f17f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
41 changes: 26 additions & 15 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d
Expand Down Expand Up @@ -137,7 +137,7 @@ BroadcastStyle(a::AbstractArrayStyle, ::Style{Tuple}) = a
BroadcastStyle(::A, ::A) where A<:ArrayStyle = A()
BroadcastStyle(::ArrayStyle, ::ArrayStyle) = Unknown()
BroadcastStyle(::A, ::A) where A<:AbstractArrayStyle = A()
Base.@pure function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N}
@pure function BroadcastStyle(a::A, b::B) where {A<:AbstractArrayStyle{M},B<:AbstractArrayStyle{N}} where {M,N}
if Base.typename(A) === Base.typename(B)
return A(Val(max(M, N)))
end
Expand Down Expand Up @@ -697,24 +697,35 @@ function promote_typejoin_union(::Type{T}) where T
elseif T isa Union
promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
p = T.parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
lf, fixed = Core.Compiler.full_va_len(p)
c = Vector{Any}(undef, lf)
for i = 1:lf
pi = p[i]
ci = promote_typejoin_union(Core.Compiler.unwrapva(pi))
c[i] = i == lf && Core.Compiler.isvarargtype(pi) ? Vararg{ci} : ci
end
return Tuple{c...}
typejoin_union_tuple(T)
else
T
end
end

@pure function typejoin_union_tuple(::Type{T}) where {T}
p = T.parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
lf, fixed = Core.Compiler.full_va_len(p)
c = Vector{Any}(undef, lf)
for i = 1:lf
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
c[i] = i == lf && Core.Compiler.isvarargtype(pi) ? Vararg{ci} : ci
end
return Tuple{c...}
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))
Expand Down
4 changes: 4 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,10 @@ ret = @macroexpand @.([Int, Number] >: Real)
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test isequal(tuple.([1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(broadcast(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
Expand Down

0 comments on commit 999f17f

Please sign in to comment.