Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support and use broadcast with mapreduce. #270

Merged
merged 3 commits into from
May 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ end
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
bc′ = Broadcast.preprocess(dest, bc)
gpu_call(dest, bc′; name="broadcast") do ctx, dest, bc′
let I = CartesianIndex(@cartesianidx(dest))
let I = @cartesianidx(dest)
@inbounds dest[I] = bc′[I]
end
return
Expand All @@ -81,12 +81,42 @@ end
allequal(x) = true
allequal(x, y, z...) = x == y && allequal(y, z...)

function Base.map!(f, y::GPUDestArray, xs::AbstractArray...)
@assert allequal(size.((y, xs...))...)
return y .= f.(xs...)
function Base.map!(f, dest::GPUDestArray, xs::AbstractArray...)
indices = LinearIndices.((dest, xs...))
common_length = minimum(length.(indices))

# custom broadcast, ignoring the container size mismatches
# (avoids the reshape + view that our mapreduce impl has to do)
bc = Broadcast.instantiate(Broadcast.broadcasted(f, xs...))
bc′ = Broadcast.preprocess(dest, bc)
gpu_call(dest, bc′; name="map!", total_threads=common_length) do ctx, dest, bc′
i = linear_index(ctx)
if i <= common_length
I = CartesianIndices(axes(bc′))[i]
@inbounds dest[i] = bc′[I]
end
return
end

return dest
end

function Base.map(f, y::GPUDestArray, xs::AbstractArray...)
@assert allequal(size.((y, xs...))...)
return f.(y, xs...)
function Base.map(f, x::GPUDestArray, xs::AbstractArray...)
# if argument sizes match, their shape needs to be preserved
xs = (x, xs...)
if allequal(size.(xs)...)
return f.(xs...)
end

# if not, treat them as iterators
indices = LinearIndices.(xs)
common_length = minimum(length.(indices))

# construct a broadcast to figure out the destination container
bc = Broadcast.instantiate(Broadcast.broadcasted(f, xs...))
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
isbitstype(ElType) || error("Cannot map function returning non-isbits $ElType.")
dest = similar(bc, ElType, common_length)

return map!(f, dest, xs...)
end
44 changes: 32 additions & 12 deletions src/host/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# map-reduce

const AbstractArrayOrBroadcasted = Union{AbstractArray,Broadcast.Broadcasted}

# GPUArrays' mapreduce methods build on `Base.mapreducedim!`, but with an additional
# argument `init` value to avoid eager initialization of `R` (if set to something).
mapreducedim!(f, op, R::AbstractGPUArray, As::AbstractArray...; init=nothing) = error("Not implemented") # COV_EXCL_LINE
mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArrayOrBroadcasted;
init=nothing) = error("Not implemented") # COV_EXCL_LINE
# resolve ambiguities
maleadt marked this conversation as resolved.
Show resolved Hide resolved
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::AbstractArray) = mapreducedim!(f, op, R, A)
Base.mapreducedim!(f, op, R::AbstractGPUArray, A::Broadcast.Broadcasted) = mapreducedim!(f, op, R, A)

neutral_element(op, T) =
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
Expand All @@ -18,11 +23,30 @@ neutral_element(::typeof(Base.mul_prod), T) = one(T)
neutral_element(::typeof(Base.min), T) = typemax(T)
neutral_element(::typeof(Base.max), T) = typemin(T)

function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
# resolve ambiguities
Base.mapreduce(f, op, A::AbstractGPUArray, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::AbstractArrayOrBroadcasted...;
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)

function _mapreduce(f, op, As...; dims, init)
# mapreduce should apply `f` like `map` does, consuming elements like iterators.
bc = if allequal(size.(As)...)
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
else
# TODO: can we avoid the reshape + view?
indices = LinearIndices.(As)
common_length = minimum(length.(indices))
Bs = map(As) do A
view(reshape(A, length(A)), 1:common_length)
end
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
end

