Skip to content

Commit

Permalink
Fix broadcast_indices
Browse files Browse the repository at this point in the history
This fixes a regression introduced in 4f1b479. broadcast_indices() needs to
be overloaded by packages for custom types, so it cannot be hidden under
_broadcast_indices(). Also, ::Type is incorrect since the method only applies
to scalars.

Make the tests more complex to be closer to actual implementations in packages
so that regressions like this will be noticed in the future.
  • Loading branch information
nalimilan committed May 30, 2017
1 parent efca045 commit cc74744
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
10 changes: 5 additions & 5 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ promote_containertype(::Type{T}, ::Type{T}) where {T} = T
## Calculate the broadcast indices of the arguments, or error if incompatible
# array inputs
broadcast_indices() = ()
broadcast_indices(A) = _broadcast_indices(containertype(A), A)
broadcast_indices(A) = broadcast_indices(containertype(A), A)
@inline broadcast_indices(A, B...) = broadcast_shape(broadcast_indices(A), broadcast_indices(B...))
_broadcast_indices(::Type, A) = ()
_broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
_broadcast_indices(::Type{Array}, A::Ref) = ()
_broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::ScalarType, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A::Ref) = ()
broadcast_indices(::Type{Array}, A) = indices(A)

# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down
9 changes: 7 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ struct Array19745{T,N} <: AbstractArray{T,N}
data::Array{T,N}
end
Base.getindex(A::Array19745, i::Integer...) = A.data[i...]
Base.setindex!(A::Array19745, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::Array19745) = size(A.data)

Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745
Expand All @@ -435,8 +436,12 @@ Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = ()
getfield19745(x::Array19745) = x.data
getfield19745(x) = x

Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))
function Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...)
T = Base.Broadcast._broadcast_eltype(f, A, Bs...)
shape = Base.Broadcast.broadcast_indices(A, Bs...)
dest = Array19745(Array{T}(Base.index_lengths(shape...)))
return broadcast!(f, dest, A, Bs...)
end

@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
Expand Down

0 comments on commit cc74744

Please sign in to comment.