Skip to content

Commit

Permalink
Implement broadcasting using Cartesian
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jan 13, 2014
1 parent ea7987f commit 4ef39e1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 222 deletions.
2 changes: 0 additions & 2 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,9 @@ imag{T<:Real}(x::AbstractArray{T}) = zero(x)
\(A::Number, B::AbstractArray) = B ./ A
\(A::AbstractArray, B::Number) = B ./ A

./(x::AbstractArray, y::AbstractArray ) = throw(MethodError(./, (x,y)))
./(x::Number,y::AbstractArray ) = throw(MethodError(./, (x,y)))
./(x::AbstractArray, y::Number) = throw(MethodError(./, (x,y)))

.^(x::AbstractArray, y::AbstractArray ) = throw(MethodError(.^, (x,y)))
.^(x::Number,y::AbstractArray ) = throw(MethodError(.^, (x,y)))
.^(x::AbstractArray, y::Number) = throw(MethodError(.^, (x,y)))

Expand Down
261 changes: 49 additions & 212 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
module Broadcast

using ..Meta.quot
using ..Cartesian
import Base.promote_eltype
import Base.(.+), Base.(.-), Base.(.*), Base.(./), Base.(.\)
export broadcast, broadcast!, broadcast_function, broadcast!_function
export broadcast_getindex, broadcast_setindex!


## Broadcasting utilities ##

Expand Down Expand Up @@ -57,247 +56,85 @@ function check_broadcast_shape(shape::Dims, As::AbstractArray...)
end
end

# Calculate strides as will be used by the generated inner loops
function calc_loop_strides(shape::Dims, As::AbstractArray...)
# squeeze out singleton dimensions in shape
dims = Array(Int, 0)
loopshape = Array(Int, 0)
nd = length(shape)
sizehint(dims, nd)
sizehint(loopshape, nd)
for i = 1:nd
s = shape[i]
if s != 1
push!(dims, i)
push!(loopshape, s)
end
end
nd = length(loopshape)

strides = Int[(size(A, d) > 1 ? stride(A, d) : 0) for A in As, d in dims]
# convert from regular strides to loop strides
for k=(nd-1):-1:1, a=1:length(As)
strides[a, k+1] -= strides[a, k]*loopshape[k]
end

tuple(loopshape...), strides
end

function broadcast_args(shape::Dims, As::(Array...))
loopshape, strides = calc_loop_strides(shape, As...)
(loopshape, As, ones(Int, length(As)), strides)
end
function broadcast_args(shape::Dims, As::(StridedArray...))
loopshape, strides = calc_loop_strides(shape, As...)
nA = length(As)
offs = Array(Int, nA)
baseAs = Array(Array, nA)
for (k, A) in enumerate(As)
offs[k],baseAs[k] = isa(A,SubArray) ? (A.first_index,A.parent) : (1,A)
end
(loopshape, tuple(baseAs...), offs, strides)
end


## Generation of inner loop instances ##

function code_inner_loop(fname::Symbol, extra_args::Vector, initial,
innermost::Function, narrays::Int, nd::Int)
Asyms = [gensym("A$a") for a=1:narrays]
indsyms = [gensym("k$a") for a=1:narrays]
axissyms = [gensym("i$d") for d=1:nd]
sizesyms = [gensym("n$d") for d=1:nd]
stridesyms = [gensym("s$(a)_$d") for a=1:narrays, d=1:nd]

loop = innermost([:($arr[$ind]) for (arr, ind) in zip(Asyms, indsyms)]...)
for (d, (axis, n)) in enumerate(zip(axissyms, sizesyms))
loop = :(
for $axis=1:$n
$loop
$([:($ind += $(stridesyms[a, d]))
for (a, ind) in enumerate(indsyms)]...)
end
)
end

