Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize broadcast to handle tuples and scalars #16986

Merged
merged 4 commits into from
Sep 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ This section lists changes that do not have deprecation warnings.
for `real(z) < 0`, which differs from `log(gamma(z))` by multiples of 2π
in the imaginary part ([#18330]).

* `broadcast` now handles tuples, and treats any argument that is not a tuple
or an array as a "scalar" ([#16986]).

Library improvements
--------------------

Expand Down Expand Up @@ -638,6 +641,7 @@ Language tooling improvements
[#16854]: https://github.com/JuliaLang/julia/issues/16854
[#16953]: https://github.com/JuliaLang/julia/issues/16953
[#16972]: https://github.com/JuliaLang/julia/issues/16972
[#16986]: https://github.com/JuliaLang/julia/issues/16986
[#17033]: https://github.com/JuliaLang/julia/issues/17033
[#17037]: https://github.com/JuliaLang/julia/issues/17037
[#17075]: https://github.com/JuliaLang/julia/issues/17075
Expand Down
4 changes: 2 additions & 2 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ promote_array_type{S<:Integer}(::typeof(.\), ::Type{S}, ::Type{Bool}, T::Type) =
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}, T::Type) = T

for f in (:+, :-, :div, :mod, :&, :|, :$)
@eval ($f){R,S}(A::AbstractArray{R}, B::AbstractArray{S}) =
_elementwise($f, promote_op($f, R, S), A, B)
@eval ($f)(A::AbstractArray, B::AbstractArray) =
_elementwise($f, promote_eltype_op($f, A, B), A, B)
end
function _elementwise(op, ::Type{Any}, A::AbstractArray, B::AbstractArray)
promote_shape(A, B) # check size compatibility
Expand Down
159 changes: 105 additions & 54 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,44 @@ export broadcast_getindex, broadcast_setindex!

## Broadcasting utilities ##

# fallback routines for broadcasting with no arguments or with scalars
# to just produce a scalar result:
# fallback for broadcasting with zero arguments and some special cases
broadcast(f) = f()
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, x::Number...) = f(x...)
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray) = fill!(X, f())
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
check_broadcast_shape(size(x), size(y))
check_broadcast_shape(broadcast_indices(x), broadcast_indices(y))
copy!(x, y)
end

## Calculate the broadcast shape of the arguments, or error if incompatible
# logic for deciding the resulting container type
containertype(x) = containertype(typeof(x))
containertype(::Type) = Any
containertype{T<:Tuple}(::Type{T}) = Tuple
containertype{T<:AbstractArray}(::Type{T}) = Array
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))

promote_containertype(::Type{Array}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ct) = Array
promote_containertype(ct, ::Type{Array}) = Array
promote_containertype(::Type{Tuple}, ::Type{Any}) = Tuple
promote_containertype(::Type{Any}, ::Type{Tuple}) = Tuple
promote_containertype{T}(::Type{T}, ::Type{T}) = T

## Calculate the broadcast indices of the arguments, or error if incompatible
# array inputs
broadcast_shape() = ()
broadcast_shape(A) = indices(A)
@inline broadcast_shape(A, B...) = broadcast_shape((), indices(A), map(indices, B)...)
broadcast_indices() = ()
broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::Type{Any}, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A) = indices(A)
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...)
Expand All @@ -50,24 +69,21 @@ _bcsm(a, b) = a == b || length(b) == 1
_bcsm(a, b::Number) = b == 1
_bcsm(a::Number, b::Number) = a == b || b == 1

## Check that all arguments are broadcast compatible with shape
## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
check_broadcast_shape(::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, A::Union{AbstractArray,Number}) = check_broadcast_shape((), indices(A))
check_broadcast_shape(shp) = nothing
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, indices(A))
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
check_broadcast_shape(shp, ::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
function check_broadcast_shape(shp, Ashp::Tuple)
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
check_broadcast_shape(tail(shp), tail(Ashp))
end
check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A))
# comparing many inputs
@inline function check_broadcast_shape(shp, A, As...)
check_broadcast_shape(shp, A)
check_broadcast_shape(shp, As...)
@inline function check_broadcast_indices(shp, A, As...)
check_broadcast_indices(shp, A)
check_broadcast_indices(shp, As...)
end

## Indexing manipulations
Expand All @@ -83,14 +99,13 @@ end

# newindexer(shape, A) generates `keep` and `Idefault` (for use by
# `newindex` above) for a particular array `A`, given the
# broadcast_shape `shape`
# broadcast_indices `shape`
# `keep` is equivalent to map(==, indices(A), shape) (but see #17126)
newindexer(shape, x::Number) = (), ()
@inline newindexer(shape, A) = newindexer(shape, indices(A))
@inline newindexer(shape, indsA::Tuple{}) = (), ()
@inline function newindexer(shape, indsA::Tuple)
@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A))
@inline shapeindexer(shape, indsA::Tuple{}) = (), ()
@inline function shapeindexer(shape, indsA::Tuple)
ind1 = indsA[1]
keep, Idefault = newindexer(tail(shape), tail(indsA))
keep, Idefault = shapeindexer(tail(shape), tail(indsA))
(shape[1] == ind1, keep...), (first(ind1), Idefault...)
end

