Skip to content

Commit

Permalink
Replace broadcast_similar with plain old similar (#27110)
Browse files Browse the repository at this point in the history
Now that we have a first-class object that represents a lazy broadcast, we can just use `similar` itself to ask it what kind of container it should allocate.
  • Loading branch information
mbauman authored May 17, 2018
1 parent 406615f commit 337ee84
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 46 deletions.
44 changes: 14 additions & 30 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable,
dotview, @__dot__
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__

### Objects with customized broadcasting behavior should declare a BroadcastStyle

Expand All @@ -25,8 +24,8 @@ by defining a type/method pair
struct MyContainerStyle <: BroadcastStyle end
Base.BroadcastStyle(::Type{<:MyContainer}) = MyContainerStyle()
One then writes method(s) (at least [`broadcast_similar`](@ref)) operating on
`MyContainerStyle`. There are also several pre-defined subtypes of `BroadcastStyle`
One then writes method(s) (at least [`similar`](@ref)) operating on
`Broadcasted{MyContainerStyle}`. There are also several pre-defined subtypes of `BroadcastStyle`
that you may be able to leverage; see the
[Interfaces chapter](@ref man-interfaces-broadcasting) for more information.
"""
Expand All @@ -38,13 +37,6 @@ parameter `C`. You can use this as an alternative to creating custom subtypes of
for example
Base.BroadcastStyle(::Type{<:MyContainer}) = Broadcast.Style{MyContainer}()
There is a pre-defined [`broadcast_similar`](@ref) method
broadcast_similar(f, ::Style{C}, ::Type{ElType}, inds, args...) =
similar(C, ElType, inds)
Naturally you can specialize this for your particular `C` (e.g., `MyContainer`).
"""
struct Style{T} <: BroadcastStyle end

Expand Down Expand Up @@ -199,23 +191,15 @@ function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
end

## Allocating the output container
"""
broadcast_similar(::BroadcastStyle, ::Type{ElType}, inds, bc)
Allocate an output object for [`broadcast`](@ref), appropriate for the indicated
[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and axes of the
container. The final `bc` argument is the `Broadcasted` object representing the fused broadcast operation
and its arguments.
"""
broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} =
similar(Array{ElType}, inds)
broadcast_similar(::DefaultArrayStyle{N}, ::Type{Bool}, inds::Indices{N}, bc) where N =
similar(BitArray, inds)
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}) where {N,ElType} =
similar(Array{ElType}, axes(bc))
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}) where N =
similar(BitArray, axes(bc))
# In cases of conflict we fall back on Array
broadcast_similar(::ArrayConflict, ::Type{ElType}, inds::Indices, bc) where ElType =
similar(Array{ElType}, inds)
broadcast_similar(::ArrayConflict, ::Type{Bool}, inds::Indices, bc) =
similar(BitArray, inds)
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{ElType}) where ElType =
similar(Array{ElType}, axes(bc))
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) =
similar(BitArray, axes(bc))

## Computing the result's axes. Most types probably won't need to specialize this.
broadcast_axes() = ()
Expand Down Expand Up @@ -767,7 +751,7 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
ElType = combine_eltypes(bc.f, bc.args)
if Base.isconcretetype(ElType)
# We can trust it and defer to the simpler `copyto!`
return copyto!(broadcast_similar(Style(), ElType, axes(bc), bc), bc)
return copyto!(similar(bc, ElType), bc)
end
# When ElType is not concrete, use narrowing. Use the first output
# value to determine the starting output eltype; copyto_nonleaf!
Expand All @@ -777,12 +761,12 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
state = start(iter)
if done(iter, state)
# if empty, take the ElType at face value
return broadcast_similar(Style(), ElType, axes(bc′), bc′)
return similar(bc′, ElType)
end
# Initialize using the first value
I, state = next(iter, state)
@inbounds val = bc′[I]
dest = broadcast_similar(Style(), typeof(val), axes(bc′), bc′)
dest = similar(bc′, typeof(val))
@inbounds dest[I] = val
# Now handle the remaining values
return copyto_nonleaf!(dest, bc′, iter, state, 1)
Expand Down
1 change: 0 additions & 1 deletion doc/src/base/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ Base.@__dot__
For specializing broadcast on custom types, see
```@docs
Base.BroadcastStyle
Base.broadcast_similar
Base.broadcast_axes
Base.Broadcast.AbstractArrayStyle
Base.Broadcast.ArrayStyle
Expand Down
16 changes: 8 additions & 8 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f
| Methods to implement | Brief description |
|:-------------------- |:----------------- |
| `Base.BroadcastStyle(::Type{SrcType}) = SrcStyle()` | Broadcasting behavior of `SrcType` |
| `Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc)` | Allocation of output container |
| `Base.similar(bc::Broadcasted{DestStyle}, ::Type{ElType})` | Allocation of output container |
| **Optional methods** | | |
| `Base.BroadcastStyle(::Style1, ::Style2) = Style12()` | Precedence rules for mixing styles |
| `Base.broadcast_axes(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) |
Expand Down Expand Up @@ -512,17 +512,17 @@ For more details, see [below](@ref writing-binary-broadcasting-rules).

