Skip to content

Commit

Permalink
Cache binding pointer in GlobalRef
Browse files Browse the repository at this point in the history
On certain workloads, profiling shows a surprisingly high fraction of
inference time spent looking up bindings in modules. Bindings use
a hash table, so they're expected to be quite fast, but certainly
not zero. A big contributor to the problem is that we do basically
treat it as zero, looking up bindings for GlobalRefs multiple times
for each statement (e.g. in `isconst`, `isdefined`, to get the types,
etc). This PR attempts to improve the situation by adding an extra
field to GlobalRef that caches this lookup. This field is not serialized
and if not set, we fallback to the previous behavior. I would expect
the memory overhead to be relatively small, since we do intern GlobalRefs
in memory to only have one per binding (rather than one per use).

 # Benchmarks

The benchmarks look quite promising. Consider this artifical example
(though it's actually not all that far fetched, given some of the
generated code we see):

```
using Core.Intrinsics: add_int
const ONE = 1
@eval function f(x, y)
	z = 0
	$([:(z = add_int(x, add_int(z, ONE))) for _ = 1:10000]...)
	return add_int(z, y)
end
g(y) = f(ONE, y)
```

On master:
```
julia> @time @code_typed g(1)
  1.427227 seconds (1.31 M allocations: 58.809 MiB, 5.73% gc time, 99.96% compilation time)
CodeInfo(
1 ─ %1 = invoke Main.f(Main.ONE::Int64, y::Int64)::Int64
└──      return %1
) => Int64
```

On this PR:
```
julia> @time @code_typed g(1)
  0.503151 seconds (1.19 M allocations: 53.641 MiB, 5.10% gc time, 33.59% compilation time)
CodeInfo(
1 ─ %1 = invoke Main.f(Main.ONE::Int64, y::Int64)::Int64
└──      return %1
) => Int64
```

I don't expect the same speedup on other workloads, but there should be
a few % speedup on most workloads still.

 # Future extensions

The other motivation for this is that with a view towards #40399,
we will want to more clearly define when bindings get resolved. The
idea here would then be that binding resolution replaces generic
`GlobalRefs` by GlobalRefs with a set binding cache, and any
unresolved bindings would be treated conservatively by inference
and optimization.
  • Loading branch information
Keno committed Sep 13, 2022
1 parent 545da22 commit 20688f6
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 23 deletions.
4 changes: 3 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ eval(Core, quote
end
LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int32, inlined_at::Int32) =
$(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at))
GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s))
GlobalRef(m::Module, s::Symbol, binding::Ptr{Nothing}) = $(Expr(:new, :GlobalRef, :m, :s, :binding))
SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n))
TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t))
PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values))
Expand Down Expand Up @@ -812,6 +812,8 @@ Unsigned(x::Union{Float16, Float32, Float64, Bool}) = UInt(x)
Integer(x::Integer) = x
Integer(x::Union{Float16, Float32, Float64}) = Int(x)

GlobalRef(m::Module, s::Symbol) = GlobalRef(m, s, bitcast(Ptr{Nothing}, 0))

# Binding for the julia parser, called as
#
# Core._parse(text, filename, lineno, offset, options)
Expand Down
23 changes: 15 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
return sv.argtypes[e.n]
end
elseif isa(e, GlobalRef)
return abstract_eval_global(interp, e.mod, e.name, sv)
return abstract_eval_globalref(interp, e, sv)
end

return Const(e)
Expand Down Expand Up @@ -2256,17 +2256,24 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
return rt
end

function abstract_eval_global(M::Module, s::Symbol)
if isdefined(M, s) && isconst(M, s)
return Const(getglobal(M, s))
function isdefined_globalref(g::GlobalRef)
g.binding != C_NULL && return ccall(:jl_binding_boundp, Cint, (Ptr{Cvoid},), g.binding) != 0
return isdefined(g.mod, g.name)
end