Expand All @@ -110,6 +125,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this
dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6))

@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
@inline _broadcast_getindex(::Type{Any}, A, I) = A
@inline _broadcast_getindex(::Any, A, I) = A[I]

## Broadcasting core
# nargs encodes the number of As arguments (which matches the number
# of keeps). The first two type parameters are to ensure specialization.
Expand All @@ -124,7 +143,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function and store the result
@inbounds B[I] = @ncall $nargs f val
end
Expand All @@ -148,7 +167,7 @@ end
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function and store the result
@inbounds C[ind] = @ncall $nargs f val
ind += 1
Expand Down Expand Up @@ -176,7 +195,7 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline function broadcast!{nargs}(f, B::AbstractArray, As::Vararg{Any,nargs})
shape = indices(B)
check_broadcast_shape(shape, As...)
check_broadcast_indices(shape, As...)
keeps, Idefaults = map_newindexer(shape, As)
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs})
B
Expand All @@ -196,7 +215,7 @@ end
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function
V = @ncall $nargs f val
S = typeof(V)
Expand All @@ -219,7 +238,7 @@ end
end

function broadcast_t(f, ::Type{Any}, As...)
shape = broadcast_shape(As...)
shape = broadcast_indices(As...)
iter = CartesianRange(shape)
if isempty(iter)
return similar(Array{Any}, shape)
Expand All @@ -228,19 +247,46 @@ function broadcast_t(f, ::Type{Any}, As...)
keeps, Idefaults = map_newindexer(shape, As)
st = start(iter)
I, st = next(iter, st)
val = f([ As[i][newindex(I, keeps[i], Idefaults[i])] for i=1:nargs ]...)
val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...)
B = similar(Array{typeof(val)}, shape)
B[I] = val
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
end

@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_shape(As...)), As...)
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...)

@generated function broadcast_tup{AT,nargs}(f, As::AT, ::Type{Val{nargs}}, n)
quote
ntuple(n -> (@ncall $nargs f i->_broadcast_getindex(As[i], n)), Val{n})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this generated function definition is illegal (they may not use closures / inner functions, as that generates a new type and corrupts inference)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be reverted then? is that restriction fully documented or enforced?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a very strong restriction which includes @threads (possibly other macros) and comprehensions.

Copy link
Contributor Author

@pabloferz pabloferz Sep 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would changing this to (((@ncall $nargs f i->_broadcast_getindex(As[i], _)) for _ = 1:n)...) be enough to fix this?

EDIT: No, some of the changes from #18413 are making all variations of fixes I tried, fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, I believe we should fix the comprehension issue by introducing sealed anonymous functions (so that they get uniqued based upon their code contents rather than their definition position / gensym-name). But yes, that's currently invalid.

end
end

function broadcast_c(f, ::Type{Tuple}, As...)
shape = broadcast_indices(As...)
check_broadcast_indices(shape, As...)
n = length(shape[1])
nargs = length(As)
return broadcast_tup(f, As, Val{nargs}, n)
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

