Skip to content

Commit cd0414f

Browse files
authored
Return of the threaded detection (#328)
1 parent e60297b commit cd0414f

File tree

2 files changed

+92
-15
lines changed

2 files changed

+92
-15
lines changed

src/ProgressMeter.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ Base.@kwdef mutable struct ProgressCore
8181
numprintedvalues::Int = 0 # num values printed below progress in last iteration
8282
prev_update_count::Int = 1 # counter at last update
8383
printed::Bool = false # true if we have issued at least one status update
84-
safe_lock::Bool = Threads.nthreads() > 1 # set to false for non-threaded tight loops
84+
safe_lock::Int = 2*(Threads.nthreads()>1) # 0: no lock, 1: lock, 2: detect
85+
thread_id::Int = Threads.threadid() # id of the thread that created the progressmeter
8586
tinit::Float64 = time() # time meter was initialized
8687
tlast::Float64 = time() # time of last update
8788
tsecond::Float64 = time() # ignore the first loop given usually uncharacteristically slow
@@ -448,8 +449,20 @@ end
448449

449450
predicted_updates_per_dt_have_passed(p::AbstractProgress) = p.counter - p.prev_update_count >= p.check_iterations
450451

452+
function is_threading(p::AbstractProgress)
453+
p.safe_lock == 0 && return false
454+
p.safe_lock == 1 && return true
455+
if p.thread_id != Threads.threadid()
456+
lock(p.lock) do
457+
p.safe_lock = 1
458+
end
459+
return true
460+
end
461+
return false
462+
end
463+
451464
function lock_if_threading(f::Function, p::AbstractProgress)
452-
if p.safe_lock
465+
if is_threading(p)
453466
lock(p.lock) do
454467
f()
455468
end

test/core.jl

+77-13
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ for ns in [1, 9, 10, 99, 100, 999, 1_000, 9_999, 10_000, 99_000, 100_000, 999_99
2525
end
2626

2727
# Performance test (from #171, #323)
28-
function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=false)
28+
function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=0)
2929
prog = Progress(n; dt, enabled, safe_lock)
3030
x = 0.0
3131
for i in 1:n
@@ -43,38 +43,85 @@ function noprog_perf(n)
4343
return x
4444
end
4545

46+
function prog_threaded(n; dt=0.1, enabled=true, force=false, safe_lock=2)
47+
prog = Progress(n; dt, enabled, safe_lock)
48+
x = Threads.Atomic{Float64}(0.0)
49+
Threads.@threads for i in 1:n
50+
Threads.atomic_add!(x, rand())
51+
next!(prog; force)
52+
end
53+
return x
54+
end
55+
56+
function noprog_threaded(n)
57+
x = Threads.Atomic{Float64}(0.0)
58+
Threads.@threads for i in 1:n
59+
Threads.atomic_add!(x, rand())
60+
end
61+
return x
62+
end
63+
4664
println("Performance tests...")
4765

4866
#precompile
4967
noprog_perf(10)
5068
prog_perf(10)
51-
prog_perf(10; safe_lock=true)
52-
prog_perf(10; dt=9999)
69+
prog_perf(10; safe_lock=1)
70+
prog_perf(10; dt=9999.9)
5371
prog_perf(10; enabled=false)
54-
prog_perf(10; enabled=false, safe_lock=true)
72+
prog_perf(10; enabled=false, safe_lock=1)
5573
prog_perf(10; force=true)
5674

