Skip to content

Commit

Permalink
initial attempt of auto serialization of globals in a cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
amitmurthy committed Dec 20, 2016
1 parent 03cc3a8 commit 6b6f423
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 36 deletions.
88 changes: 85 additions & 3 deletions base/clusterserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

import .Serializer: known_object_data, object_number, serialize_cycle, deserialize_cycle, writetag,
__deserialized_types__, serialize_typename, deserialize_typename,
TYPENAME_TAG, object_numbers
TYPENAME_TAG, GLOBALREF_TAG, object_numbers,
serialize_global_from_main, deserialize_global_from_main

type ClusterSerializer{I<:IO} <: AbstractSerializer
io::I
counter::Int
table::ObjectIdDict

pid::Int # Worker we are connected to.
sent_objects::Set{UInt64} # used by serialize (track objects sent)
sent_globals::Dict

ClusterSerializer(io::I) = new(io, 0, ObjectIdDict(), Set{UInt64}())
ClusterSerializer(io::I) = new(io, 0, ObjectIdDict(),
Base.worker_id_from_socket(io),
Set{UInt64}(), Dict())
end
ClusterSerializer(io::IO) = ClusterSerializer{typeof(io)}(io)

Expand Down Expand Up @@ -43,6 +48,83 @@ function serialize(s::ClusterSerializer, t::TypeName)
serialize_typename(s, t)
push!(s.sent_objects, identifier)
end
# println(t.module, ":", t.name, ", id:", identifier, send_whole ? " sent" : " NOT sent")
# #println(t.module, ":", t.name, ", id:", identifier, send_whole ? " sent" : " NOT sent")
nothing
end

const FLG_SER_VAL = UInt8(1)
const FLG_ISCONST_VAL = UInt8(2)
isflagged(v, flg) = (v & flg == flg)

# We will send/resend a global object if
# a) has not been sent previously, i.e., we are seeing this object_id for the
# for the first time, or,
# b) hash value has changed

function serialize_global_from_main(s::ClusterSerializer, g::GlobalRef)
v = getfield(Main, g.name)
println(g)

serialize(s, g.name)

flags = UInt8(0)
if isbits(v)
flags = flags | FLG_SER_VAL
else
oid = object_id(v)
if haskey(s.sent_globals, oid)
# We have sent this object before, see if it has changed.
prev_hash = s.sent_globals[oid]
new_hash = hash(v)
if new_hash != prev_hash
flags = flags | FLG_SER_VAL
s.sent_globals[oid] = new_hash

# No need to setup a new finalizer as only the hash
# value and not the object itself has changed.
end
else
flags = flags | FLG_SER_VAL
try
finalizer(v, x->delete_global_tracker(s,x))
s.sent_globals[oid] = hash(v)
catch ex
# Do not track objects that cannot be finalized.
end
end
end
isconst(Main, g.name) && (flags = flags | FLG_ISCONST_VAL)

write(s.io, flags)
isflagged(flags, FLG_SER_VAL) && serialize(s, v)
end

function deserialize_global_from_main(s::ClusterSerializer)
sym = deserialize(s)::Symbol
flags = read(s.io, UInt8)

if isflagged(flags, FLG_SER_VAL)
v = deserialize(s)
end

# create/update binding under Main only if the value has been sent
if isflagged(flags, FLG_SER_VAL)
if isflagged(flags, FLG_ISCONST_VAL)
eval(Main, :(const $sym = $v))
else
eval(Main, :($sym = $v))
end
end

return GlobalRef(Main, sym)
end

function delete_global_tracker(s::ClusterSerializer, v)
oid = object_id(v)
if haskey(s.sent_globals, oid)
delete!(s.sent_globals, oid)
end

# TODO: Should release memory from the remote nodes.
end

28 changes: 25 additions & 3 deletions base/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,21 +381,36 @@ end

function serialize(s::AbstractSerializer, g::GlobalRef)
writetag(s.io, GLOBALREF_TAG)
if g.mod === Main && isdefined(g.mod, g.name) && isconst(g.mod, g.name)
if g.mod === Main && isdefined(g.mod, g.name)
v = getfield(g.mod, g.name)
if isa(v, DataType) && v === v.name.primary && should_send_whole_type(s, v)
# handle references to types in Main by sending the whole type.
# needed to be able to send nested functions (#15451).
write(s.io, UInt8(1))
serialize(s, v)
return
elseif g.name in names(Main, false, false)
# FIXME :
# 1. There must be a better way to detect if a binding has been imported
# into Main or has been primarily defined here.
# 2. Handle bindings in Main pointing to bindings in Base, e.g., my_foo=myid.
write(s.io, UInt8(2))
serialize_global_from_main(s, g)
return
end
end

write(s.io, UInt8(0))
serialize_global_ref(s, g)
end

function serialize_global_ref(s::AbstractSerializer, g::GlobalRef)
serialize(s, g.mod)
serialize(s, g.name)
end

# default impl only serializes the symbol.
serialize_global_from_main(s::AbstractSerializer, g::GlobalRef) = serialize_global_ref(s, g)

function serialize(s::AbstractSerializer, t::TypeName)
serialize_cycle(s, t) && return
Expand Down Expand Up @@ -730,13 +745,20 @@ end
function deserialize(s::AbstractSerializer, ::Type{GlobalRef})
kind = read(s.io, UInt8)
if kind == 0
return GlobalRef(deserialize(s)::Module, deserialize(s)::Symbol)
else
return deserialize_global_ref(s)
elseif kind == 1
ty = deserialize(s)
return GlobalRef(ty.name.module, ty.name.name)
else # kind == 2
return deserialize_global_from_main(s)
end
end

