Skip to content

Commit

Permalink
optimizer: fix up the inlining algorithm to use correct nargs/isva
Browse files Browse the repository at this point in the history
It appears that inlining.jl was not updated in #54341.
Specifically, using `nargs`/`isva` from `mi.def::Method` in
`ir_prepare_inlining!` causes the following error to occur:
```julia
function generate_lambda_ex(world::UInt, source::LineNumberNode,
                            argnames, spnames, @nospecialize body)
    stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...))
    return stub(world, source, body)
end
function overdubbee54341(a, b)
    return a + b
end
const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1]
function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes)
    if length(fargtypes) != 2
        return generate_lambda_ex(world, source,
            (:overdub54341, :args), (), :(error("Wrong number of arguments")))
    else
        return copy(overdubee_codeinfo54341)
    end
end
@eval function overdub54341(args...)
    $(Expr(:meta, :generated, overdub_generator54341))
    $(Expr(:meta, :generated_only))
end
topfunc(x) = overdub54341(x, 2)
```
```julia
julia> topfunc(1)
Internal error: during type inference of
topfunc(Int64)
Encountered unexpected error in runtime:
BoundsError(a=Array{Any, 1}(dims=(2,), mem=Memory{Any}(8, 0x10632e780)[SSAValue(2), SSAValue(3), #<null>, #<null>, #<null>, #<null>, #<null>, #<null>]), i=(3,))
throw_boundserror at ./essentials.jl:14
getindex at ./essentials.jl:909 [inlined]
ssa_substitute_op! at ./compiler/ssair/inlining.jl:1798
ssa_substitute_op! at ./compiler/ssair/inlining.jl:1852
ir_inline_item! at ./compiler/ssair/inlining.jl:386
...
```

This commit updates the abstract interpretation and inlining algorithm
to use the `nargs`/`isva` values held by `CodeInfo`. Similar
modifications have also been made to EscapeAnalysis.jl.
  • Loading branch information
aviatesk committed Oct 3, 2024
1 parent 77c5875 commit 9059017
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 67 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
effects = Effects(effects; noub=ALWAYS_TRUE)
end
exct = refine_exception_type(result.exct, effects)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects), effects, mi)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects, spec_info(irsv)), effects, mi)
end
end
end
Expand Down
25 changes: 12 additions & 13 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ mutable struct InferenceState
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG
method_info::MethodInfo
spec_info::SpecInfo

#= intermediate states for local abstract interpretation =#
currbb::Int
Expand Down Expand Up @@ -294,7 +294,7 @@ mutable struct InferenceState
sptypes = sptypes_from_meth_instance(mi)
code = src.code::Vector{Any}
cfg = compute_basic_blocks(code)
method_info = MethodInfo(src)
spec_info = SpecInfo(src)

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
Expand Down Expand Up @@ -351,7 +351,7 @@ mutable struct InferenceState
restrict_abstract_call_sites = isa(def, Module)

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
mi, world, mod, sptypes, slottypes, src, cfg, spec_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
tasks, pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
Expand Down Expand Up @@ -791,7 +791,7 @@ end

# TODO add `result::InferenceResult` and put the irinterp result into the inference cache?
mutable struct IRInterpretationState
const method_info::MethodInfo
const spec_info::SpecInfo
const ir::IRCode
const mi::MethodInstance
const world::UInt
Expand All @@ -809,7 +809,7 @@ mutable struct IRInterpretationState
parentid::Int

