Skip to content

Commit

Permalink
Use containertype to determine array type for array broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Dec 29, 2016
1 parent b561cfb commit 7767d8b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
1 change: 0 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ export broadcast_getindex, broadcast_setindex!
# fallbacks for some special cases
@inline broadcast(f, x::Number...) = f(x...)
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
@inline broadcast(f, As::AbstractArray...) = broadcast_c(f, Array, As...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
Expand Down
35 changes: 35 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,38 @@ end
@test (+).(Ref(1), Ref(2)) == fill(3)
@test (+).([[0,2], [1,3]], [1,-1]) == [[1,3], [0,2]]
@test (+).([[0,2], [1,3]], Ref{Vector{Int}}([1,-1])) == [[1,1], [2,2]]

# broadcasting for custom AbstractArray
immutable Array19745{T,N} <: AbstractArray{T,N}
data::Array{T,N}
end
Base.getindex(A::Array19745, i::Integer...) = A.data[i...]
Base.size(A::Array19745) = size(A.data)

Base.Broadcast.containertype{T<:Array19745}(::Type{T}) = Array19745

Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745

Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = ()

getfield19745(x::Array19745) = x.data
getfield19745(x) = x

Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))

@testset "broadcasting for custom AbstractArray" begin

a = randn(10)
aa = Array19745(a)
@test a .+ 1 == @inferred(aa .+ 1)
@test a .* a' == @inferred(aa .* aa')
@test isa(aa .+ 1, Array19745)
@test isa(aa .* aa', Array19745)

end

0 comments on commit 7767d8b

Please sign in to comment.