@gensym shape arrays offsets strides
## Broadcasting core
# Generate the body for a broadcasting function f_broadcast!(B, A1, A2, ..., A$narrays),
# using function f, output B, and inputs As...
# B must have already been set to the appropriate size.
function gen_broadcast_body(nd::Int, narrays::Int, f::Function)
checkshape = Expr(:call, check_broadcast_shape, :(size(B)), [symbol("A_"*string(i)) for i = 1:narrays]...)
F = Expr(:quote, f)
quote
function $fname($shape::NTuple{$nd, Int},
$arrays::NTuple{$narrays, StridedArray},
$offsets::Vector{Int},
$strides::Matrix{Int}, $(extra_args...))
@assert size($strides) == ($narrays, $nd)
($(sizesyms...),) = $shape
$([:(if $n==0; return; end) for n in sizesyms]...)
($(Asyms...), ) = $arrays
($(stridesyms...),) = $strides
($(indsyms...), ) = $offsets
$initial
$loop
@assert ndims(B) == $nd
$checkshape
@nloops $nd i B d->(@nexprs $narrays k->(j_d_k = size(A_k, d) == 1 ? 1 : i_d)) begin
@nexprs $narrays k->(@inbounds v_k = @nref $nd A_k d->j_d_k)
@inbounds (@nref $nd B i) = (@ncall $narrays $F v)
end
B
end
end


## Generation of inner loop staged functions ##

function code_inner(fname::Symbol, extra_args::Vector, initial,
innermost::Function)
quote
function $fname(shape::(Int...), arrays::(StridedArray...),
offsets::Vector{Int}, strides::Matrix{Int},
$(extra_args...))
f = eval(code_inner_loop($(quot(fname)), $(quot(extra_args)),
$(quot(initial)), $(quot(innermost)),
length(arrays), length(shape)))
f(shape, arrays, offsets, strides, $(extra_args...))
function broadcast!_function(nd::Int, narrays::Int, f::Function)
As = [symbol("A_"*string(i)) for i = 1:narrays]
body = gen_broadcast_body(nd, narrays, f)
@eval begin
local _F_
function _F_(B, $(As...))
$body
end
_F_
end
end

code_foreach_inner(fname::Symbol, extra_args::Vector, innermost::Function) =
code_inner(fname, extra_args, quote end, innermost)

function code_map!_inner(fname::Symbol, dest, extra_args::Vector,
innermost::Function)
@gensym k
code_inner(fname, {dest, extra_args...}, :($k=1),
(els...)->quote
@inbounds $dest[$k] = $(innermost(:($dest[$k]), els...))
$k += 1
end)
end


## (Generation of) complete broadcast functions ##

function code_broadcasts(name::String, op)
fname, fname_T, fname! = [gensym("broadcast$(infix)_$name")
for infix in ("", "_T", "!")]

inner!, inner!! = gensym("$(name)_inner!"), gensym("$(name)!_inner!")
innerdef = code_map!_inner(inner!, :(result::Array), [],
(dest, els...) -> :( $op($(els...)) ))
innerdef! = code_foreach_inner(inner!!, [],
(dest, els...) -> :( $dest=$op($(els...)) ))
quote
$innerdef
$fname_T{T}(::Type{T}) = $op()
function $fname_T{T}(::Type{T}, As::StridedArray...)
shape = broadcast_shape(As...)
result = Array(T, shape)
$inner!(broadcast_args(shape, As)..., result)
result
end

function $fname(As::StridedArray...)
$fname_T(Base.promote_eltype(As...), As...)
end

function $fname{T}(As::StridedArray{T}...)
$fname_T(T, As...)
end

function $fname(As::StridedArray{Bool}...)
$fname_T(typeof($op(true,true)), As...)
end

$innerdef!
function $fname!(dest::StridedArray, As::StridedArray...)
shape = size(dest)
check_broadcast_shape(shape, As...)
$inner!!(broadcast_args(shape, tuple(dest, As...))...)
dest
end

($fname, $fname_T, $fname!)
let broadcast_cache = Dict()
global broadcast!
function broadcast!(f::Function, B, As...)
nd = ndims(B)
narrays = length(As)
key = (f, nd, narrays)
if !haskey(broadcast_cache,key)
func = broadcast!_function(nd, narrays, f)
broadcast_cache[key] = func
else
func = broadcast_cache[key]
end
func(B, As...)
end
end # let broadcast_cache