function IRInterpretationState(interp::AbstractInterpreter,
method_info::MethodInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
world::UInt, min_world::UInt, max_world::UInt)
curridx = 1
given_argtypes = Vector{Any}(undef, length(argtypes))
Expand All @@ -831,7 +831,7 @@ mutable struct IRInterpretationState
tasks = WorkThunk[]
edges = Any[]
callstack = AbsIntState[]
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
return new(spec_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, tasks, edges, callstack, 0, 0)
end
end
Expand All @@ -845,14 +845,13 @@ function IRInterpretationState(interp::AbstractInterpreter,
else
isa(src, CodeInfo) || return nothing
end
method_info = MethodInfo(src)
spec_info = SpecInfo(src)
ir = inflate_ir(src, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
return IRInterpretationState(interp, spec_info, ir, mi, argtypes, world,
codeinst.min_world, codeinst.max_world)
end


# AbsIntState
# ===========

Expand Down Expand Up @@ -927,11 +926,11 @@ is_constproped(::IRInterpretationState) = true
is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL)
is_cached(::IRInterpretationState) = false

method_info(sv::InferenceState) = sv.method_info
method_info(sv::IRInterpretationState) = sv.method_info
spec_info(sv::InferenceState) = sv.spec_info
spec_info(sv::IRInterpretationState) = sv.spec_info

propagate_inbounds(sv::AbsIntState) = method_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_for_inference_limit_heuristics
propagate_inbounds(sv::AbsIntState) = spec_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = spec_info(sv).method_for_inference_limit_heuristics

frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world
Expand Down
14 changes: 7 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,10 @@ function ((; code_cache)::GetNativeEscapeCache)(mi::MethodInstance)
return false
end

function refine_effects!(interp::AbstractInterpreter, sv::PostOptAnalysisState)
function refine_effects!(interp::AbstractInterpreter, opt::OptimizationState, sv::PostOptAnalysisState)
if !is_effect_free(sv.result.ipo_effects) && sv.all_effect_free && !isempty(sv.ea_analysis_pending)
ir = sv.ir
nargs = let def = sv.result.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end
nargs = Int(opt.src.nargs)
estate = EscapeAnalysis.analyze_escapes(ir, nargs, optimizer_lattice(interp), GetNativeEscapeCache(interp))
argescapes = EscapeAnalysis.ArgEscapeCache(estate)
stack_analysis_result!(sv.result, argescapes)
Expand Down Expand Up @@ -939,7 +939,8 @@ function check_inconsistentcy!(sv::PostOptAnalysisState, scanner::BBScanner)
end
end

function ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, result::InferenceResult)
function ipo_dataflow_analysis!(interp::AbstractInterpreter, opt::OptimizationState,
ir::IRCode, result::InferenceResult)
if !is_ipo_dataflow_analysis_profitable(result.ipo_effects)
return false
end
Expand Down Expand Up @@ -967,13 +968,13 @@ function ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, result:
end
end

return refine_effects!(interp, sv)
return refine_effects!(interp, opt, sv)
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt, caller)
ipo_dataflow_analysis!(interp, ir, caller)
@timeit "optimizer" ir = run_passes_ipo_safe(opt.src, opt)
ipo_dataflow_analysis!(interp, opt, ir, caller)
return finish(interp, opt, ir, caller)
end

Expand All @@ -995,7 +996,6 @@ matchpass(::Nothing, _, _) = false
function run_passes_ipo_safe(
ci::CodeInfo,
sv::OptimizationState,
caller::InferenceResult,
optimize_until = nothing, # run all passes by default
)
__stage__ = 0 # used by @pass
Expand Down
68 changes: 37 additions & 31 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct InliningTodo
mi::MethodInstance
# The IR of the inlinee
ir::IRCode
# The SpecInfo for the inlinee
spec_info::SpecInfo
# The DebugInfo table for the inlinee
di::DebugInfo
# If the function being inlined is a single basic block we can use a
Expand All @@ -20,8 +22,8 @@ struct InliningTodo
# Effects of the call statement
effects::Effects
end
function InliningTodo(mi::MethodInstance, (ir, di)::Tuple{IRCode, DebugInfo}, effects::Effects)
return InliningTodo(mi, ir, di, linear_inline_eligible(ir), effects)
function InliningTodo(mi::MethodInstance, ir::IRCode, spec_info::SpecInfo, di::DebugInfo, effects::Effects)
return InliningTodo(mi, ir, spec_info, di, linear_inline_eligible(ir), effects)
end

