From a17deeee3a3ed102a40b4ec0a2ec02e6d0fe3616 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Tue, 12 Jan 2016 18:22:04 -0500 Subject: [PATCH] fix large performance regression in remotecall on jb/functions This was mostly due to sending a fresh copy of a closure's type every time. Now reused using the same mechanism as LambdaStaticData. Also, toplevel named functions are assumed present on all processors. --- base/serialize.jl | 77 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/base/serialize.jl b/base/serialize.jl index abe24f83a967f..e93ace7efa571 100644 --- a/base/serialize.jl +++ b/base/serialize.jl @@ -291,25 +291,26 @@ function serialize(s::SerializationState, m::Module) writetag(s.io, EMPTYTUPLE_TAG) end -const lambda_numbers = WeakKeyDict() -lnumber_salt = 0 -function lambda_number(l::LambdaStaticData) - global lnumber_salt, lambda_numbers - if haskey(lambda_numbers, l) - return lambda_numbers[l] +# TODO: make this bidirectional, so objects can be sent back via the same key +const object_numbers = WeakKeyDict() +obj_number_salt = 0 +function object_number(l::ANY) + global obj_number_salt, object_numbers + if haskey(object_numbers, l) + return object_numbers[l] end # a hash function that always gives the same number to the same # object on the same machine, and is unique over all machines. - ln = lnumber_salt+(UInt64(myid())<<44) - lnumber_salt += 1 - lambda_numbers[l] = ln + ln = obj_number_salt+(UInt64(myid())<<44) + obj_number_salt += 1 + object_numbers[l] = ln return ln end function serialize(s::SerializationState, linfo::LambdaStaticData) serialize_cycle(s, linfo) && return writetag(s.io, LAMBDASTATICDATA_TAG) - serialize(s, lambda_number(linfo)) + serialize(s, object_number(linfo)) serialize(s, uncompressed_ast(linfo)) if isdefined(linfo.def, :roots) serialize(s, linfo.def.roots::Vector{Any}) @@ -345,6 +346,7 @@ end function serialize(s::SerializationState, t::TypeName) serialize_cycle(s, t) && return writetag(s.io, TYPENAME_TAG) + serialize(s, object_number(t)) serialize(s, t.name) serialize(s, t.module) serialize(s, t.names) @@ -364,8 +366,18 @@ end # decide whether to send all data for a type (instead of just its name) function should_send_whole_type(s, t::ANY) - # TODO improve somehow - return t.name.module === Main + tn = t.name + if isdefined(tn, :mt) + # TODO improve somehow + # send whole type for anonymous functions in Main + fname = tn.mt.name + mod = tn.module + toplevel = isdefined(mod, fname) && isdefined(t, :instance) && + getfield(mod, fname) === t.instance + ishidden = unsafe_load(unsafe_convert(Ptr{UInt8}, fname))==UInt8('#') + return mod === __deserialized_types__ || (mod === Main && (ishidden || !toplevel)) + end + return false end # `type_itself` means we are serializing a type object. when it's false, we are @@ -451,7 +463,7 @@ function deserialize(s::SerializationState) handle_deserialize(s, Int32(read(s.io, UInt8)::UInt8)) end -function deserialize_cycle(s::SerializationState, x) +function deserialize_cycle(s::SerializationState, x::ANY) if !isimmutable(x) && !typeof(x).pointerfree s.table[s.counter] = x s.counter += 1 @@ -522,12 +534,12 @@ function deserialize(s::SerializationState, ::Type{Module}) m end -const known_lambda_data = Dict() +const known_object_data = Dict() function deserialize(s::SerializationState, ::Type{LambdaStaticData}) lnumber = deserialize(s) - if haskey(known_lambda_data, lnumber) - linfo = known_lambda_data[lnumber]::LambdaStaticData + if haskey(known_object_data, lnumber) + linfo = known_object_data[lnumber]::LambdaStaticData makenew = false else linfo = ccall(:jl_new_lambda_info, Any, (Ptr{Void}, Ptr{Void}, Ptr{Void}, Ptr{Void}), C_NULL, C_NULL, C_NULL, C_NULL)::LambdaStaticData @@ -555,7 +567,7 @@ function deserialize(s::SerializationState, ::Type{LambdaStaticData}) linfo.file = file linfo.line = line linfo.pure = pure - known_lambda_data[lnumber] = linfo + known_object_data[lnumber] = linfo end return linfo end @@ -622,13 +634,19 @@ function deserialize(s::SerializationState, ::Type{Union}) end function deserialize(s::SerializationState, ::Type{TypeName}) + number = deserialize(s) name = deserialize(s) mod = deserialize(s) - tn = ccall(:jl_new_typename_in, Any, (Any, Any), name, mod) + if haskey(known_object_data, number) + tn = known_object_data[number]::TypeName + makenew = false + else + tn = ccall(:jl_new_typename_in, Any, (Any, Any), name, mod) + makenew = true + end deserialize_cycle(s, tn) names = deserialize(s) - tn.names = names super = deserialize(s) parameters = deserialize(s) types = deserialize(s) @@ -636,12 +654,20 @@ function deserialize(s::SerializationState, ::Type{TypeName}) abstr = deserialize(s) mutable = deserialize(s) ninitialized = deserialize(s) - tn.primary = ccall(:jl_new_datatype, Any, (Any, Any, Any, Any, Any, Cint, Cint, Cint), - tn, super, parameters, names, types, - abstr, mutable, ninitialized) + + if makenew + tn.names = names + tn.primary = ccall(:jl_new_datatype, Any, (Any, Any, Any, Any, Any, Cint, Cint, Cint), + tn, super, parameters, names, types, + abstr, mutable, ninitialized) + known_object_data[number] = tn + end tag = Int32(read(s.io, UInt8)::UInt8) if tag != UNDEFREF_TAG - tn.mt = handle_deserialize(s, tag) + mt = handle_deserialize(s, tag) + if makenew + tn.mt = mt + end end return tn @@ -664,6 +690,11 @@ function deserialize_datatype(s::SerializationState) tname.module = __deserialized_types__ tname.name = newname ccall(:jl_set_const, Void, (Any, Any, Any), __deserialized_types__, newname, ty) + if !isdefined(ty,:instance) + if isempty(ty.parameters) && !ty.abstract && ty.size == 0 && (!ty.mutable || isempty(tname.names)) + setfield!(ty, :instance, ccall(:jl_new_struct, Any, (Any,Any...), ty)) + end + end end else name = deserialize(s)::Symbol