eval(code_map!_inner(:broadcast_getindex_inner!,
:(result::Array), [:(A::AbstractArray)],
(dest, inds...) -> :( A[$(inds...)] )))
function broadcast_getindex(A::AbstractArray,
ind1::StridedArray{Int},
inds::StridedArray{Int}...)
inds = tuple(ind1, inds...)
shape = broadcast_shape(inds...)
result = Array(eltype(A), shape)
broadcast_getindex_inner!(broadcast_args(shape, inds)..., result, A)
result
end

eval(code_foreach_inner(:broadcast_setindex!_inner!, [:(A::AbstractArray)],
(x, inds...)->:( A[$(inds...)] = $x )))
function broadcast_setindex!(A::AbstractArray, X::StridedArray,
ind1::StridedArray{Int},
inds::StridedArray{Int}...)
Xinds = tuple(X, ind1, inds...)
shape = broadcast_shape(Xinds...)
broadcast_setindex!_inner!(broadcast_args(shape, Xinds)..., A)
Xinds[1]
end


## actual functions for broadcast and broadcast! ##

broadcastfuns = ObjectIdDict()
function broadcast_functions(op::Function)
(haskey(broadcastfuns, op) ? broadcastfuns[op] :
(broadcastfuns[op] = eval(code_broadcasts(string(op), quot(op)))))::NTuple{3,Function}
end

broadcast_function(op::Function) = broadcast_functions(op)[1]
broadcast_T_function(op::Function) = broadcast_functions(op)[2]
broadcast!_function(op::Function) = broadcast_functions(op)[3]

broadcast(op::Function) = op()
broadcast(op::Function, As::StridedArray...) = broadcast_function(op)(As...)

function broadcast_T{T}(op::Function, ::Type{T}, As::StridedArray...)
broadcast_T_function(op)(T, As...)
end
broadcast(f::Function, As...) = broadcast!(f, Array(promote_eltype(As...), broadcast_shape(As...)), As...)

function broadcast!(op::Function, dest::StridedArray, As::StridedArray...)
broadcast!_function(op)(dest, As...)
end
broadcast!_function(f::Function) = (B, As...) -> broadcast!(f, B, As...)
broadcast_function(f::Function) = (As...) -> broadcast(f, As...)


## elementwise operators ##

const broadcast_add = broadcast_function(+)
const broadcast_sub = broadcast_function(-)
const broadcast_mul = broadcast_function(*)
const broadcast_rem = broadcast_function(%)
const broadcast_div_T = broadcast_T_function(/)
const broadcast_rdiv_T = broadcast_T_function(\)
const broadcast_pow_T = broadcast_T_function(^)

.+(As::StridedArray...) = broadcast_add(As...)
.*(As::StridedArray...) = broadcast_mul(As...)
.-(A::StridedArray, B::StridedArray) = broadcast_sub(A, B)
.%(A::StridedArray, B::StridedArray) = broadcast_rem(A, B)
.+(As::AbstractArray...) = broadcast(+, As...)
.*(As::AbstractArray...) = broadcast(*, As...)
.-(A::AbstractArray, B::AbstractArray) = broadcast(-, A, B)
.%(A::AbstractArray, B::AbstractArray) = broadcast(%, A, B)

type_div(T,S) = promote_type(T,S)
type_div{T<:Integer,S<:Integer}(::Type{T},::Type{S}) = typeof(one(T)/one(S))
type_div{T,S}(::Type{Complex{T}},::Type{Complex{S}}) = Complex{type_div(T,S)}
type_div{T,S}(::Type{Complex{T}},::Type{S}) = Complex{type_div(T,S)}
type_div{T,S}(::Type{T},::Type{Complex{S}}) = Complex{type_div(T,S)}