deserialize_global_ref(s::AbstractSerializer) = GlobalRef(deserialize(s)::Module, deserialize(s)::Symbol)

# default impl is same as any global ref, i.e., only the module and symbol.
deserialize_global_from_main(s::AbstractSerializer) = deserialize_global_ref()

function deserialize(s::AbstractSerializer, ::Type{Union})
types = deserialize(s)
Union{types...}
Expand Down
145 changes: 115 additions & 30 deletions test/parallel_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,27 +470,25 @@ for T in [Void, ShmemFoo]
end

# Issue #14664
let
local d = SharedArray(Int,10)
@sync @parallel for i=1:10
d[i] = i
end
d = SharedArray(Int,10)
@sync @parallel for i=1:10
d[i] = i
end

for (x,i) in enumerate(d)
@test x == i
end
for (x,i) in enumerate(d)
@test x == i
end

# complex
local sd = SharedArray(Int,10)
local se = SharedArray(Int,10)
@sync @parallel for i=1:10
sd[i] = i
se[i] = i
end
sc = complex(sd,se)
for (x,i) in enumerate(sc)
@test i == complex(x,x)
end
# complex
sd = SharedArray(Int,10)
se = SharedArray(Int,10)
@sync @parallel for i=1:10
sd[i] = i
se[i] = i
end
sc = complex(sd,se)
for (x,i) in enumerate(sc)
@test i == complex(x,x)
end

# Once finalized accessing remote references and shared arrays should result in exceptions.
Expand Down Expand Up @@ -1128,23 +1126,24 @@ remotecall_fetch(()->eval(:(f16091a() = 2)), wid)
f16091b = () -> 1
remotecall_fetch(()->eval(:(f16091b = () -> 2)), wid)
@test remotecall_fetch(f16091b, 2) === 1
@test remotecall_fetch((myid)->remotecall_fetch(f16091b, myid), wid, myid()) === 2

# FIXME: What is being tested here? Why the difference between named and anonymous functions?
@test remotecall_fetch((myid)->remotecall_fetch(f16091b, myid), wid, myid()) === 1



# issue #16451
let
local rng=RandomDevice()
retval = @parallel (+) for _ in 1:10
rand(rng)
end
@test retval > 0.0 && retval < 10.0
rng=RandomDevice()
retval = @parallel (+) for _ in 1:10
rand(rng)
end
@test retval > 0.0 && retval < 10.0

rand(rng)
retval = @parallel (+) for _ in 1:10
rand(rng)
retval = @parallel (+) for _ in 1:10
rand(rng)
end
@test retval > 0.0 && retval < 10.0
end
@test retval > 0.0 && retval < 10.0

# serialization tests
wrkr1 = workers()[1]
Expand Down Expand Up @@ -1261,3 +1260,89 @@ function test_add_procs_threaded_blas()
rmprocs(processes_added)
end
test_add_procs_threaded_blas()

# Auto serialization of globals from Main.
# bitstypes
global v1 = 1
@test remotecall_fetch(()->v1, id_other) == v1
@test remotecall_fetch(()->isdefined(Main, :v1), id_other) == true
v1 = 2
@test remotecall_fetch(()->v1, id_other) == 2

# non-bitstypes
global v2 = ones(10)
for i in 1:5
v2[i] = i
@test remotecall_fetch(()->v2, id_other) == v2
end

# nested anon functions
global f1 = x->x
global f2 = x->f1(x)
v = rand()
@test remotecall_fetch(f2, id_other, v) == v
@test remotecall_fetch(x->f2(x), id_other, v) == v

# consts
const c1 = ones(10)
@test remotecall_fetch(()->c1, id_other) == c1
@test remotecall_fetch(()->isconst(Main, :c1), id_other) == true

# Test same call with local vars
function wrapped_var_ser_tests()
# bitstypes
local lv1 = 1
@test remotecall_fetch(()->lv1, id_other) == lv1
@test remotecall_fetch(()->isdefined(Main, :lv1), id_other) == false
lv1 = 2
@test remotecall_fetch(()->lv1, id_other) == 2

# non-bitstypes
local lv2 = ones(10)
for i in 1:5
lv2[i] = i
@test remotecall_fetch(()->lv2, id_other) == lv2
end

# nested anon functions
local lf1 = x->x
local lf2 = x->lf1(x)
v = rand()
@test remotecall_fetch(lf2, id_other, v) == v
@test remotecall_fetch(x->lf2(x), id_other, v) == v
end

wrapped_var_ser_tests()

# reported github issues - Mostly tests with globals and various parallel macros
#2669, #5390
v2669=10
@test fetch(@spawn (1+v2669)) == 10

#12367
refs = []
if true
n = 10
for (idx,p) in enumerate(procs())
ref[idx] = @spawnat p begin
@sync for i in 1:n
nothing
end
end
end
end
foreach(wait, refs)

#14399
s = convert(SharedArray, [1,2,3,4]);
@test pmap(i->length(s), 1:2) == [4,4]

#6760
if true
a = 2
x = @parallel (vcat) for k=1:2
sin(a)
end
end
@test x == map(_->sin(2), 1:2)

0 comments on commit 6b6f423

Please sign in to comment.