struct ConstantCase
Expand Down Expand Up @@ -321,7 +323,8 @@ function ir_inline_linetable!(debuginfo::DebugInfoStream, inlinee_debuginfo::Deb
end

function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCode, IncrementalCompact},
ir::IRCode, di::DebugInfo, mi::MethodInstance, inlined_at::NTuple{3,Int32}, argexprs::Vector{Any})
ir::IRCode, spec_info::SpecInfo, di::DebugInfo, mi::MethodInstance,
inlined_at::NTuple{3,Int32}, argexprs::Vector{Any})
def = mi.def::Method
debuginfo = inline_target isa IRCode ? inline_target.debuginfo : inline_target.ir.debuginfo
topline = new_inlined_at = ir_inline_linetable!(debuginfo, di, inlined_at)
Expand All @@ -334,8 +337,8 @@ function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCod
spvals_ssa = insert_node!(
removable_if_unused(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
if def.isva
nargs_def = Int(def.nargs::Int32)
if spec_info.isva
nargs_def = spec_info.nargs
if nargs_def > 0
argexprs = fix_va_argexprs!(insert_node!, inline_target, argexprs, nargs_def, topline)
end
Expand All @@ -362,7 +365,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
item::InliningTodo, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}})
# Ok, do the inlining here
inlined_at = compact.result[idx][:line]
ssa_substitute = ir_prepare_inlining!(InsertHere(compact), compact, item.ir, item.di, item.mi, inlined_at, argexprs)
ssa_substitute = ir_prepare_inlining!(InsertHere(compact), compact, item.ir, item.spec_info, item.di, item.mi, inlined_at, argexprs)
boundscheck = has_flag(compact.result[idx], IR_FLAG_INBOUNDS) ? :off : boundscheck

