Skip to content

Commit

Permalink
bpart: Start tracking backedges for bindings (JuliaLang#57213)
Browse files Browse the repository at this point in the history
This PR adds limited backedge support for Bindings. There are two
classes of bindings that get backedges:

1. Cross-module `GlobalRef` bindings (new in this PR)
2. Any globals accesses through intrinsics (i.e. those with forward
edges from JuliaLang#57009)

This is a time/space trade-off for invalidation. As a result of the
first category, invalidating a binding now only needs to scan all the
methods defined in the same module as the binding. At the same time, it
is anticipated that most binding references are to bindings in the same
module, keeping the list of bindings that need explicit (back)edges
small.
  • Loading branch information
Keno authored Feb 2, 2025
1 parent dbc6681 commit e485be8
Show file tree
Hide file tree
Showing 23 changed files with 311 additions and 112 deletions.
2 changes: 1 addition & 1 deletion Compiler/src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method,
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
_uncompressed_ir
_uncompressed_ir, maybe_add_binding_backedge!
using Base.Order

import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,
Expand Down
24 changes: 16 additions & 8 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(fun
end
if const_edge !== nothing
edge = const_edge
update_valid_age!(sv, world_range(const_edge))
end
end

Expand Down Expand Up @@ -2330,6 +2331,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
end
if const_edge !== nothing
edge = const_edge
update_valid_age!(sv, world_range(const_edge))
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo′, sig)
Expand Down Expand Up @@ -2396,8 +2398,9 @@ function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, s
if M isa Const && s isa Const
M, s = M.val, s.val
if M isa Module && s isa Symbol
(ret, bpart) = abstract_eval_globalref(interp, GlobalRef(M, s), saw_latestworld, sv)
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(bpart))
gr = GlobalRef(M, s)
(ret, bpart) = abstract_eval_globalref(interp, gr, saw_latestworld, sv)
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(convert(Core.Binding, gr), bpart))
end
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
elseif !hasintersect(widenconst(M), Module) || !hasintersect(widenconst(s), Symbol)
Expand Down Expand Up @@ -2475,8 +2478,9 @@ function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState,
if isa(M, Const) && isa(s, Const)
M, s = M.val, s.val
if M isa Module && s isa Symbol
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, GlobalRef(M, s), v)
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(partition))
gr = GlobalRef(M, s)
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, gr, v)
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(convert(Core.Binding, gr), partition))
end
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
end
Expand Down Expand Up @@ -2564,14 +2568,15 @@ function abstract_eval_replaceglobal!(interp::AbstractInterpreter, sv::AbsIntSta
M, s = M.val, s.val
M isa Module || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
s isa Symbol || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
partition = abstract_eval_binding_partition!(interp, GlobalRef(M, s), sv)
gr = GlobalRef(M, s)
partition = abstract_eval_binding_partition!(interp, gr, sv)
rte = abstract_eval_partition_load(interp, partition)
if binding_kind(partition) == BINDING_KIND_GLOBAL
T = partition_restriction(partition)
end
exct = Union{rte.exct, global_assignment_binding_rt_exct(interp, partition, v)[2]}
effects = merge_effects(rte.effects, Effects(setglobal!_effects, nothrow=exct===Bottom))
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(partition))
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(convert(Core.Binding, gr), partition))
else
sg = abstract_eval_setglobal!(interp, sv, saw_latestworld, M, s, v)
end
Expand Down Expand Up @@ -2791,6 +2796,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
if const_edge !== nothing
edge = const_edge
update_valid_age!(sv, world_range(const_edge))
end
end
end
Expand Down Expand Up @@ -3225,7 +3231,8 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
end

effects = EFFECTS_TOTAL
partition = lookup_binding_partition!(interp, GlobalRef(mod, sym), sv)
gr = GlobalRef(mod, sym)
partition = lookup_binding_partition!(interp, gr, sv)
if allow_import !== true && is_some_imported(binding_kind(partition))
if allow_import === false
rt = Const(false)
Expand All @@ -3243,7 +3250,7 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
effects = Effects(generic_isdefinedglobal_effects, nothrow=true)
end
end
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(partition))
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(convert(Core.Binding, gr), partition))
end