"""
broadcast(f, As...)

Broadcasts the arrays `As` to a common size by expanding singleton dimensions, and returns
an array of the results `f(as...)` for each position.
Broadcasts the arrays, tuples and/or scalars `As` to a container of the
appropriate type and dimensions. In this context, anything that is not a
subtype of `AbstractArray` or `Tuple` is considered a scalar. The resulting
container is stablished by the following rules:

- If all the arguments are scalars, it returns a scalar.
- If the arguments are tuples and zero or more scalars, it returns a tuple.
- If there is at least an array in the arguments, it returns an array
(and treats tuples as 1-dimensional arrays) expanding singleton dimensions.

A special syntax exists for broadcasting: `f.(args...)` is equivalent to
`broadcast(f, args...)`, and nested `f.(g.(args...))` calls are fused into a
single broadcast loop.
Copy link
Member

@stevengj stevengj Sep 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to run ./julia doc/genstdlib.jl to update the manual with this docstring.

Also, there is a manual section on broadcasting, and it would be good to include an example acting on strings or similar, e.g. string.(("one","two","three","four"),": ",1:4).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


```jldoctest
julia> A = [1, 2, 3, 4, 5]
Expand All @@ -266,27 +312,32 @@ julia> broadcast(+, A, B)
8 9
11 12
14 15
```
"""
@inline broadcast(f, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

# alternate, more compact implementation; unfortunately slower.
# also the `collect` machinery doesn't yet support arbitrary index bases.
#=
@generated function _broadcast{nargs}(f, keeps, As, ::Type{Val{nargs}}, iter)
quote
collect((@ncall $nargs f i->As[i][newindex(I, keeps[i])]) for I in iter)
end
end
julia> parse.(Int, ["1", "2"])
2-element Array{Int64,1}:
1
2

function broadcast(f, As...)
shape = broadcast_shape(As...)
iter = CartesianRange(shape)
keeps, Idefaults = map_newindexer(shape, As)
naT = Val{nfields(As)}
_broadcast(f, keeps, Idefaults, As, naT, iter)
end
=#
julia> abs.((1, -2))
(1,2)

julia> broadcast(+, 1.0, (0, -2.0))
(1.0,-1.0)

julia> broadcast(+, 1.0, (0, -2.0), [1])
2-element Array{Float64,1}:
2.0
0.0

julia> string.(("one","two","three","four"), ": ", 1:4)
4-element Array{String,1}:
"one: 1"
"two: 2"
"three: 3"
"four: 4"
```
"""
@inline broadcast(f, As...) = broadcast_c(f, containertype(As...), As...)

"""
bitbroadcast(f, As...)
Expand All @@ -304,7 +355,7 @@ julia> bitbroadcast(isodd,[1,2,3,4,5])
true
```
"""
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_shape(As...)), As...)
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_indices(As...)), As...)

"""
broadcast_getindex(A, inds...)
Expand Down Expand Up @@ -345,13 +396,13 @@ julia> broadcast_getindex(C,[1,2,10])
15
```
"""
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_shape(I...)), src, I...)
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_indices(I...)), src, I...)
@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_shape(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
check_broadcast_indices(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == OneTo(1)))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin
Expand All @@ -374,7 +425,7 @@ position in `X` at the indices in `A` given by the same positions in `inds`.
quote
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = broadcast_shape($(Isplat...))
shape = broadcast_indices($(Isplat...))
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d])
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == 1:1))
if !isa(x, AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ maybe_oneto() = OneTo(1)

### From abstractarray.jl: Internal multidimensional indexing definitions ###
getindex(x::Number, i::CartesianIndex{0}) = x
getindex(t::Tuple, I...) = getindex(t, IteratorsMD.flatten(I)...)

# These are not defined on directly on getindex to avoid
# ambiguities for AbstractArray subtypes. See the note in abstractarray.jl
Expand Down
2 changes: 2 additions & 0 deletions base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ zero{T<:Number}(::Type{T}) = convert(T,0)
one(x::Number) = oftype(x,1)
one{T<:Number}(::Type{T}) = convert(T,1)

_default_type(::Type{Number}) = Int

"""
factorial(n)

Expand Down
4 changes: 2 additions & 2 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ function _mapreducedim!{T,N}(f, op, R::AbstractArray, A::AbstractArray{T,N})
return R
end
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(R)) # handle d=1 manually
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
if reducedim1(R, A)
# keep the accumulator as a local variable when reducing along the first dimension
i1 = first(indices1(R))
Expand Down Expand Up @@ -331,7 +331,7 @@ function findminmax!{T,N}(f, Rval, Rind, A::AbstractArray{T,N})
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(Rval))
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
k = 0
if reducedim1(Rval, A)
i1 = first(indices1(Rval))
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
triu, vec, permute!

import Base.Broadcast: broadcast_shape
import Base.Broadcast: broadcast_indices

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blkdiag, dense, droptol!, dropzeros!, dropzeros,
Expand Down
Loading