The broadcast style is computed for every broadcasting operation to allow for
dispatch and specialization. The actual allocation of the result array is
handled by `Base.broadcast_similar`, using this style as its first argument.
handled by `similar`, using the Broadcasted object as its first argument.

```julia
Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc)
Base.similar(bc::Broadcasted{DestStyle}, ::Type{ElType})
```

The fallback definition is

```julia
broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} =
similar(Array{ElType}, inds)
similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}) where {N,ElType} =
similar(Array{ElType}, axes(bc))
```

However, if needed you can specialize on any or all of these arguments. The final argument
Expand Down Expand Up @@ -555,13 +555,13 @@ Base.BroadcastStyle(::Type{<:ArrayAndChar}) = Broadcast.ArrayStyle{ArrayAndChar}
```

This means we must also define a corresponding `broadcast_similar` method:
This means we must also define a corresponding `similar` method:
```jldoctest ArrayAndChar; filter = r"(^find_aac \(generic function with 5 methods\)$|^$)"
function Base.broadcast_similar(::Broadcast.ArrayStyle{ArrayAndChar}, ::Type{ElType}, inds, bc) where ElType
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayAndChar}}, ::Type{ElType}) where ElType
# Scan the inputs for the ArrayAndChar:
A = find_aac(bc)
# Use the char field of A to create the output
ArrayAndChar(similar(Array{ElType}, inds), A.char)
ArrayAndChar(similar(Array{ElType}, axes(bc)), A.char)
end
"`A = find_aac(As)` returns the first ArrayAndChar among the arguments."
Expand Down
7 changes: 4 additions & 3 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Broadcast styles
import Base.Broadcast
using Base.Broadcast: DefaultArrayStyle, broadcast_similar, tail
using Base.Broadcast: DefaultArrayStyle, Broadcasted, tail

struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
Expand Down Expand Up @@ -91,11 +91,12 @@ function fzero(bc::Broadcast.Broadcasted)
return any(ismissing, args) ? missing : bc.f(args...)
end

function Broadcast.broadcast_similar(::StructuredMatrixStyle{T}, ::Type{ElType}, inds, bc) where {T,ElType}
function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
inds = axes(bc)
if isstructurepreserving(bc) || (fzeropreserving(bc) && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
return structured_broadcast_alloc(bc, T, ElType, length(inds[1]))
end
return broadcast_similar(DefaultArrayStyle{2}(), ElType, inds, bc)
return similar(convert(Broadcasted{DefaultArrayStyle{ndims(bc)}}, bc), ElType)
end

function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
Expand Down
8 changes: 4 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ abstract type ArrayData{T,N} <: AbstractArray{T,N} end
Base.getindex(A::ArrayData, i::Integer...) = A.data[i...]
Base.setindex!(A::ArrayData, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::ArrayData) = size(A.data)
Base.broadcast_similar(::Broadcast.ArrayStyle{A}, ::Type{T}, inds::Tuple, bc) where {A,T} =
A(Array{T}(undef, length.(inds)))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{A}}, ::Type{T}) where {A,T} =
A(Array{T}(undef, length.(axes(bc))))

struct Array19745{T,N} <: ArrayData{T,N}
data::Array{T,N}
Expand Down Expand Up @@ -494,8 +494,8 @@ end
struct AD2DimStyle <: Broadcast.AbstractArrayStyle{2}; end
AD2DimStyle(::Val{2}) = AD2DimStyle()
AD2DimStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
Base.broadcast_similar(::AD2DimStyle, ::Type{T}, inds::Tuple, bc) where {T} =
AD2Dim(Array{T}(undef, length.(inds)))
Base.similar(bc::Broadcast.Broadcasted{AD2DimStyle}, ::Type{T}) where {T} =
AD2Dim(Array{T}(undef, length.(axes(bc))))
Base.BroadcastStyle(::Type{T}) where {T<:AD2Dim} = AD2DimStyle()

@testset "broadcasting for custom AbstractArray" begin
Expand Down

0 comments on commit 337ee84

Please sign in to comment.