-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize broadcast to handle tuples and scalars #16986
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,25 +10,44 @@ export broadcast_getindex, broadcast_setindex! | |
|
||
## Broadcasting utilities ## | ||
|
||
# fallback routines for broadcasting with no arguments or with scalars | ||
# to just produce a scalar result: | ||
# fallback for broadcasting with zero arguments and some special cases | ||
broadcast(f) = f() | ||
broadcast(f, x::Number...) = f(x...) | ||
@inline broadcast(f, x::Number...) = f(x...) | ||
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...) | ||
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...) | ||
|
||
# special cases for "X .= ..." (broadcast!) assignments | ||
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x) | ||
broadcast!(f, X::AbstractArray) = fill!(X, f()) | ||
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...)) | ||
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) | ||
check_broadcast_shape(size(x), size(y)) | ||
check_broadcast_shape(broadcast_indices(x), broadcast_indices(y)) | ||
copy!(x, y) | ||
end | ||
|
||
## Calculate the broadcast shape of the arguments, or error if incompatible | ||
# logic for deciding the resulting container type | ||
containertype(x) = containertype(typeof(x)) | ||
containertype(::Type) = Any | ||
containertype{T<:Tuple}(::Type{T}) = Tuple | ||
containertype{T<:AbstractArray}(::Type{T}) = Array | ||
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2)) | ||
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...)) | ||
|
||
promote_containertype(::Type{Array}, ::Type{Array}) = Array | ||
promote_containertype(::Type{Array}, ct) = Array | ||
promote_containertype(ct, ::Type{Array}) = Array | ||
promote_containertype(::Type{Tuple}, ::Type{Any}) = Tuple | ||
promote_containertype(::Type{Any}, ::Type{Tuple}) = Tuple | ||
promote_containertype{T}(::Type{T}, ::Type{T}) = T | ||
|
||
## Calculate the broadcast indices of the arguments, or error if incompatible | ||
# array inputs | ||
broadcast_shape() = () | ||
broadcast_shape(A) = indices(A) | ||
@inline broadcast_shape(A, B...) = broadcast_shape((), indices(A), map(indices, B)...) | ||
broadcast_indices() = () | ||
broadcast_indices(A) = broadcast_indices(containertype(A), A) | ||
broadcast_indices(::Type{Any}, A) = () | ||
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),) | ||
broadcast_indices(::Type{Array}, A) = indices(A) | ||
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...) | ||
# shape (i.e., tuple-of-indices) inputs | ||
broadcast_shape(shape::Tuple) = shape | ||
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...) | ||
|
@@ -50,24 +69,21 @@ _bcsm(a, b) = a == b || length(b) == 1 | |
_bcsm(a, b::Number) = b == 1 | ||
_bcsm(a::Number, b::Number) = a == b || b == 1 | ||
|
||
## Check that all arguments are broadcast compatible with shape | ||
## Check that all arguments are broadcast compatible with shape | ||
# comparing one input against a shape | ||
check_broadcast_shape(::Tuple{}) = nothing | ||
check_broadcast_shape(::Tuple{}, A::Union{AbstractArray,Number}) = check_broadcast_shape((), indices(A)) | ||
check_broadcast_shape(shp) = nothing | ||
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, indices(A)) | ||
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing | ||
check_broadcast_shape(shp, ::Tuple{}) = nothing | ||
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing | ||
check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions")) | ||
function check_broadcast_shape(shp, Ashp::Tuple) | ||
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination")) | ||
check_broadcast_shape(tail(shp), tail(Ashp)) | ||
end | ||
check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A)) | ||
# comparing many inputs | ||
@inline function check_broadcast_shape(shp, A, As...) | ||
check_broadcast_shape(shp, A) | ||
check_broadcast_shape(shp, As...) | ||
@inline function check_broadcast_indices(shp, A, As...) | ||
check_broadcast_indices(shp, A) | ||
check_broadcast_indices(shp, As...) | ||
end | ||
|
||
## Indexing manipulations | ||
|
@@ -83,14 +99,13 @@ end | |
|
||
# newindexer(shape, A) generates `keep` and `Idefault` (for use by | ||
# `newindex` above) for a particular array `A`, given the | ||
# broadcast_shape `shape` | ||
# broadcast_indices `shape` | ||
# `keep` is equivalent to map(==, indices(A), shape) (but see #17126) | ||
newindexer(shape, x::Number) = (), () | ||
@inline newindexer(shape, A) = newindexer(shape, indices(A)) | ||
@inline newindexer(shape, indsA::Tuple{}) = (), () | ||
@inline function newindexer(shape, indsA::Tuple) | ||
@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A)) | ||
@inline shapeindexer(shape, indsA::Tuple{}) = (), () | ||
@inline function shapeindexer(shape, indsA::Tuple) | ||
ind1 = indsA[1] | ||
keep, Idefault = newindexer(tail(shape), tail(indsA)) | ||
keep, Idefault = shapeindexer(tail(shape), tail(indsA)) | ||
(shape[1] == ind1, keep...), (first(ind1), Idefault...) | ||
end | ||
|
||
|
@@ -110,6 +125,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this | |
dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) = | ||
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6)) | ||
|
||
@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I) | ||
@inline _broadcast_getindex(::Type{Any}, A, I) = A | ||
@inline _broadcast_getindex(::Any, A, I) = A[I] | ||
|
||
## Broadcasting core | ||
# nargs encodes the number of As arguments (which matches the number | ||
# of keeps). The first two type parameters are to ensure specialization. | ||
|
@@ -124,7 +143,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) = | |
# reverse-broadcast the indices | ||
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) | ||
# extract array values | ||
@nexprs $nargs i->(@inbounds val_i = A_i[I_i]) | ||
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) | ||
# call the function and store the result | ||
@inbounds B[I] = @ncall $nargs f val | ||
end | ||
|
@@ -148,7 +167,7 @@ end | |
# reverse-broadcast the indices | ||
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) | ||
# extract array values | ||
@nexprs $nargs i->(@inbounds val_i = A_i[I_i]) | ||
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) | ||
# call the function and store the result | ||
@inbounds C[ind] = @ncall $nargs f val | ||
ind += 1 | ||
|
@@ -176,7 +195,7 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. | |
""" | ||
@inline function broadcast!{nargs}(f, B::AbstractArray, As::Vararg{Any,nargs}) | ||
shape = indices(B) | ||
check_broadcast_shape(shape, As...) | ||
check_broadcast_indices(shape, As...) | ||
keeps, Idefaults = map_newindexer(shape, As) | ||
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs}) | ||
B | ||
|
@@ -196,7 +215,7 @@ end | |
# reverse-broadcast the indices | ||
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) | ||
# extract array values | ||
@nexprs $nargs i->(@inbounds val_i = A_i[I_i]) | ||
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) | ||
# call the function | ||
V = @ncall $nargs f val | ||
S = typeof(V) | ||
|
@@ -219,7 +238,7 @@ end | |
end | ||
|
||
function broadcast_t(f, ::Type{Any}, As...) | ||
shape = broadcast_shape(As...) | ||
shape = broadcast_indices(As...) | ||
iter = CartesianRange(shape) | ||
if isempty(iter) | ||
return similar(Array{Any}, shape) | ||
|
@@ -228,19 +247,46 @@ function broadcast_t(f, ::Type{Any}, As...) | |
keeps, Idefaults = map_newindexer(shape, As) | ||
st = start(iter) | ||
I, st = next(iter, st) | ||
val = f([ As[i][newindex(I, keeps[i], Idefaults[i])] for i=1:nargs ]...) | ||
val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...) | ||
B = similar(Array{typeof(val)}, shape) | ||
B[I] = val | ||
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1) | ||
end | ||
|
||
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_shape(As...)), As...) | ||
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...) | ||
|
||
@generated function broadcast_tup{AT,nargs}(f, As::AT, ::Type{Val{nargs}}, n) | ||
quote | ||
ntuple(n -> (@ncall $nargs f i->_broadcast_getindex(As[i], n)), Val{n}) | ||
end | ||
end | ||
|
||
function broadcast_c(f, ::Type{Tuple}, As...) | ||
shape = broadcast_indices(As...) | ||
check_broadcast_indices(shape, As...) | ||
n = length(shape[1]) | ||
nargs = length(As) | ||
return broadcast_tup(f, As, Val{nargs}, n) | ||
end | ||
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...) | ||
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...) | ||
|
||
""" | ||
broadcast(f, As...) | ||
|
||
Broadcasts the arrays `As` to a common size by expanding singleton dimensions, and returns | ||
an array of the results `f(as...)` for each position. | ||
Broadcasts the arrays, tuples and/or scalars `As` to a container of the | ||
appropriate type and dimensions. In this context, anything that is not a | ||
subtype of `AbstractArray` or `Tuple` is considered a scalar. The resulting | ||
container is stablished by the following rules: | ||
|
||
- If all the arguments are scalars, it returns a scalar. | ||
- If the arguments are tuples and zero or more scalars, it returns a tuple. | ||
- If there is at least an array in the arguments, it returns an array | ||
(and treats tuples as 1-dimensional arrays) expanding singleton dimensions. | ||
|
||
A special syntax exists for broadcasting: `f.(args...)` is equivalent to | ||
`broadcast(f, args...)`, and nested `f.(g.(args...))` calls are fused into a | ||
single broadcast loop. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to run Also, there is a manual section on broadcasting, and it would be good to include an example acting on strings or similar, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
```jldoctest | ||
julia> A = [1, 2, 3, 4, 5] | ||
|
@@ -266,27 +312,32 @@ julia> broadcast(+, A, B) | |
8 9 | ||
11 12 | ||
14 15 | ||
``` | ||
""" | ||
@inline broadcast(f, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...) | ||
|
||
# alternate, more compact implementation; unfortunately slower. | ||
# also the `collect` machinery doesn't yet support arbitrary index bases. | ||
#= | ||
@generated function _broadcast{nargs}(f, keeps, As, ::Type{Val{nargs}}, iter) | ||
quote | ||
collect((@ncall $nargs f i->As[i][newindex(I, keeps[i])]) for I in iter) | ||
end | ||
end | ||
julia> parse.(Int, ["1", "2"]) | ||
2-element Array{Int64,1}: | ||
1 | ||
2 | ||
|
||
function broadcast(f, As...) | ||
shape = broadcast_shape(As...) | ||
iter = CartesianRange(shape) | ||
keeps, Idefaults = map_newindexer(shape, As) | ||
naT = Val{nfields(As)} | ||
_broadcast(f, keeps, Idefaults, As, naT, iter) | ||
end | ||
=# | ||
julia> abs.((1, -2)) | ||
(1,2) | ||
|
||
julia> broadcast(+, 1.0, (0, -2.0)) | ||
(1.0,-1.0) | ||
|
||
julia> broadcast(+, 1.0, (0, -2.0), [1]) | ||
2-element Array{Float64,1}: | ||
2.0 | ||
0.0 | ||
|
||
julia> string.(("one","two","three","four"), ": ", 1:4) | ||
4-element Array{String,1}: | ||
"one: 1" | ||
"two: 2" | ||
"three: 3" | ||
"four: 4" | ||
``` | ||
""" | ||
@inline broadcast(f, As...) = broadcast_c(f, containertype(As...), As...) | ||
|
||
""" | ||
bitbroadcast(f, As...) | ||
|
@@ -304,7 +355,7 @@ julia> bitbroadcast(isodd,[1,2,3,4,5]) | |
true | ||
``` | ||
""" | ||
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_shape(As...)), As...) | ||
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_indices(As...)), As...) | ||
|
||
""" | ||
broadcast_getindex(A, inds...) | ||
|
@@ -345,13 +396,13 @@ julia> broadcast_getindex(C,[1,2,10]) | |
15 | ||
``` | ||
""" | ||
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_shape(I...)), src, I...) | ||
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_indices(I...)), src, I...) | ||
@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...) | ||
N = length(I) | ||
Isplat = Expr[:(I[$d]) for d = 1:N] | ||
quote | ||
@nexprs $N d->(I_d = I[d]) | ||
check_broadcast_shape(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly | ||
check_broadcast_indices(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly | ||
checkbounds(src, $(Isplat...)) | ||
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == OneTo(1))) | ||
@nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin | ||
|
@@ -374,7 +425,7 @@ position in `X` at the indices in `A` given by the same positions in `inds`. | |
quote | ||
@nexprs $N d->(I_d = I[d]) | ||
checkbounds(A, $(Isplat...)) | ||
shape = broadcast_shape($(Isplat...)) | ||
shape = broadcast_indices($(Isplat...)) | ||
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d]) | ||
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == 1:1)) | ||
if !isa(x, AbstractArray) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this generated function definition is illegal (they may not use closures / inner functions, as that generates a new type and corrupts inference)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be reverted then? is that restriction fully documented or enforced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is a very strong restriction which includes
@threads
(possibly other macros) and comprehensions.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would changing this to
(((@ncall $nargs f i->_broadcast_getindex(As[i], _)) for _ = 1:n)...)
be enough to fix this?EDIT: No, some of the changes from #18413 are making all variations of fixes I tried, fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future, I believe we should fix the comprehension issue by introducing sealed anonymous functions (so that they get uniqued based upon their code contents rather than their definition position / gensym-name). But yes, that's currently invalid.