# If the iterator already moved on to the next basic block,
Expand Down Expand Up @@ -860,15 +863,14 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
if inferred_result isa ConstantCase
add_inlining_backedge!(et, mi)
return inferred_result
end
if inferred_result isa InferredResult
elseif inferred_result isa InferredResult
(; src, effects) = inferred_result
elseif inferred_result isa CodeInstance
src = @atomic :monotonic inferred_result.inferred
effects = decode_effects(inferred_result.ipo_purity_bits)
else
src = nothing
effects = Effects()
else # there is no cached source available, bail out
return compileable_specialization(mi, Effects(), et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
end

# the duplicated check might have been done already within `analyze_method!`, but still
Expand All @@ -883,9 +885,12 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
ir = inferred_result isa CodeInstance ? retrieve_ir_for_inlining(inferred_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
return InliningTodo(mi, ir, effects)
if inferred_result isa CodeInstance
ir, spec_info, debuginfo = retrieve_ir_for_inlining(inferred_result, src)
else
ir, spec_info, debuginfo = retrieve_ir_for_inlining(mi, src, preserve_local_sources)
end
return InliningTodo(mi, ir, spec_info, debuginfo, effects)
end

# the special resolver for :invoke-d call
Expand All @@ -901,23 +906,17 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
if cached_result isa InferredResult
(; src, effects) = cached_result
elseif cached_result isa CodeInstance
src = @atomic :monotonic cached_result.inferred
effects = decode_effects(cached_result.ipo_purity_bits)
else
src = nothing
effects = Effects()
else # there is no cached source available, bail out
return nothing
end

preserve_local_sources = true
src_inlining_policy(state.interp, src, info, flag) || return nothing
ir = cached_result isa CodeInstance ? retrieve_ir_for_inlining(cached_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
ir, spec_info, debuginfo = retrieve_ir_for_inlining(cached_result, src)
add_inlining_backedge!(et, mi)
return InliningTodo(mi, ir, effects)
return InliningTodo(mi, ir, spec_info, debuginfo, effects)
end

function validate_sparams(sparams::SimpleVector)
Expand Down Expand Up @@ -971,22 +970,29 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
return resolve_todo(mi, volatile_inf_result, info, flag, state; invokesig)
end

function retrieve_ir_for_inlining(cached_result::CodeInstance, src::MaybeCompressed)
src = _uncompressed_ir(cached_result, src)::CodeInfo
return inflate_ir!(src, cached_result.def), src.debuginfo
function retrieve_ir_for_inlining(cached_result::CodeInstance, src::String)
src = _uncompressed_ir(cached_result, src)
return inflate_ir!(src, cached_result.def), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(cached_result::CodeInstance, src::CodeInfo)
return inflate_ir!(copy(src), cached_result.def), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo, preserve_local_sources::Bool)
if preserve_local_sources
src = copy(src)
end
return inflate_ir!(src, mi), src.debuginfo
return inflate_ir!(src, mi), SpecInfo(src), src.debuginfo
end
function retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode, preserve_local_sources::Bool)
if preserve_local_sources
ir = copy(ir)
end
# COMBAK this is not correct, we should make `InferenceResult` propagate `SpecInfo`
spec_info = let m = mi.def::Method
SpecInfo(Int(m.nargs), m.isva, false, nothing)
end
ir.debuginfo.def = mi
return ir, DebugInfo(ir.debuginfo, length(ir.stmts))
return ir, spec_info, DebugInfo(ir.debuginfo, length(ir.stmts))
end

function handle_single_case!(todo::Vector{Pair{Int,Any}},
Expand Down Expand Up @@ -1466,8 +1472,8 @@ function semiconcrete_result_item(result::SemiConcreteResult,

add_inlining_backedge!(et, mi)
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
ir = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.effects)
ir, _, debuginfo = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.spec_info, debuginfo, result.effects)
end

function handle_semi_concrete_result!(cases::Vector{InliningCase}, result::SemiConcreteResult,
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
end

src_inlining_policy(inlining.interp, src, info, IR_FLAG_NULL) || return false
src, di = retrieve_ir_for_inlining(code, src)
src, spec_info, di = retrieve_ir_for_inlining(code, src)

# For now: Require finalizer to only have one basic block
length(src.cfg.blocks) == 1 || return false
Expand All @@ -1542,7 +1542,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,

# TODO: Should there be a special line number node for inlined finalizers?
inline_at = ir[SSAValue(idx)][:line]
ssa_substitute = ir_prepare_inlining!(InsertBefore(ir, SSAValue(idx)), ir, src, di, mi, inline_at, argexprs)
ssa_substitute = ir_prepare_inlining!(InsertBefore(ir, SSAValue(idx)), ir, src, spec_info, di, mi, inline_at, argexprs)

# TODO: Use the actual inliner here rather than open coding this special purpose inliner.
ssa_rename = Vector{Any}(undef, length(src.stmts))
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct SemiConcreteResult <: ConstResult
mi::MethodInstance
ir::IRCode
effects::Effects
spec_info::SpecInfo
end

# XXX Technically this does not represent a result of constant inference, but rather that of
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ function typeinf_ircode(interp::AbstractInterpreter, mi::MethodInstance,
end
(; result) = frame
opt = OptimizationState(frame, interp)
ir = run_passes_ipo_safe(opt.src, opt, result, optimize_until)
ir = run_passes_ipo_safe(opt.src, opt, optimize_until)
rt = widenconst(ignorelimited(result.result))
return ir, rt
end
Expand Down
7 changes: 5 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ struct StmtInfo
used::Bool
end

struct MethodInfo
struct SpecInfo
nargs::Int
isva::Bool
propagate_inbounds::Bool
method_for_inference_limit_heuristics::Union{Nothing,Method}
end
MethodInfo(src::CodeInfo) = MethodInfo(
SpecInfo(src::CodeInfo) = SpecInfo(
Int(src.nargs), src.isva,
src.propagate_inbounds,
src.method_for_inference_limit_heuristics::Union{Nothing,Method})

Expand Down
11 changes: 7 additions & 4 deletions test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ CC.get_inference_world(interp::EscapeAnalyzer) = interp.world
CC.get_inference_cache(interp::EscapeAnalyzer) = interp.inf_cache
CC.cache_owner(::EscapeAnalyzer) = EAToken()

function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, ir::IRCode, caller::InferenceResult)
function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, opt::OptimizationState,
ir::IRCode, caller::InferenceResult)
# run EA on all frames that have been optimized
nargs = let def = caller.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end
nargs = Int(opt.src.nargs)
𝕃ₒ = CC.optimizer_lattice(interp)
get_escape_cache = GetEscapeCache(interp)
estate = try
analyze_escapes(ir, nargs, CC.optimizer_lattice(interp), get_escape_cache)
analyze_escapes(ir, nargs, 𝕃ₒ, get_escape_cache)
catch err
@error "error happened within EA, inspect `Main.failed_escapeanalysis`"
Main.failed_escapeanalysis = FailedAnalysis(ir, nargs, get_escape_cache)
Expand All @@ -133,7 +135,8 @@ function CC.ipo_dataflow_analysis!(interp::EscapeAnalyzer, ir::IRCode, caller::I
end
record_escapes!(interp, caller, estate, ir)

@invoke CC.ipo_dataflow_analysis!(interp::AbstractInterpreter, ir::IRCode, caller::InferenceResult)
@invoke CC.ipo_dataflow_analysis!(interp::AbstractInterpreter, opt::OptimizationState,
ir::IRCode, caller::InferenceResult)
end

function record_escapes!(interp::EscapeAnalyzer,
Expand Down
Loading

0 comments on commit 9059017

Please sign in to comment.