# figure out the destination container type by looking at the initializer element,
# or by relying on inference to reason through the map and reduce functions.
if init === nothing
ET = Base.promote_op(f, map(eltype, As)...)
ET = Broadcast.combine_eltypes(bc.f, bc.args)
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("mapreduce cannot figure the output element type, please pass an explicit init value")
Expand All @@ -32,14 +56,10 @@ function Base.mapreduce(f, op, As::AbstractGPUArray...; dims=:, init=nothing)
ET = typeof(init)
end

# TODO: Broadcast-semantics after JuliaLang-julia#31020
A = first(As)
all(B -> size(A) == size(B), As) || throw(DimensionMismatch("dimensions of containers must be identical"))

sz = size(A)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], ndims(A))
R = similar(A, ET, red)
mapreducedim!(f, op, R, As...; init=init)
sz = size(bc)
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
R = similar(bc, ET, red)
mapreducedim!(identity, op, R, bc; init=init)

if dims==Colon()
@allowscalar R[]
Expand All @@ -57,7 +77,7 @@ Base.count(pred::Function, A::AbstractGPUArray) = mapreduce(pred, +, A; init = 0

Base.:(==)(A::AbstractGPUArray, B::AbstractGPUArray) = Bool(mapreduce(==, &, A, B))

# avoid calling into `initarray!``
# avoid calling into `initarray!`
Base.sum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.add_sum, R, A)
Base.prod!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(Base.mul_prod, R, A)
Base.maximum!(R::AbstractGPUArray, A::AbstractGPUArray) = Base.reducedim!(max, R, A)
Expand Down
5 changes: 3 additions & 2 deletions src/reference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ Adapt.adapt_storage(::Adaptor, x::JLArray{T,N}) where {T,N} =
GPUArrays.unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
reshape(reinterpret(T, A.data), size)

function GPUArrays.mapreducedim!(f, op, R::JLArray, As::AbstractArray...; init=nothing)
function GPUArrays.mapreducedim!(f, op, R::JLArray, A::Union{AbstractArray,Broadcast.Broadcasted};
init=nothing)
if init !== nothing
fill!(R, init)
end
@allowscalar Base.reducedim!(op, R.data, map(f, As...))
@allowscalar Base.reducedim!(op, R.data, map(f, A))
end

end
24 changes: 24 additions & 0 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,30 @@ function broadcasting(AT)
@test compare((A, B) -> A .* B, AT, rand(ET, 40, 40), rand(ET, 40, 40))
@test compare((A, B) -> A .* B .+ ET(10), AT, rand(ET, 40, 40), rand(ET, 40, 40))
end

@testset "map! $ET" begin
@test compare(AT, rand(2,2), rand(2,2)) do x,y
map!(+, x, y)
end
@test compare(AT, rand(2), rand(2,2)) do x,y
map!(+, x, y)
end
@test compare(AT, rand(2,2), rand(2)) do x,y
map!(+, x, y)
end
end

@testset "map $ET" begin
@test compare(AT, rand(2,2), rand(2,2)) do x,y
map(+, x, y)
end
@test compare(AT, rand(2), rand(2,2)) do x,y
map(+, x, y)
end
@test compare(AT, rand(2,2), rand(2)) do x,y
map(+, x, y)
end
end
end

@testset "0D" begin
Expand Down
9 changes: 9 additions & 0 deletions test/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ function test_mapreduce(AT)
ET <: Complex || @test compare(minimum, AT,rand(range, dims))
end
end
@testset "broadcasting behavior" begin
@test compare((x,y)->mapreduce(+, +, x, y), AT,
rand(range, 1), rand(range, 2, 2))
@test compare(AT, rand(range, 1), rand(range, 2, 2)) do x, y
bc = Broadcast.instantiate(Broadcast.broadcasted(*, x, y))
reduce(+, bc)
end

end
end
end
@testset "any all ==" begin
Expand Down