Skip to content

Commit

Permalink
add more broadcast, fix #22
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonDanisch committed Jun 6, 2017
1 parent 32feef4 commit 4f332ea
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,11 @@ end
function Base.broadcast!(f::Function, A::AbstractAccArray, B)
acc_broadcast!(f, A, (B,))
end
function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, args)
acc_broadcast!(f, A, (B, args))
end
function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, C::AbstractAccArray, D, E...)
acc_broadcast!(f, A, (B, C, D, E...))
end
function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, C::AbstractAccArray, D)
acc_broadcast!(f, A, (B, C, D))
function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, args...)
acc_broadcast!(f, A, (B, args...))
end


# TODO check size
function Base.map!(f::Function, A::AbstractAccArray, args::AbstractAccArray...)
acc_broadcast!(f, A, (args...))
Expand Down
12 changes: 12 additions & 0 deletions test/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,15 @@ end
end
@test Array(A) a
end


@allbackends "broadcast" backend begin
u0 = GPUArray(rand(Float32, 32, 32))
tmp = ones(u0)
uprev = ones(u0)
k1 = ones(u0)
a = Float32(2.0)
tmp .= uprev .+ a .* k1
comparison = ones(Float32, 32, 32) .+ a .* ones(Float32, 32, 32)
@test comparison == Array(tmp)
end

0 comments on commit 4f332ea

Please sign in to comment.