57-
t_noprog = (@elapsed noprog_perf(10^8))/10^8
58-
t_prog = (@elapsed prog_perf(10^8))/10^8
59-
t_lock = (@elapsed prog_perf(10^8; safe_lock=true))/10^8
60-
t_noprint = (@elapsed prog_perf(10^8; dt=9999))/10^8
61-
t_disabled = (@elapsed prog_perf(10^8; enabled=false))/10^8
62-
t_disabled_lock = (@elapsed prog_perf(10^8; enabled=false, safe_lock=true))/10^8
63-
t_force = (@elapsed prog_perf(10^2; force=true))/10^2
75+
noprog_threaded(2*Threads.nthreads())
76+
prog_threaded(2*Threads.nthreads())
77+
prog_threaded(2*Threads.nthreads(); safe_lock=1)
78+
prog_threaded(2*Threads.nthreads(); dt=9999)
79+
prog_threaded(2*Threads.nthreads(); enabled=false)
80+
prog_threaded(2*Threads.nthreads(); force=true)
81+
82+
N = 10^8
83+
N_force = 1000
84+
t_noprog = (@elapsed noprog_perf(N))/N
85+
t_prog = (@elapsed prog_perf(N))/N
86+
t_lock = (@elapsed prog_perf(N; safe_lock=1))/N
87+
t_detect = (@elapsed prog_perf(N; safe_lock=2))/N
88+
t_noprint = (@elapsed prog_perf(N; dt=9999.9))/N
89+
t_disabled = (@elapsed prog_perf(N; enabled=false))/N
90+
t_disabled_lock = (@elapsed prog_perf(N; enabled=false, safe_lock=1))/N
91+
t_force = (@elapsed prog_perf(N_force; force=true))/N_force
92+
93+
Nth = Threads.nthreads() * 10^6
94+
Nth_force = Threads.nthreads() * 100
95+
th_noprog = (@elapsed noprog_threaded(Nth))/Nth
96+
th_detect = (@elapsed prog_threaded(Nth))/Nth
97+
th_lock = (@elapsed prog_threaded(Nth; safe_lock=1))/Nth
98+
th_noprint = (@elapsed prog_threaded(Nth; dt=9999.9))/Nth
99+
th_disabled = (@elapsed prog_threaded(Nth; enabled=false))/Nth
100+
th_force = (@elapsed prog_threaded(Nth_force; force=true))/Nth_force
64101

65102
println("Performance results:")
66103
println("without progress: ", ProgressMeter.speedstring(t_noprog))
67-
println("with defaults: ", ProgressMeter.speedstring(t_prog))
104+
println("with no lock: ", ProgressMeter.speedstring(t_prog))
68105
println("with no printing: ", ProgressMeter.speedstring(t_noprint))
69106
println("with disabled: ", ProgressMeter.speedstring(t_disabled))
70107
println("with lock: ", ProgressMeter.speedstring(t_lock))
108+
println("with automatic lock: ", ProgressMeter.speedstring(t_detect))
71109
println("with lock, disabled: ", ProgressMeter.speedstring(t_disabled_lock))
72110
println("with force: ", ProgressMeter.speedstring(t_force))
111+
println()
112+
println("Threaded performance results: ($(Threads.nthreads()) threads)")
113+
println("without progress: ", ProgressMeter.speedstring(th_noprog))
114+
println("with automatic lock: ", ProgressMeter.speedstring(th_detect))
115+
println("with forced lock: ", ProgressMeter.speedstring(th_lock))
116+
println("with no printing: ", ProgressMeter.speedstring(th_noprint))
117+
println("with disabled: ", ProgressMeter.speedstring(th_disabled))
118+
println("with force: ", ProgressMeter.speedstring(th_force))
73119

74120
if get(ENV, "CI", "false") == "false" # CI environment is too unreliable for performance tests
75121
@test t_prog < 9*t_noprog
76122
end
77123

124+
78125
# Avoid a NaN due to the estimated print time compensation
79126
# https://github.com/timholy/ProgressMeter.jl/issues/209
80127
prog = Progress(10)
@@ -116,7 +163,24 @@ function simple_sum(n; safe_lock = true)
116163
return s
117164
end
118165
p = Progress(10)
119-
@test p.safe_lock == (Threads.nthreads() > 1)
166+
@test (p.safe_lock) == 2*(Threads.nthreads() > 1)
120167
p = Progress(10; safe_lock = false)
121168
@test p.safe_lock == false
122169
@test simple_sum(10; safe_lock = true) simple_sum(10; safe_lock = false)
170+
171+
172+
# Brute-force thread safety
173+
174+
function test_thread(N)
175+
p = Progress(N)
176+
Threads.@threads for _ in 1:N
177+
next!(p)
178+
end
179+
end
180+
181+
println("Brute-forcing thread safety... ($(Threads.nthreads()) threads)")
182+
@time for i in 1:10^5
183+
test_thread(Threads.nthreads())
184+
end
185+
186+

0 commit comments

Comments
 (0)