function ./(A::StridedArray, B::StridedArray)
broadcast_div_T(type_div(eltype(A), eltype(B)), A, B)
function ./(A::AbstractArray, B::AbstractArray)
broadcast!(/, Array(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
end

function .\(A::StridedArray, B::StridedArray)
broadcast_rdiv_T(type_div(eltype(B), eltype(A)), A, B)
function .\(A::AbstractArray, B::AbstractArray)
broadcast!(\, Array(type_div(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
end

type_pow(T,S) = promote_type(T,S)
type_pow{S<:Integer}(::Type{Bool},::Type{S}) = Bool
type_pow{S}(T,::Type{Rational{S}}) = type_pow(T, type_div(S, S))

function .^(A::StridedArray, B::StridedArray)
broadcast_pow_T(type_pow(eltype(A), eltype(B)), A, B)
function .^(A::AbstractArray, B::AbstractArray)
broadcast!(^, Array(type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
end


Expand Down
16 changes: 8 additions & 8 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ for arr in (identity, as_sub)
@test broadcast(+, arr([1 0]), arr([1, 4])) == [2 1; 5 4]
@test broadcast(+, arr([1, 0]), arr([1 4])) == [2 5; 1 4]
@test broadcast(+, arr([1, 0]), arr([1, 4])) == [2, 4]
@test broadcast(+) == 0
@test broadcast(*) == 1
# @test broadcast(+) == 0
# @test broadcast(*) == 1

This comment has been minimized.

Copy link
@timholy

timholy Jan 13, 2014

Author Member

@toivoh, are these two tests important?

This comment has been minimized.

Copy link
@toivoh

toivoh Jan 13, 2014

Contributor

Don't quite remember. If we want the behavior to produce a scalar for zero-argument broadcast calls then I think that it's good to test it. If there is a reason not to have this behavior then we can discuss it.

This comment has been minimized.

Copy link
@timholy

timholy Jan 14, 2014

Author Member

With zero arguments, how do you even decide which kind of 0 to return? Int, Float64, or something else? Should the output be a scalar, or a 0-dimensional array?

FYI, in this branch, broadcast(+) results in ERROR: no method convert(Type{None}, Int64), which arises because it's trying to assign the result of +() (which gives 0::Int, for some reason) to an Array(None, ()). The None comes from Base.promote_eltype, and the () from Broadcast.broadcast_shape. I can't think of a situation where an error in this case is bad, but I may not be thinking about it correctly.

This comment has been minimized.

Copy link
@toivoh

toivoh Jan 14, 2014

Contributor

+() figured out which zero :). Anyway, it's probably better not to fix this until we have a better idea how we'd want it to work.


@test arr(eye(2)) .+ arr([1, 4]) == arr([2 1; 4 5])
@test arr(eye(2)) .+ arr([1 4]) == arr([2 4; 1 5])
Expand All @@ -42,10 +42,10 @@ for arr in (identity, as_sub)
@test arr([1 2]) .\ arr([3, 4]) == [3 1.5; 4 2]
@test arr([3 4]) .^ arr([1, 2]) == [3 4; 9 16]

M = arr([11 12; 21 22])
@test broadcast_getindex(M, eye(Int, 2)+1,arr([1, 2])) == [21 11; 12 22]

A = arr(zeros(2,2))
broadcast_setindex!(A, arr([21 11; 12 22]), eye(Int, 2)+1,arr([1, 2]))
@test A == M
# M = arr([11 12; 21 22])
# @test broadcast_getindex(M, eye(Int, 2)+1,arr([1, 2])) == [21 11; 12 22]
#
# A = arr(zeros(2,2))
# broadcast_setindex!(A, arr([21 11; 12 22]), eye(Int, 2)+1,arr([1, 2]))
# @test A == M

This comment has been minimized.

Copy link
@timholy

timholy Jan 13, 2014

Author Member

These should probably be restored, once I understand what these functions take as arguments.

This comment has been minimized.

Copy link
@timholy

timholy Jan 13, 2014

Author Member

..or are they essentially an artifact of not being able to work on general AbstractArrays?

end

0 comments on commit 4ef39e1

Please sign in to comment.