Skip to content

Commit

Permalink
Merge pull request #29623 from JuliaLang/jb/syncdistributed
Browse files Browse the repository at this point in the history
add `Event`; use it to fix race in Distributed setup
  • Loading branch information
JeffBezanson authored Oct 19, 2018
2 parents bf8135d + 9ee0b14 commit f6344d3
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 3 deletions.
58 changes: 56 additions & 2 deletions base/locks.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

import .Base: _uv_hook_close, unsafe_convert,
lock, trylock, unlock, islocked
lock, trylock, unlock, islocked, wait, notify

export SpinLock, RecursiveSpinLock, Mutex
export SpinLock, RecursiveSpinLock, Mutex, Event


##########################################
Expand Down Expand Up @@ -238,3 +238,57 @@ end
function islocked(m::Mutex)
return m.ownertid != 0
end

"""
Event()
Create a level-triggered event source. Tasks that call [`wait`](@ref) on an
`Event` are suspended and queued until `notify` is called on the `Event`.
After `notify` is called, the `Event` remains in a signaled state and
tasks will no longer block when waiting for it.
"""
mutable struct Event
lock::Mutex
q::Vector{Task}
set::Bool
# TODO: use a Condition with its paired lock
Event() = new(Mutex(), Task[], false)
end

function wait(e::Event)
e.set && return
lock(e.lock)
while !e.set
ct = current_task()
push!(e.q, ct)
unlock(e.lock)
try
wait()
catch
filter!(x->x!==ct, e.q)
rethrow()
end
lock(e.lock)
end
unlock(e.lock)
return nothing
end

function notify(e::Event)
lock(e.lock)
if !e.set
e.set = true
for t in e.q
schedule(t)
end
empty!(e.q)
end
unlock(e.lock)
return nothing
end

# TODO: decide what to call this
#function clear(e::Event)
# e.set = false
# return nothing
#end
2 changes: 2 additions & 0 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mutable struct Worker
manager::ClusterManager
config::WorkerConfig
version::Union{VersionNumber, Nothing} # Julia version of the remote process
initialized::Threads.Event

function Worker(id::Int, r_stream::IO, w_stream::IO, manager::ClusterManager;
version::Union{VersionNumber, Nothing}=nothing,
Expand All @@ -90,6 +91,7 @@ mutable struct Worker
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w.initialized = Threads.Event()
register_worker(w)
w
end
Expand Down
3 changes: 3 additions & 0 deletions stdlib/Distributed/src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ end

function send_msg_(w::Worker, header, msg, now::Bool)
check_worker_state(w)
if myid() != 1 && !isa(msg, IdentifySocketMsg) && !isa(msg, IdentifySocketAckMsg)
wait(w.initialized)
end
io = w.w_stream
lock(io.lock)
try
Expand Down
5 changes: 4 additions & 1 deletion stdlib/Distributed/src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ end

function handle_msg(msg::IdentifySocketMsg, header, r_stream, w_stream, version)
# register a new peer worker connection
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
w = Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
send_connection_hdr(w, false)
send_msg_now(w, MsgHeader(), IdentifySocketAckMsg())
notify(w.initialized)
end

function handle_msg(msg::IdentifySocketAckMsg, header, r_stream, w_stream, version)
Expand All @@ -301,6 +302,7 @@ end
function handle_msg(msg::JoinPGRPMsg, header, r_stream, w_stream, version)
LPROC.id = msg.self_pid
controller = Worker(1, r_stream, w_stream, cluster_manager; version=version)
notify(controller.initialized)
register_worker(LPROC)
topology(msg.topology)

Expand Down Expand Up @@ -340,6 +342,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf
process_messages(w.r_stream, w.w_stream, false)
send_connection_hdr(w, true)
send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid()))
notify(w.initialized)
catch e
@error "Error on $(myid()) while connecting to peer $rpid, exiting" exception=e,catch_backtrace()
exit(1)
Expand Down
13 changes: 13 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,16 @@ function test_thread_too_few_iters()
@test !(true in found[nthreads():end])
end
test_thread_too_few_iters()

let e = Event()
done = false
t = @async (wait(e); done = true)
sleep(0.1)
@test done == false
notify(e)
wait(t)
@test done == true
blocked = true
wait(@async (wait(e); blocked = false))
@test !blocked
end

0 comments on commit f6344d3

Please sign in to comment.