Skip to content

Commit

Permalink
Support and use broadcast with mapreduce.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 7, 2020
1 parent ff47dcc commit ac2dc24
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
31 changes: 20 additions & 11 deletions src/host/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# map-reduce

const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}

# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
init=nothing) = error("Not implemented") # COV_EXCL_LINE
# resolve ambiguities
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)

neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
Expand All @@ -18,11 +23,19 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)

function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
# resolve ambiguities
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)

function _mapreduce(f, op, As...; dims, init)
bc = Broadcast.instantiate(Broadcast.broadcasted(f, As...))

# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions.
if init === nothing
ET = Base.promote_op(f, map(eltype, As)...)
ET = Broadcast.combine_eltypes(bc.f, bc.args)
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("mapreduce cannot figure the output element type, please pass an explicit init value")
Expand All @@ -32,14 +45,10 @@ function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
ET = typeof(init)
end

# TODO: Broadcast-semantics after JuliaLang-julia#31020
A = first(As)
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))

sz = size(A)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
R = similar(A, ET, red)
mapreducedim!(f, op, R, As...; init=init)
sz = size(bc)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
R = similar(bc, ET, red)
mapreducedim!(identity, op, R, bc; init=init)

if dims==Colon()
@allowscalar R[]
Expand Down
5 changes: 3 additions & 2 deletions src/reference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
reshape(reinterpret(T, A.data), size)

function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=nothing)
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
init=nothing)
if init !== nothing
fill!(R, init)
end
@allowscalar Base.reducedim!(op, R.data, map(f, As...))
@allowscalar Base.reducedim!(op, R.data, map(f, A))
end

end
9 changes: 9 additions & 0 deletions test/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ function test_mapreduce(AT)
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
end
end
@testset "broadcasting behavior" begin
@test compare((x,y)->mapreduce(+, +, x, y), AT,
rand(range, 1), rand(range, 2, 2))
@test compare(AT, rand(range, 1), rand(range, 2, 2)) do x, y
bc = Broadcast.instantiate(Broadcast.broadcasted(*, x, y))
reduce(bc)
end

end
end
end
@testset "any all ==" begin
Expand Down

0 comments on commit ac2dc24

Please sign in to comment.