Skip to content

Commit

Permalink
Fix broadcast_indices (#22130)
Browse files Browse the repository at this point in the history
This fixes a regression introduced in 4f1b479. broadcast_indices() needs to
be overloaded by packages for custom types, so it cannot be hidden under
_broadcast_indices(). Also, ::Type is incorrect since the method only applies
to scalars.

Make the tests more complex to be closer to actual implementations in packages
so that regressions like this will be noticed in the future.
  • Loading branch information
nalimilan authored and KristofferC committed Jun 1, 2017
1 parent 8758659 commit fb81c34
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
10 changes: 5 additions & 5 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ promote_containertype(::Type{T}, ::Type{T}) where {T} = T
## Calculate the broadcast indices of the arguments, or error if incompatible
# array inputs
broadcast_indices() = ()
broadcast_indices(A) = _broadcast_indices(containertype(A), A)
broadcast_indices(A) = broadcast_indices(containertype(A), A)
@inline broadcast_indices(A, B...) = broadcast_shape(broadcast_indices(A), broadcast_indices(B...))
_broadcast_indices(::Type, A) = ()
_broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
_broadcast_indices(::Type{Array}, A::Ref) = ()
_broadcast_indices(::Type{Array}, A) = indices(A)
broadcast_indices(::ScalarType, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A::Ref) = ()
broadcast_indices(::Type{Array}, A) = indices(A)

# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down
7 changes: 3 additions & 4 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ module HigherOrderFns
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base.Broadcast: _containertype, promote_containertype,
broadcast_indices, _broadcast_indices,
broadcast_c, broadcast_c!
broadcast_indices, broadcast_c, broadcast_c!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
Expand Down Expand Up @@ -901,7 +900,7 @@ end
# (10) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices

# broadcast shape promotion for combinations of sparse arrays and other types
_broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
# broadcast container type promotion for combinations of sparse arrays and other types
_containertype(::Type{<:SparseVecOrMat}) = AbstractSparseArray
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
Expand Down Expand Up @@ -985,7 +984,7 @@ struct PromoteToSparse end
# broadcast containertype definitions for structured matrices
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
_containertype(::Type{<:StructuredMatrix}) = PromoteToSparse
_broadcast_indices(::Type{PromoteToSparse}, A) = indices(A)
broadcast_indices(::Type{PromoteToSparse}, A) = indices(A)

# combinations explicitly involving Tuples and PromoteToSparse collections
# divert to the generic AbstractArray broadcast code
Expand Down
9 changes: 7 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ struct Array19745{T,N} <: AbstractArray{T,N}
data::Array{T,N}
end
Base.getindex(A::Array19745, i::Integer...) = A.data[i...]
Base.setindex!(A::Array19745, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::Array19745) = size(A.data)

Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745
Expand All @@ -435,8 +436,12 @@ Base.Broadcast.broadcast_indices(::Type{Array19745}, A::Ref) = ()
getfield19745(x::Array19745) = x.data
getfield19745(x) = x

Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))
function Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...)
T = Base.Broadcast._broadcast_eltype(f, A, Bs...)
shape = Base.Broadcast.broadcast_indices(A, Bs...)
dest = Array19745(Array{T}(Base.index_lengths(shape...)))
return broadcast!(f, dest, A, Bs...)
end

@testset "broadcasting for custom AbstractArray" begin
a = randn(10)
Expand Down

0 comments on commit fb81c34

Please sign in to comment.