diff --git a/base/multi.jl b/base/multi.jl index 38aab509dc45f..5b9b5ad789c3b 100644 --- a/base/multi.jl +++ b/base/multi.jl @@ -37,19 +37,36 @@ end hash(r::RRID, h::UInt) = hash(r.whence, hash(r.id, h)) ==(r::RRID, s::RRID) = (r.whence==s.whence && r.id==s.id) +## Wire format description +# +# Each message has three parts, which are written in order to the worker's stream. +# 1) A header of type MsgHeader is serialized to the stream (via `serialize`). +# 2) A message of type AbstractMsg is then serialized. +# 3) Finally, a fixed bounday of 10 bytes is written. + +# Message header stored separately from body to be able to send back errors if +# a deserialization error occurs when reading the message body. +type MsgHeader + response_oid::RRID + notify_oid::RRID +end + +# Special oid (0,0) uses to indicate a null ID. +# Used instead of Nullable to decrease wire size of header. +null_id(id) = id == RRID(0, 0) + +MsgHeader(;response_oid::RRID=RRID(0,0), notify_oid::RRID=RRID(0,0)) = + MsgHeader(response_oid, notify_oid) type CallMsg{Mode} <: AbstractMsg f::Function args::Tuple kwargs::Array - response_oid::RRID end type CallWaitMsg <: AbstractMsg f::Function args::Tuple kwargs::Array - response_oid::RRID - notify_oid::RRID end type RemoteDoMsg <: AbstractMsg f::Function @@ -57,10 +74,10 @@ type RemoteDoMsg <: AbstractMsg kwargs::Array end type ResultMsg <: AbstractMsg - response_oid::RRID value::Any end + # Worker initialization messages type IdentifySocketMsg <: AbstractMsg from_pid::Int @@ -70,34 +87,32 @@ end type JoinPGRPMsg <: AbstractMsg self_pid::Int other_workers::Array - notify_oid::RRID topology::Symbol worker_pool end type JoinCompleteMsg <: AbstractMsg - notify_oid::RRID cpu_cores::Int ospid::Int end -function send_msg_unknown(s::IO, msg) +function send_msg_unknown(s::IO, header, msg) error("attempt to send to unknown socket") end -function send_msg(s::IO, msg) +function send_msg(s::IO, header, msg) id = worker_id_from_socket(s) if id > -1 - return send_msg(worker_from_id(id), msg) + return send_msg(worker_from_id(id), header, msg) end - send_msg_unknown(s, msg) + send_msg_unknown(s, header, msg) end -function send_msg_now(s::IO, msg::AbstractMsg) +function send_msg_now(s::IO, msghdr, msg::AbstractMsg) id = worker_id_from_socket(s) if id > -1 - return send_msg_now(worker_from_id(id), msg) + return send_msg_now(worker_from_id(id), msghdr, msg) end - send_msg_unknown(s, msg) + send_msg_unknown(s, msghdr, msg) end abstract ClusterManager @@ -197,12 +212,12 @@ function set_worker_state(w, state) notify(w.c_state; all=true) end -function send_msg_now(w::Worker, msg) - send_msg_(w, msg, true) +function send_msg_now(w::Worker, msghdr, msg) + send_msg_(w, msghdr, msg, true) end -function send_msg(w::Worker, msg) - send_msg_(w, msg, false) +function send_msg(w::Worker, msghdr, msg) + send_msg_(w, msghdr, msg, false) end function flush_gc_msgs(w::Worker) @@ -241,14 +256,20 @@ function check_worker_state(w::Worker) end end +# Boundary inserted between messages on the wire, used for recovering +# from deserialization errors. Picked arbitrarily. +# A size of 10 bytes indicates ~ ~1e24 possible boundaries, so chance of collision with message contents is trivial. +const MSG_BOUNDARY = UInt8[0x79, 0x8e, 0x8e, 0xf5, 0x6e, 0x9b, 0x2e, 0x97, 0xd5, 0x7d] -function send_msg_(w::Worker, msg, now::Bool) +function send_msg_(w::Worker, header, msg, now::Bool) check_worker_state(w) io = w.w_stream lock(io.lock) try reset_state(w.w_serializer) + serialize(w.w_serializer, header) serialize(w.w_serializer, msg) # io is wrapped in w_serializer + write(io, MSG_BOUNDARY) if !now && w.gcflag flush_gc_msgs(w) @@ -768,7 +789,6 @@ function showerror(io::IO, re::RemoteException) showerror(io, re.captured) end - function run_work_thunk(thunk, print_error) local result try @@ -811,7 +831,7 @@ end function remotecall(f, w::Worker, args...; kwargs...) rr = Future(w) #println("$(myid()) asking for $rr") - send_msg(w, CallMsg{:call}(f, args, kwargs, remoteref_id(rr))) + send_msg(w, MsgHeader(response_oid=remoteref_id(rr)), CallMsg{:call}(f, args, kwargs)) rr end @@ -829,7 +849,7 @@ function remotecall_fetch(f, w::Worker, args...; kwargs...) oid = RRID() rv = lookup_ref(oid) rv.waitingfor = w.id - send_msg(w, CallMsg{:call_fetch}(f, args, kwargs, oid)) + send_msg(w, MsgHeader(response_oid=oid), CallMsg{:call_fetch}(f, args, kwargs)) v = take!(rv) delete!(PGRP.refs, oid) isa(v, RemoteException) ? throw(v) : v @@ -846,7 +866,7 @@ function remotecall_wait(f, w::Worker, args...; kwargs...) rv = lookup_ref(prid) rv.waitingfor = w.id rr = Future(w) - send_msg(w, CallWaitMsg(f, args, kwargs, remoteref_id(rr), prid)) + send_msg(w, MsgHeader(response_oid=remoteref_id(rr), notify_oid=prid), CallWaitMsg(f, args, kwargs)) v = fetch(rv.c) delete!(PGRP.refs, prid) isa(v, RemoteException) && throw(v) @@ -866,7 +886,7 @@ function remote_do(f, w::LocalProcess, args...; kwargs...) end function remote_do(f, w::Worker, args...; kwargs...) - send_msg(w, RemoteDoMsg(f, args, kwargs)) + send_msg(w, MsgHeader(), RemoteDoMsg(f, args, kwargs)) nothing end @@ -952,13 +972,13 @@ close(rr::RemoteChannel) = call_on_owner(close_ref, rr) function deliver_result(sock::IO, msg, oid, value) #print("$(myid()) sending result $oid\n") - if is(msg,:call_fetch) || isa(value, RemoteException) + if is(msg, :call_fetch) || isa(value, RemoteException) val = value else val = :OK end try - send_msg_now(sock, ResultMsg(oid, val)) + send_msg_now(sock, MsgHeader(response_oid=oid), ResultMsg(val)) catch e # terminate connection in case of serialization error # otherwise the reading end would hang @@ -996,28 +1016,73 @@ function process_messages(r_stream::IO, w_stream::IO, incoming=true) end function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool) + wpid=0 # the worker r_stream is connected to. + boundary = similar(MSG_BOUNDARY) try version = process_hdr(r_stream, incoming) serializer = ClusterSerializer(r_stream) + + # The first message will associate wpid with r_stream + msghdr = deserialize(serializer) + msg = deserialize(serializer) + readbytes!(r_stream, boundary, length(MSG_BOUNDARY)) + + handle_msg(msg, msghdr, r_stream, w_stream, version) + wpid = worker_id_from_socket(r_stream) + + @assert wpid > 0 + while true reset_state(serializer) - msg = deserialize(serializer) - # println("got msg: ", msg) - handle_msg(msg, r_stream, w_stream, version) + msghdr = deserialize(serializer) +# println("msghdr: ", msghdr) + + try + msg = deserialize(serializer) + catch e + # Deserialization error; discard bytes in stream until boundary found + boundary_idx = 1 + while true + # This may throw an EOF error if the terminal boundary was not written + # correctly, triggering the higher-scoped catch block below + byte = read(r_stream, UInt8) + if byte == MSG_BOUNDARY[boundary_idx] + boundary_idx += 1 + if boundary_idx > length(MSG_BOUNDARY) + break + end + else + boundary_idx = 1 + end + end + # println("Deserialization error.") + remote_err = RemoteException(myid(), CapturedException(e, catch_backtrace())) + if !null_id(msghdr.response_oid) + ref = lookup_ref(msghdr.response_oid) + put!(ref, remote_err) + end + if !null_id(msghdr.notify_oid) + deliver_result(w_stream, :call_fetch, msghdr.notify_oid, remote_err) + end + continue + end + readbytes!(r_stream, boundary, length(MSG_BOUNDARY)) + + # println("got msg: ", typeof(msg)) + handle_msg(msg, msghdr, r_stream, w_stream, version) end catch e # println(STDERR, "Process($(myid())) - Exception ", e) - iderr = worker_id_from_socket(r_stream) - if (iderr < 1) + if (wpid < 1) println(STDERR, e) println(STDERR, "Process($(myid())) - Unknown remote, closing connection.") else - werr = worker_from_id(iderr) + werr = worker_from_id(wpid) oldstate = werr.state set_worker_state(werr, W_TERMINATED) - # If error occured talking to pid 1, commit harakiri - if iderr == 1 + # If unhandleable error occured talking to pid 1, exit + if wpid == 1 if isopen(w_stream) print(STDERR, "fatal error on ", myid(), ": ") display_error(e, catch_backtrace()) @@ -1028,15 +1093,15 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool) # Will treat any exception as death of node and cleanup # since currently we do not have a mechanism for workers to reconnect # to each other on unhandled errors - deregister_worker(iderr) + deregister_worker(wpid) end isopen(r_stream) && close(r_stream) isopen(w_stream) && close(w_stream) - if (myid() == 1) && (iderr > 1) + if (myid() == 1) && (wpid > 1) if oldstate != W_TERMINATING - println(STDERR, "Worker $iderr terminated.") + println(STDERR, "Worker $wpid terminated.") rethrow(e) end end @@ -1071,44 +1136,44 @@ function process_hdr(s, validate_cookie) return VersionNumber(strip(String(version))) end -function handle_msg(msg::CallMsg{:call}, r_stream, w_stream, version) - schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) +function handle_msg(msg::CallMsg{:call}, msghdr, r_stream, w_stream, version) + schedule_call(msghdr.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) end -function handle_msg(msg::CallMsg{:call_fetch}, r_stream, w_stream, version) +function handle_msg(msg::CallMsg{:call_fetch}, msghdr, r_stream, w_stream, version) @schedule begin v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false) - deliver_result(w_stream, :call_fetch, msg.response_oid, v) + deliver_result(w_stream, :call_fetch, msghdr.response_oid, v) end end -function handle_msg(msg::CallWaitMsg, r_stream, w_stream, version) +function handle_msg(msg::CallWaitMsg, msghdr, r_stream, w_stream, version) @schedule begin - rv = schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) - deliver_result(w_stream, :call_wait, msg.notify_oid, fetch(rv.c)) + rv = schedule_call(msghdr.response_oid, ()->msg.f(msg.args...; msg.kwargs...)) + deliver_result(w_stream, :call_wait, msghdr.notify_oid, fetch(rv.c)) end end -function handle_msg(msg::RemoteDoMsg, r_stream, w_stream, version) +function handle_msg(msg::RemoteDoMsg, msghdr, r_stream, w_stream, version) @schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true) end -function handle_msg(msg::ResultMsg, r_stream, w_stream, version) - put!(lookup_ref(msg.response_oid), msg.value) +function handle_msg(msg::ResultMsg, msghdr, r_stream, w_stream, version) + put!(lookup_ref(msghdr.response_oid), msg.value) end -function handle_msg(msg::IdentifySocketMsg, r_stream, w_stream, version) +function handle_msg(msg::IdentifySocketMsg, msghdr, r_stream, w_stream, version) # register a new peer worker connection w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version) send_connection_hdr(w, false) - send_msg_now(w, IdentifySocketAckMsg()) + send_msg_now(w, MsgHeader(), IdentifySocketAckMsg()) end -function handle_msg(msg::IdentifySocketAckMsg, r_stream, w_stream, version) +function handle_msg(msg::IdentifySocketAckMsg, msghdr, r_stream, w_stream, version) w = map_sock_wrkr[r_stream] w.version = version end -function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream, version) +function handle_msg(msg::JoinPGRPMsg, msghdr, r_stream, w_stream, version) LPROC.id = msg.self_pid controller = Worker(1, r_stream, w_stream, cluster_manager; version=version) register_worker(LPROC) @@ -1129,7 +1194,7 @@ function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream, version) set_default_worker_pool(msg.worker_pool) send_connection_hdr(controller, false) - send_msg_now(controller, JoinCompleteMsg(msg.notify_oid, Sys.CPU_CORES, getpid())) + send_msg_now(controller, MsgHeader(notify_oid=msghdr.notify_oid), JoinCompleteMsg(Sys.CPU_CORES, getpid())) end function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConfig) @@ -1138,7 +1203,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf w = Worker(rpid, r_s, w_s, manager; config=wconfig) process_messages(w.r_stream, w.w_stream, false) send_connection_hdr(w, true) - send_msg_now(w, IdentifySocketMsg(myid())) + send_msg_now(w, MsgHeader(), IdentifySocketMsg(myid())) catch e display_error(e, catch_backtrace()) println(STDERR, "Error [$e] on $(myid()) while connecting to peer $rpid. Exiting.") @@ -1146,7 +1211,7 @@ function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConf end end -function handle_msg(msg::JoinCompleteMsg, r_stream, w_stream, version) +function handle_msg(msg::JoinCompleteMsg, msghdr, r_stream, w_stream, version) w = map_sock_wrkr[r_stream] environ = get(w.config.environ, Dict()) environ[:cpu_cores] = msg.cpu_cores @@ -1154,7 +1219,7 @@ function handle_msg(msg::JoinCompleteMsg, r_stream, w_stream, version) w.config.ospid = msg.ospid w.version = version - ntfy_channel = lookup_ref(msg.notify_oid) + ntfy_channel = lookup_ref(msghdr.notify_oid) put!(ntfy_channel, w.id) push!(default_worker_pool(), w) @@ -1478,7 +1543,7 @@ function create_worker(manager, wconfig) all_locs = map(x -> isa(x, Worker) ? (get(x.config.connect_at, ()), x.id) : ((), x.id, true), join_list) send_connection_hdr(w, true) - send_msg_now(w, JoinPGRPMsg(w.id, all_locs, ntfy_oid, PGRP.topology, default_worker_pool())) + send_msg_now(w, MsgHeader(notify_oid=ntfy_oid), JoinPGRPMsg(w.id, all_locs, PGRP.topology, default_worker_pool())) @schedule manage(w.manager, w.id, w.config, :register) wait(rr_ntfy_join) diff --git a/test/parallel_exec.jl b/test/parallel_exec.jl index 978204b210628..4c5eb78c5cd73 100644 --- a/test/parallel_exec.jl +++ b/test/parallel_exec.jl @@ -1025,3 +1025,13 @@ f_myid = ()->myid() @test wrkr1 == remotecall_fetch(f_myid, wrkr1) @test wrkr2 == remotecall_fetch(f_myid, wrkr2) @test wrkr2 == remotecall_fetch((f, p)->remotecall_fetch(f, p), wrkr1, f_myid, wrkr2) + +# Deserialization error recovery test +let + bad_thunk = ()->NonexistantModule.f() + @test_throws RemoteException remotecall_fetch(bad_thunk, 2) + # Test that the stream is still usable + @test remotecall_fetch(()->:test,2) == :test + ref = remotecall(bad_thunk, 2) + @test_throws RemoteException fetch(ref) +end