function abstract_eval_globalref(g::GlobalRef)
if isdefined_globalref(g) && isconst(g)
g.binding != C_NULL && return Const(ccall(:jl_binding_value, Any, (Ptr{Cvoid},), g.binding))
return Const(getglobal(g.mod, g.name))
end
ty = ccall(:jl_binding_type, Any, (Any, Any), M, s)
ty = ccall(:jl_binding_type, Any, (Any, Any), g.mod, g.name)
ty === nothing && return Any
return ty
end
abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref(GlobalRef(M, s))

function abstract_eval_global(interp::AbstractInterpreter, M::Module, s::Symbol, frame::Union{InferenceState, IRCode})
rt = abstract_eval_global(M, s)
function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, frame::Union{InferenceState, IRCode})
rt = abstract_eval_globalref(g)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
if isa(rt, Const)
Expand All @@ -2277,7 +2284,7 @@ function abstract_eval_global(interp::AbstractInterpreter, M::Module, s::Symbol,
else
nothrow = true
end
elseif isdefined(M,s)
elseif isdefined_globalref(g)
nothrow = true
end
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly))
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ function argextype(
elseif isa(x, QuoteNode)
return Const(x.value)
elseif isa(x, GlobalRef)
return abstract_eval_global(x.mod, x.name)
return abstract_eval_globalref(x)
elseif isa(x, PhiNode)
return Any
elseif isa(x, PiNode)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function typ_for_val(@nospecialize(x), ci::CodeInfo, sptypes::Vector{Any}, idx::
end
return (ci.ssavaluetypes::Vector{Any})[idx]
end
isa(x, GlobalRef) && return abstract_eval_global(x.mod, x.name)
isa(x, GlobalRef) && return abstract_eval_globalref(x)
isa(x, SSAValue) && return (ci.ssavaluetypes::Vector{Any})[x.id]
isa(x, Argument) && return slottypes[x.n]
isa(x, NewSSAValue) && return DelayedTyp(x)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ function is_throw_call(e::Expr)
if e.head === :call
f = e.args[1]
if isa(f, GlobalRef)
ff = abstract_eval_global(f.mod, f.name)
ff = abstract_eval_globalref(f)
if isa(ff, Const) && ff.val === Core.throw
return true
end
Expand Down
5 changes: 5 additions & 0 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ Determine whether a global is declared `const` in a given module `m`.
isconst(m::Module, s::Symbol) =
ccall(:jl_is_const, Cint, (Any, Any), m, s) != 0

function isconst(g::GlobalRef)
g.binding != C_NULL && return ccall(:jl_binding_is_const, Cint, (Ptr{Cvoid},), g.binding) != 0
return isconst(g.mod, g.name)
end

"""
isconst(t::DataType, s::Union{Int,Symbol}) -> Bool
Expand Down
3 changes: 3 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,9 @@ static void jl_deserialize_struct(jl_serializer_state *s, jl_value_t *v) JL_GC_D
entry->min_world = 1;
entry->max_world = 0;
}
} else if (dt == jl_globalref_type) {
jl_globalref_t *r = (jl_globalref_t*)v;
r->bnd_cache = jl_get_binding(r->mod, r->name);
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2392,12 +2392,6 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svec(1, jl_slotnumber_type),
jl_emptysvec, 0, 0, 1);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(2, "mod", "name"),
jl_svec(2, jl_module_type, jl_symbol_type),
jl_emptysvec, 0, 0, 2);

jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
Expand Down Expand Up @@ -2694,6 +2688,12 @@ void jl_init_types(void) JL_GC_DISABLED

jl_value_t *pointer_void = jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_nothing_type);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "mod", "name", "binding"),
jl_svec(3, jl_module_type, jl_symbol_type, pointer_void),
jl_emptysvec, 0, 0, 3);

tv = jl_svec2(tvar("A"), tvar("R"));
jl_opaque_closure_type = (jl_unionall_t*)jl_new_datatype(jl_symbol("OpaqueClosure"), core, jl_function_type, tv,
jl_perm_symsvec(5, "captures", "world", "source", "invoke", "specptr"),
Expand Down
9 changes: 9 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,13 @@ typedef struct _jl_module_t {
jl_mutex_t lock;
} jl_module_t;

typedef struct {
jl_module_t *mod;
jl_sym_t *name;
// Not serialized. Caches the value of jl_get_binding(mod, name).
jl_binding_t *bnd_cache;
} jl_globalref_t;

// one Type-to-Value entry
typedef struct _jl_typemap_entry_t {
JL_DATA_TYPE
Expand Down Expand Up @@ -1626,6 +1633,8 @@ JL_DLLEXPORT int jl_boundp(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_defines_or_exports_p(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_binding_resolved_p(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_is_const(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_binding_is_const(jl_binding_t *b);
JL_DLLEXPORT int jl_binding_boundp(jl_binding_t *b);
JL_DLLEXPORT jl_value_t *jl_get_global(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var);
JL_DLLEXPORT void jl_set_global(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT);
JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT);
Expand Down
31 changes: 26 additions & 5 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -413,14 +413,17 @@ JL_DLLEXPORT jl_binding_t *jl_get_binding_or_error(jl_module_t *m, jl_sym_t *var
JL_DLLEXPORT jl_value_t *jl_module_globalref(jl_module_t *m, jl_sym_t *var)
{
JL_LOCK(&m->lock);
jl_binding_t *b = (jl_binding_t*)ptrhash_get(&m->bindings, var);
jl_binding_t *b = _jl_get_module_binding(m, var);
if (b == HT_NOTFOUND) {
JL_UNLOCK(&m->lock);
return jl_new_struct(jl_globalref_type, m, var);
return jl_new_struct(jl_globalref_type, m, var, jl_box_voidpointer(NULL));
}
jl_value_t *globalref = jl_atomic_load_relaxed(&b->globalref);
if (globalref == NULL) {
jl_value_t *newref = jl_new_struct(jl_globalref_type, m, var);
jl_value_t *newref = jl_new_struct(jl_globalref_type, m, var, jl_box_voidpointer(NULL));
if (b->owner) {
((jl_globalref_t*)newref)->bnd_cache = b->owner == m ? b : _jl_get_module_binding(b->owner, var);
}
if (jl_atomic_cmpswap_relaxed(&b->globalref, &globalref, newref)) {
JL_GC_PROMISE_ROOTED(newref);
globalref = newref;
Expand Down Expand Up @@ -662,12 +665,18 @@ JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_RO
return b == HT_NOTFOUND ? NULL : b;
}


JL_DLLEXPORT jl_value_t *jl_binding_value(jl_binding_t *b JL_PROPAGATES_ROOT)
{
return b->value;
}

JL_DLLEXPORT jl_value_t *jl_get_global(jl_module_t *m, jl_sym_t *var)
{
jl_binding_t *b = jl_get_binding(m, var);
if (b == NULL) return NULL;
if (b->deprecated) jl_binding_deprecation_warning(m, b);
return b->value;
return jl_binding_value(b);
}

JL_DLLEXPORT void jl_set_global(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT)
Expand Down Expand Up @@ -696,10 +705,22 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
jl_symbol_name(bp->name));
}

JL_DLLEXPORT int jl_binding_is_const(jl_binding_t *b)
{
assert(b);
return b->constp;
}

JL_DLLEXPORT int jl_binding_boundp(jl_binding_t *b)
{
assert(b);
return b->value != 0;
}

JL_DLLEXPORT int jl_is_const(jl_module_t *m, jl_sym_t *var)
{
jl_binding_t *b = jl_get_binding(m, var);
return b && b->constp;
return b && jl_binding_is_const(b);
}

// set the deprecated flag for a binding:
Expand Down
4 changes: 4 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ static void jl_serialize_value__(jl_serializer_state *s, jl_value_t *v, int recu
jl_serialize_value(s, tn->partial);
}
else if (t->layout->nfields > 0) {
if (jl_typeis(v, jl_globalref_type)) {
// Don't save the cached binding reference in staticdata
((jl_globalref_t*)v)->bnd_cache = NULL;
}
char *data = (char*)jl_data_ptr(v);
size_t i, np = t->layout->npointers;
for (i = 0; i < np; i++) {
Expand Down

0 comments on commit 20688f6

Please sign in to comment.