diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 61e892df..0992fc4b 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -213,16 +213,11 @@ end function Base.broadcast!(f::Function, A::AbstractAccArray, B) acc_broadcast!(f, A, (B,)) end -function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, args) - acc_broadcast!(f, A, (B, args)) -end -function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, C::AbstractAccArray, D, E...) - acc_broadcast!(f, A, (B, C, D, E...)) -end -function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, C::AbstractAccArray, D) - acc_broadcast!(f, A, (B, C, D)) +function Base.broadcast!(f::Function, A::AbstractAccArray, B::AbstractAccArray, args...) + acc_broadcast!(f, A, (B, args...)) end + # TODO check size function Base.map!(f::Function, A::AbstractAccArray, args::AbstractAccArray...) acc_broadcast!(f, A, (args...)) diff --git a/test/shared.jl b/test/shared.jl index 65971cdd..adf7b684 100644 --- a/test/shared.jl +++ b/test/shared.jl @@ -84,3 +84,15 @@ end end @test Array(A) ≈ a end + + +@allbackends "broadcast" backend begin + u0 = GPUArray(rand(Float32, 32, 32)) + tmp = ones(u0) + uprev = ones(u0) + k1 = ones(u0) + a = Float32(2.0) + tmp .= uprev .+ a .* k1 + comparison = ones(Float32, 32, 32) .+ a .* ones(Float32, 32, 32) + @test comparison == Array(tmp) +end