Skip to content

Commit

Permalink
optimizer: handle GlobalRef based on assume_bindings_static
Browse files Browse the repository at this point in the history
Following up #53750.
External abstract interpreters with `assume_bindings_static` enabled
assume nothrow-ness of defined `GlobalRef`s no matter if they are
constant during abstract interpretation. However, post-#53750, the
optimizer’s verification switched to stricter checks, causing errors
in such interpreters. This commit makes `IRCode` hold
`assume_bindings_static::Bool` info to ensure that inlining and
verification behaviors are synced with abstract interpretation.
  • Loading branch information
aviatesk committed May 24, 2024
1 parent 424ac6e commit 9a0a1d2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
4 changes: 1 addition & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2836,7 +2836,7 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::
rt = abstract_eval_globalref_type(g)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
if isa(rt, Const)
if isa(rt, Const) # implies `isdefinedconst_globalref(g) === true`
consistent = ALWAYS_TRUE
nothrow = true
if is_mutation_free_argtype(rt)
Expand All @@ -2849,8 +2849,6 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::
else
rt = Union{}
end
elseif isdefinedconst_globalref(g)
nothrow = true
end
return RTEffects(rt, nothrow ? Union{} : UndefVarError, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly))
end
Expand Down
32 changes: 21 additions & 11 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src:
return (false, true, true)
end

assume_bindings_static(ir::IRCode) = ir.assume_bindings_static
assume_bindings_static(compact::IncrementalCompact) = assume_bindings_static(compact.ir)

# Returns a tuple of `(:consistent, :removable, :nothrow)` flags for a given statement.
function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
# TODO: We're duplicating analysis from inference here.
Expand All @@ -298,7 +301,12 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
isa(stmt, GotoNode) && return (true, false, true)
isa(stmt, GotoIfNot) && return (true, false, (𝕃ₒ, argextype(stmt.cond, src), Bool))
if isa(stmt, GlobalRef)
nothrow = consistent = isdefinedconst_globalref(stmt)
if assume_bindings_static(src)
nothrow = isdefined_globalref(stmt)
consistent = nothrow & isconst(stmt)
else
nothrow = consistent = isdefinedconst_globalref(stmt)
end
return (consistent, nothrow, nothrow)
elseif isa(stmt, Expr)
(; head, args) = stmt
Expand Down Expand Up @@ -1129,7 +1137,6 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels.
stmtinfo = sv.stmt_info
meta = Expr[]
idx = 1
oldidx = 1
nstmts = length(code)
Expand Down Expand Up @@ -1229,25 +1236,28 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
renumber_cfg_stmts!(sv.cfg, blockchangemap)
end

for i = 1:length(code)
code[i] = process_meta!(meta, code[i])
end
meta = process_meta!(code)
strip_trailing_junk!(code, ssavaluetypes, ssaflags, di, sv.cfg, stmtinfo)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
# NOTE this `argtypes` contains types of slots yet: it will be modified to contain the
# types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form
# and eliminates slots (see below)
argtypes = sv.slottypes
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes)
return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes,
InferenceParams(sv.inlining.interp).assume_bindings_static)
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
if isexpr(stmt, :meta) && length(stmt.args) 1
push!(meta, stmt)
return nothing
function process_meta!(code::Vector{Any})
meta = Expr[]
for i = 1:length(code)
stmt = code[i]
if isexpr(stmt, :meta) && length(stmt.args) 1
push!(meta, stmt)
code[i] = nothing
end
end
return stmt
return meta
end

function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
Expand Down
12 changes: 8 additions & 4 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,22 +428,26 @@ struct IRCode
cfg::CFG
new_nodes::NewNodeStream
meta::Vector{Expr}
assume_bindings_static::Bool # XXX propagate `interp::AbstractInterpreter` here?

function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState})
return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta)
function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream,
argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState},
assume_bindings_static::Bool=false)
return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta, assume_bindings_static)
end
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
di = ir.debuginfo
@assert di.codelocs === stmts.line
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta)
return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta, ir.assume_bindings_static)
end
global function copy(ir::IRCode)
di = ir.debuginfo
stmts = copy(ir.stmts)
di = copy(di)
di.edges = copy(di.edges)
di.codelocs = stmts.line
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta))
return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg),
copy(ir.new_nodes), copy(ir.meta), ir.assume_bindings_static)
end
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int,
error("")
end
elseif isa(op, GlobalRef)
if !isdefined(op.mod, op.name) || !isconst(op.mod, op.name)
if !(assume_bindings_static(ir) ? isdefined_globalref(op) : isdefinedconst_globalref(op))
@verify_error "Unbound GlobalRef not allowed in value position"
error("")
end
Expand Down

0 comments on commit 9a0a1d2

Please sign in to comment.