Skip to content

Commit

Permalink
add support for mapreduce, foreach, broadcast... (#276)
Browse files Browse the repository at this point in the history
supported are now: `asyncmap`, `broadcast`, `broadcast!`, `foreach`, `map`, `mapfoldl`, `mapfoldr`, `mapreduce`, `pmap` and `reduce`
  • Loading branch information
MarcMush authored Sep 12, 2023
1 parent 9e5002e commit e6015a3
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 20 deletions.
45 changes: 33 additions & 12 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,9 @@ interval in seconds between updates to the user. You may optionally
supply a custom message to be printed that specifies the computation
being performed.
`@showprogress` works for loops, comprehensions, map, reduce, and pmap.
`@showprogress` works for loops, comprehensions, `asyncmap`,
`broadcast`, `broadcast!`, `foreach`, `map`, `mapfoldl`,
`mapfoldr`, `mapreduce`, `pmap` and `reduce`.
"""
macro showprogress(args...)
showprogress(args...)
Expand All @@ -892,7 +894,8 @@ function showprogress(args...)
return expr
end
metersym = gensym("meter")
mapfuns = (:map, :asyncmap, :reduce, :pmap)
mapfuns = (:asyncmap, :broadcast, :broadcast!, :foreach, :map,
:mapfoldl, :mapfoldr, :mapreduce, :pmap, :reduce)
kind = :invalid # :invalid, :loop, or :map

if isa(expr, Expr)
Expand Down Expand Up @@ -994,7 +997,7 @@ function showprogress(args...)

# get args to map to determine progress length
mapargs = collect(Any, filter(call.args[2:end]) do a
return isa(a, Symbol) || !(a.head in (:kw, :parameters))
return isa(a, Symbol) || isa(a, Number) || !(a.head in (:kw, :parameters))
end)
if expr.head == :do
insert!(mapargs, 1, :nothing) # to make args for ncalls line up
Expand Down Expand Up @@ -1035,6 +1038,7 @@ function progress_map(args...; mapfun=map,
progress=Progress(ncalls(mapfun, args)),
channel_bufflen=min(1000, ncalls(mapfun, args)),
kwargs...)
isempty(args) && return mapfun(; kwargs...)
f = first(args)
other_args = args[2:end]
channel = RemoteChannel(()->Channel{Bool}(channel_bufflen), 1)
Expand Down Expand Up @@ -1069,17 +1073,34 @@ progress_pmap(args...; kwargs...) = progress_map(args...; mapfun=pmap, kwargs...
"""
Infer the number of calls to the mapped function (i.e. the length of the returned array) given the input arguments to map, reduce or pmap.
"""
function ncalls(mapfun::Function, map_args)
if mapfun == pmap && length(map_args) >= 2 && isa(map_args[2], AbstractWorkerPool)
relevant = map_args[3:end]
else
relevant = map_args[2:end]
end
if isempty(relevant)
error("Unable to determine number of calls in $mapfun. Too few arguments?")
function ncalls(::typeof(broadcast), map_args)
length(map_args) < 2 && return 1
return prod(length, Broadcast.combine_axes(map_args[2:end]...))
end

function ncalls(::typeof(broadcast!), map_args)
length(map_args) < 2 && return 1
return length(map_args[2])
end

function ncalls(::Union{typeof(mapreduce),typeof(mapfoldl),typeof(mapfoldr)}, map_args)
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
end

function ncalls(::typeof(pmap), map_args)
if length(map_args) 2 && map_args[2] isa AbstractWorkerPool
length(map_args) < 3 && return 1
return minimum(length, map_args[3:end])
else
return maximum(length(arg) for arg in relevant)
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
end
end

function ncalls(mapfun::Function, map_args)
length(map_args) < 2 && return 1
return minimum(length, map_args[2:end])
end

end # module
101 changes: 93 additions & 8 deletions test/test_map.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using Test
using Distributed
import ProgressMeter.ncalls

procs = addprocs(2)
wp = WorkerPool(procs)
@everywhere using ProgressMeter

@testset "map tests" begin
Expand Down Expand Up @@ -50,6 +53,30 @@ procs = addprocs(2)
end
println()

# test ncalls
@test ncalls(map, (+, 1:10)) == 10
@test ncalls(pmap, (+, 1:10, 1:100)) == 10
@test ncalls(pmap, (+, wp, 1:10)) == 10
@test ncalls(reduce, (+, 1:10)) == 10
@test ncalls(mapreduce, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldl, (+, +, 1:10, (1:10)')) == 10
@test ncalls(mapfoldr, (+, +, 1:10, (1:10)')) == 10
@test ncalls(foreach, (+, 1:10)) == 10
@test ncalls(broadcast, (+, 1:10, 1:10)) == 10
@test ncalls(broadcast, (+, 1:8, (1:7)', 1)) == 8*7
@test ncalls(broadcast, (+, 1:3, (1:5)', ones(1,1,2))) == 3*5*2
@test ncalls(broadcast!, (+, zeros(10,8))) == 80
@test ncalls(broadcast!, (+, zeros(10,8,7), 1:10)) == 10*8*7

@test ncalls(map, (time,)) == 1
@test ncalls(foreach, (time,)) == 1
@test ncalls(broadcast, (time,)) == 1
@test ncalls(broadcast!, (time, [1])) == 1
@test ncalls(mapreduce, (time, +)) == 1

@test_throws DimensionMismatch ncalls(broadcast, (+, 1:10, 1:100))
@test_throws DimensionMismatch ncalls(broadcast, (+, 1:100, 1:10))

# @showprogress
vals = @showprogress map(1:10) do x
return x^2
Expand All @@ -66,11 +93,53 @@ procs = addprocs(2)
end
@test vals == map(x->x^2, 1:10)

vals = @showprogress pmap(wp, 1:10) do x
x^2
end
@test vals == map(x->x^2, 1:10)

val = @showprogress reduce(1:10) do x, y
return x + y
end
@test val == reduce((x, y)->x+y, 1:10)

val = @showprogress mapreduce(+, 1:10) do x
return x^2
end
@test val == mapreduce(x->x^2, +, 1:10)

val = @showprogress mapfoldl(-, 1:10) do x
return x^2
end
@test val == mapfoldl(x->x^2, -, 1:10)

val = @showprogress mapfoldr(-, 1:10) do x
return x^2
end
@test val == mapfoldr(x->x^2, -, 1:10)

@showprogress foreach(1:10) do x
print(x)
end

val = @showprogress broadcast(1:10, (1:10)') do x,y
return x+y
end
@test val == broadcast(+, 1:10, (1:10)')

A = zeros(10,8)
@showprogress broadcast!(A, 1:10, (1:8)') do x,y
return x+y
end
@test A == broadcast(+, 1:10, (1:8)')

@showprogress broadcast!(A, 1:10) do x
return x
end
@test A == repeat(1:10, 1, 8)



# function passed by name
function testfun(x)
return x^2
Expand All @@ -79,8 +148,32 @@ procs = addprocs(2)
@test vals == map(testfun, 1:10)
vals = @showprogress pmap(testfun, 1:10)
@test vals == map(testfun, 1:10)
vals = @showprogress pmap(testfun, wp, 1:10)
@test vals == map(testfun, 1:10)
val = @showprogress reduce(+, 1:10)
@test val == reduce(+, 1:10)
val = @showprogress mapreduce(testfun, +, 1:10)
@test val == mapreduce(testfun, +, 1:10)
val = @showprogress mapfoldl(testfun, -, 1:10)
@test val == mapfoldl(testfun, -, 1:10)
val = @showprogress mapfoldr(testfun, -, 1:10)
@test val == mapfoldr(testfun, -, 1:10)
@showprogress foreach(print, 1:10)
println()
val = @showprogress broadcast(+, 1:10, (1:12)')
@test val == broadcast(+, 1:10, (1:12)')
@showprogress broadcast!(+, A, 1:10, 1:10, (1:8)', 3)
@test A == broadcast(+, 1:10, 1:10, (1:8)', 3)

# test function with no arg
function constfun()
return 42
end
@test map(constfun) == @showprogress map(constfun)
@test broadcast(constfun) == @showprogress broadcast(constfun)
#@test mapreduce(constfun, error) == @showprogress mapreduce(constfun, error) # julia 1.2+
@showprogress foreach(printlnconstfun)


# #136: make sure mid progress shows up even without sleep
println("Verify that intermediate progress is displayed:")
Expand All @@ -95,15 +188,7 @@ procs = addprocs(2)



# abstract worker pool arg
wp = WorkerPool(procs)
vals = @showprogress pmap(testfun, wp, 1:10)
@test vals == map(testfun, 1:10)

vals = @showprogress pmap(wp, 1:10) do x
x^2
end
@test vals == map(testfun, 1:10)



Expand Down

0 comments on commit e6015a3

Please sign in to comment.