From 9a0a1d2d9ef0b0691c42b90c194ab377c05a097a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 19 Apr 2024 16:13:38 +0900 Subject: [PATCH] optimizer: handle `GlobalRef` based on `assume_bindings_static` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Following up #53750. External abstract interpreters with `assume_bindings_static` enabled assume nothrow-ness of defined `GlobalRef`s no matter if they are constant during abstract interpretation. However, post-#53750, the optimizer’s verification switched to stricter checks, causing errors in such interpreters. This commit makes `IRCode` hold `assume_bindings_static::Bool` info to ensure that inlining and verification behaviors are synced with abstract interpretation. --- base/compiler/abstractinterpretation.jl | 4 +--- base/compiler/optimize.jl | 32 ++++++++++++++++--------- base/compiler/ssair/ir.jl | 12 ++++++---- base/compiler/ssair/verify.jl | 2 +- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 46e15d0c3ad79..eda70befb7506 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2836,7 +2836,7 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv:: rt = abstract_eval_globalref_type(g) consistent = inaccessiblememonly = ALWAYS_FALSE nothrow = false - if isa(rt, Const) + if isa(rt, Const) # implies `isdefinedconst_globalref(g) === true` consistent = ALWAYS_TRUE nothrow = true if is_mutation_free_argtype(rt) @@ -2849,8 +2849,6 @@ function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv:: else rt = Union{} end - elseif isdefinedconst_globalref(g) - nothrow = true end return RTEffects(rt, nothrow ? Union{} : UndefVarError, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)) end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index ec206e47103ed..537cdc0a23dc9 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -288,6 +288,9 @@ function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src: return (false, true, true) end +assume_bindings_static(ir::IRCode) = ir.assume_bindings_static +assume_bindings_static(compact::IncrementalCompact) = assume_bindings_static(compact.ir) + # Returns a tuple of `(:consistent, :removable, :nothrow)` flags for a given statement. function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact}) # TODO: We're duplicating analysis from inference here. @@ -298,7 +301,12 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe isa(stmt, GotoNode) && return (true, false, true) isa(stmt, GotoIfNot) && return (true, false, βŠ‘(𝕃ₒ, argextype(stmt.cond, src), Bool)) if isa(stmt, GlobalRef) - nothrow = consistent = isdefinedconst_globalref(stmt) + if assume_bindings_static(src) + nothrow = isdefined_globalref(stmt) + consistent = nothrow & isconst(stmt) + else + nothrow = consistent = isdefinedconst_globalref(stmt) + end return (consistent, nothrow, nothrow) elseif isa(stmt, Expr) (; head, args) = stmt @@ -1129,7 +1137,6 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) # Go through and add an unreachable node after every # Union{} call. Then reindex labels. stmtinfo = sv.stmt_info - meta = Expr[] idx = 1 oldidx = 1 nstmts = length(code) @@ -1229,9 +1236,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) renumber_cfg_stmts!(sv.cfg, blockchangemap) end - for i = 1:length(code) - code[i] = process_meta!(meta, code[i]) - end + meta = process_meta!(code) strip_trailing_junk!(code, ssavaluetypes, ssaflags, di, sv.cfg, stmtinfo) types = Any[] stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags) @@ -1239,15 +1244,20 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) # types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form # and eliminates slots (see below) argtypes = sv.slottypes - return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes) + return IRCode(stmts, sv.cfg, di, argtypes, meta, sv.sptypes, + InferenceParams(sv.inlining.interp).assume_bindings_static) end -function process_meta!(meta::Vector{Expr}, @nospecialize stmt) - if isexpr(stmt, :meta) && length(stmt.args) β‰₯ 1 - push!(meta, stmt) - return nothing +function process_meta!(code::Vector{Any}) + meta = Expr[] + for i = 1:length(code) + stmt = code[i] + if isexpr(stmt, :meta) && length(stmt.args) β‰₯ 1 + push!(meta, stmt) + code[i] = nothing + end end - return stmt + return meta end function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index c665c5bef299e..8083a1e5410df 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -428,14 +428,17 @@ struct IRCode cfg::CFG new_nodes::NewNodeStream meta::Vector{Expr} + assume_bindings_static::Bool # XXX propagate `interp::AbstractInterpreter` here? - function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState}) - return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta) + function IRCode(stmts::InstructionStream, cfg::CFG, debuginfo::DebugInfoStream, + argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{VarState}, + assume_bindings_static::Bool=false) + return new(stmts, argtypes, sptypes, debuginfo, cfg, NewNodeStream(), meta, assume_bindings_static) end function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream) di = ir.debuginfo @assert di.codelocs === stmts.line - return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta) + return new(stmts, ir.argtypes, ir.sptypes, di, cfg, new_nodes, ir.meta, ir.assume_bindings_static) end global function copy(ir::IRCode) di = ir.debuginfo @@ -443,7 +446,8 @@ struct IRCode di = copy(di) di.edges = copy(di.edges) di.codelocs = stmts.line - return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta)) + return new(stmts, copy(ir.argtypes), copy(ir.sptypes), di, copy(ir.cfg), + copy(ir.new_nodes), copy(ir.meta), ir.assume_bindings_static) end end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index bfa5ab82143cb..9e7f6de3ce9fc 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -56,7 +56,7 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, error("") end elseif isa(op, GlobalRef) - if !isdefined(op.mod, op.name) || !isconst(op.mod, op.name) + if !(assume_bindings_static(ir) ? isdefined_globalref(op) : isdefinedconst_globalref(op)) @verify_error "Unbound GlobalRef not allowed in value position" error("") end