function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, @nospecialize(M), @nospecialize(s), @nospecialize(allow_import_arg), @nospecialize(order_arg), saw_latestworld::Bool, sv::AbsIntState)
Expand Down Expand Up @@ -3454,6 +3461,7 @@ end

world_range(ir::IRCode) = ir.valid_worlds
world_range(ci::CodeInfo) = WorldRange(ci.min_world, ci.max_world)
world_range(ci::CodeInstance) = WorldRange(ci.min_world, ci.max_world)
world_range(compact::IncrementalCompact) = world_range(compact.ir)

function force_binding_resolution!(g::GlobalRef, world::UInt)
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ function activate!(; reflection=true, codegen=false)
Base.REFLECTION_COMPILER[] = Compiler
end
if codegen
activate_codegen!()
bootstrap!()
end
end
5 changes: 4 additions & 1 deletion Compiler/src/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int,
imported_binding = partition_restriction(bpart)::Core.Binding
bpart = lookup_binding_partition(min_world(ir.valid_worlds), imported_binding)
end
if !is_defined_const_binding(binding_kind(bpart)) || (bpart.max_world < max_world(ir.valid_worlds))
if (!is_defined_const_binding(binding_kind(bpart)) || (bpart.max_world < max_world(ir.valid_worlds))) &&
(op.mod !== Core) && (op.mod !== Base)
# Core and Base are excluded because the frontend uses them for intrinsics, etc.
# TODO: Decide which way to go with these.
@verify_error "Unbound or partitioned GlobalRef not allowed in value position"
raise_error()
end
Expand Down
8 changes: 5 additions & 3 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,12 @@ Represents access to a global through runtime reflection, rather than as a manif
perform such accesses.
"""
struct GlobalAccessInfo <: CallInfo
b::Core.Binding
bpart::Core.BindingPartition
end
GlobalAccessInfo(::Nothing) = NoCallInfo()
add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo) =
push!(edges, info.bpart)
GlobalAccessInfo(::Core.Binding, ::Nothing) = NoCallInfo()
function add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo)
push!(edges, info.b)
end

@specialize
3 changes: 2 additions & 1 deletion Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
i += 1
continue
elseif isa(item, Core.BindingPartition)
elseif isa(item, Core.Binding)
i += 1
maybe_add_binding_backedge!(item, caller)
continue
end
if isa(item, CodeInstance)
Expand Down
2 changes: 1 addition & 1 deletion Compiler/test/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ let code = Any[
Expr(:boundscheck),
Compiler.GotoIfNot(SSAValue(1), 6),
# block 2
Expr(:call, GlobalRef(Base, :size), Compiler.Argument(3)),
Expr(:call, size, Compiler.Argument(3)),
Compiler.ReturnNode(),
# block 3
Core.PhiNode(),
Expand Down
4 changes: 4 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ let os = ccall(:jl_get_UNAME, Any, ())
end
end

# subarrays
include("subarray.jl")
include("views.jl")

# numeric operations
include("hashing.jl")
include("rounding.jl")
Expand Down
2 changes: 0 additions & 2 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ include("indices.jl")
include("genericmemory.jl")
include("array.jl")
include("abstractarray.jl")
include("subarray.jl")
include("views.jl")
include("baseext.jl")

include("c.jl")
Expand Down
142 changes: 97 additions & 45 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
visit(mt) || return false
end
end
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
# this is the original/primary binding for the submodule
foreach_module_mtable(visit, v, world) || return false
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
# this is probably an external method table here, so let's
# assume so as there is no way to precisely distinguish them
Expand All @@ -48,83 +45,138 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
return true
end

function foreach_reachable_mtable(visit, world::UInt)
visit(TYPE_TYPE_MT) || return
visit(NONFUNCTION_MT) || return
for mod in loaded_modules_array()
foreach_module_mtable(visit, mod, world)
function foreachgr(visit, src::CodeInfo)
stmts = src.code
for i = 1:length(stmts)
stmt = stmts[i]
isa(stmt, GlobalRef) && visit(stmt)
for ur in Compiler.userefs(stmt)
arg = ur[]
isa(arg, GlobalRef) && visit(arg)
end
end
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
found_any = false
labelchangemap = nothing
function anygr(visit, src::CodeInfo)
stmts = src.code
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
for i = 1:length(stmts)
stmt = stmts[i]
if isgr(stmt)
found_any = true
if isa(stmt, GlobalRef)
visit(stmt) && return true
continue
end
for ur in Compiler.userefs(stmt)
arg = ur[]
# If any of the GlobalRefs in this stmt match the one that
# we are about, we need to move out all GlobalRefs to preserve
# effect order, in case we later invalidate a different GR
if isa(arg, GlobalRef)
if isgr(arg)
@assert !isa(stmt, PhiNode)
found_any = true
break
end
end
isa(arg, GlobalRef) && visit(arg) && return true
end
end
return found_any
return false
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
return anygr(isgr, src)
end

function scan_edge_list(ci::Core.CodeInstance, bpart::Core.BindingPartition)
function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
isdefined(ci, :edges) || return false
edges = ci.edges
i = 1
while i <= length(edges)
if isassigned(edges, i) && edges[i] === bpart
if isassigned(edges, i) && edges[i] === binding
return true
end
i += 1
end
return false
end

function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
binding = convert(Core.Binding, gr)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
end

function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
try
valid_in_valuepos = false
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, invalidated_bpart))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
end
b = convert(Core.Binding, gr)
if isdefined(b, :backedges)
for edge in b.backedges
if isa(edge, CodeInstance)
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
else
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
end
end
end
catch err
bt = catch_backtrace()
invokelatest(Base.println, "Internal Error during invalidation:")
invokelatest(Base.display_error, err, bt)
end
end

gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod

# N.B.: This needs to match jl_maybe_add_binding_backedge
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
method = isa(edge, Method) ? edge : edge.def.def::Method
gr_needs_backedge_in_module(b.globalref, method.module) || return
if !isdefined(b, :backedges)
b.backedges = Any[]
end
!isempty(b.backedges) && b.backedges[end] === edge && return
push!(b.backedges, edge)
end

function binding_was_invalidated(b::Core.Binding)
# At least one partition is required for invalidation
!isdefined(b, :partitions) && return false
b.partitions.min_world > unsafe_load(cglobal(:jl_require_world, UInt))
end

function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method)
isdefined(method, :source) || return
src = _uncompressed_ir(method)
mod = method.module
foreachgr(src) do gr::GlobalRef
b = convert(Core.Binding, gr)
binding_was_invalidated(b) && push!(methods_with_invalidated_source, method)
maybe_add_binding_backedge!(b, method)
end
end

function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any})
methods_with_invalidated_source = IdSet{Method}()
for method in internal_methods
if isa(method, Method)
scan_new_method!(methods_with_invalidated_source, method)
end
end
for tme::Core.TypeMapEntry in extext_methods
scan_new_method!(methods_with_invalidated_source, tme.func::Method)
end
return methods_with_invalidated_source
end
6 changes: 3 additions & 3 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using .Base:
using Core: @doc

using .Base:
cld, fld, SubArray, view, resize!, IndexCartesian
cld, fld, resize!, IndexCartesian
using .Base.Checked: checked_mul

import .Base:
Expand Down Expand Up @@ -1327,7 +1327,7 @@ eltype(::Type{PartitionIterator{T}}) where {T} = Vector{eltype(T)}
# Arrays use a generic `view`-of-a-`vec`, so we cannot exactly predict what we'll get back
eltype(::Type{PartitionIterator{T}}) where {T<:AbstractArray} = AbstractVector{eltype(T)}
# But for some common implementations in Base we know the answer exactly
eltype(::Type{PartitionIterator{T}}) where {T<:Vector} = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, true}
eltype(::Type{PartitionIterator{T}}) where {T<:Vector} = Base.SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, true}

IteratorEltype(::Type{PartitionIterator{T}}) where {T} = IteratorEltype(T)
IteratorEltype(::Type{PartitionIterator{T}}) where {T<:AbstractArray} = EltypeUnknown()
Expand All @@ -1353,7 +1353,7 @@ end
function iterate(itr::PartitionIterator{<:AbstractArray}, state = firstindex(itr.c))
state > lastindex(itr.c) && return nothing
r = min(state + itr.n - 1, lastindex(itr.c))
return @inbounds view(itr.c, state:r), r + 1
return @inbounds Base.view(itr.c, state:r), r + 1
end

struct IterationCutShort; end
Expand Down
Loading

0 comments on commit e485be8

Please sign in to comment.