Skip to content

Commit

Permalink
Merge pull request #270 from JuliaGPU/tb/mapreduce_broadcast
Browse files Browse the repository at this point in the history
Support and use broadcast with mapreduce.
  • Loading branch information
maleadt authored May 8, 2020
2 parents 82be58e + d7d648f commit 8494cfd
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 21 deletions.
44 changes: 37 additions & 7 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
bc′ = Broadcast.preprocess(dest, bc)
gpu_call(dest, bc′; name="broadcast") do ctx, dest, bc′
let I = CartesianIndex(@cartesianidx(dest))
let I = @cartesianidx(dest)
@inbounds dest[I] = bc′[I]
end
return
Expand All @@ -81,12 +81,42 @@ end
allequal(x) = true
allequal(x, y, z...) = x == y && allequal(y, z...)

function Base.map!(f, y::GPUDestArray, xs::AbstractArray...)
@assert allequal(size.((y, xs...))...)
return y .= f.(xs...)
function Base.map!(f, dest::GPUDestArray, xs::AbstractArray...)
indices = LinearIndices.((dest, xs...))
common_length = minimum(length.(indices))

# custom broadcast, ignoring the container size mismatches
# (avoids the reshape + view that our mapreduce impl has to do)
bc = Broadcast.instantiate(Broadcast.broadcasted(f, xs...))
bc′ = Broadcast.preprocess(dest, bc)
gpu_call(dest, bc′; name="map!", total_threads=common_length) do ctx, dest, bc′
i = linear_index(ctx)
if i <= common_length
I = CartesianIndices(axes(bc′))[i]
@inbounds dest[i] = bc′[I]
end
return
end

return dest
end

function Base.map(f, y::GPUDestArray, xs::AbstractArray...)
@assert allequal(size.((y, xs...))...)
return f.(y, xs...)
function Base.map(f, x::GPUDestArray, xs::AbstractArray...)
# if argument sizes match, their shape needs to be preserved
xs = (x, xs...)
if allequal(size.(xs)...)
return f.(xs...)
end

# if not, treat them as iterators
indices = LinearIndices.(xs)
common_length = minimum(length.(indices))

# construct a broadcast to figure out the destination container
bc = Broadcast.instantiate(Broadcast.broadcasted(f, xs...))
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
isbitstype(ElType) || error("Cannot map function returning non-isbits $ElType.")
dest = similar(bc, ElType, common_length)

return map!(f, dest, xs...)
end
44 changes: 32 additions & 12 deletions src/host/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# map-reduce

const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}

# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
init=nothing) = error("Not implemented") # COV_EXCL_LINE
# resolve ambiguities
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)

neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
Expand All @@ -18,11 +23,30 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)

function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
# resolve ambiguities
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)

function _mapreduce(f, op, As...; dims, init)
# mapreduce should apply `f` like `map` does, consuming elements like iterators.
bc = if allequal(size.(As)...)
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
else
# TODO: can we avoid the reshape + view?
indices = LinearIndices.(As)
common_length = minimum(length.(indices))
Bs = map(As) do A
view(reshape(A, length(A)), 1:common_length)
end
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
end

# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions.
if init === nothing
ET = Base.promote_op(f, map(eltype, As)...)
ET = Broadcast.combine_eltypes(bc.f, bc.args)
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("mapreduce cannot figure the output element type, please pass an explicit init value")
Expand All @@ -32,14 +56,10 @@ function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
ET = typeof(init)
end

# TODO: Broadcast-semantics after JuliaLang-julia#31020
A = first(As)
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))

sz = size(A)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
R = similar(A, ET, red)
mapreducedim!(f, op, R, As...; init=init)
sz = size(bc)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
R = similar(bc, ET, red)
mapreducedim!(identity, op, R, bc; init=init)

if dims==Colon()
@allowscalar R[]
Expand All @@ -57,7 +77,7 @@ Base.count(pred::Function, A::AbstractGPUArray) = mapreduce(pred, +, A; init = 0

Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))

# avoid calling into `initarray!``
# avoid calling into `initarray!`
Base.sum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.add_sum, R, A)
Base.prod!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.mul_prod, R, A)
Base.maximum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(max, R, A)
Expand Down
5 changes: 3 additions & 2 deletions src/reference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
reshape(reinterpret(T, A.data), size)

function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=nothing)
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
init=nothing)
if init !== nothing
fill!(R, init)
end
@allowscalar Base.reducedim!(op, R.data, map(f, As...))
@allowscalar Base.reducedim!(op, R.data, map(f, A))
end

end
24 changes: 24 additions & 0 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,30 @@ function broadcasting(AT)
@test compare((A, B) -> A .* B, AT, rand(ET, 40, 40), rand(ET, 40, 40))
@test compare((A, B) -> A .* B .+ ET(10), AT, rand(ET, 40, 40), rand(ET, 40, 40))
end

@testset "map! $ET" begin
@test compare(AT, rand(2,2), rand(2,2)) do x,y
map!(+, x, y)
end
@test compare(AT, rand(2), rand(2,2)) do x,y
map!(+, x, y)
end
@test compare(AT, rand(2,2), rand(2)) do x,y
map!(+, x, y)
end
end

@testset "map $ET" begin
@test compare(AT, rand(2,2), rand(2,2)) do x,y
map(+, x, y)
end
@test compare(AT, rand(2), rand(2,2)) do x,y
map(+, x, y)
end
@test compare(AT, rand(2,2), rand(2)) do x,y
map(+, x, y)
end
end
end

@testset "0D" begin
Expand Down
9 changes: 9 additions & 0 deletions test/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ function test_mapreduce(AT)
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
end
end
@testset "broadcasting behavior" begin
@test compare((x,y)->mapreduce(+, +, x, y), AT,
rand(range, 1), rand(range, 2, 2))
@test compare(AT, rand(range, 1), rand(range, 2, 2)) do x, y
bc = Broadcast.instantiate(Broadcast.broadcasted(*, x, y))
reduce(+, bc)
end

end
end
end
@testset "any all ==" begin
Expand Down

0 comments on commit 8494cfd

Please sign in to comment.