From e343279bbc8fc670e7ed7d358cf4ff3127000e50 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 14 Jan 2022 05:21:10 +0900 Subject: [PATCH] optimizer: Julia-level escape analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit ports [EscapeAnalysis.jl](https://github.com/aviatesk/EscapeAnalysis.jl) into Julia base. You can find the documentation of this escape analysis at [this GitHub page](https://aviatesk.github.io/EscapeAnalysis.jl/dev/)[^1]. [^1]: The same documentation will be included into Julia's developer documentation by this commit. This escape analysis will hopefully be an enabling technology for various memory-related optimizations at Julia's high level compilation pipeline. Possible target optimization includes mutable ϕ-node aware SROA, `mutating_arrayfreeze` optimization (#42465), stack allocation of mutables, dead array elimination, finalizer elision and so on[^2]. [^2]: It would be also interesting if LLVM-level optimizations can consume some IPO information derived by this escape analysis to broaden optimization possibilities. The primary motivation for porting EA in this PR is to check its impact on latency as well as to get feedbacks from a broader range of developers. If there are no serious problems, EA will be used by #42465 for `mutating_arrayfreeze` optimization for `ImmutableArray` construction. This commit simply defines EA inside Julia base compiler and enables the existing test suite with it. The plan is to validate EA's accuracy and correctness and to check the latency impact by running EA before the existing SROA pass and checking if it can detect all objects eliminated by the SROA pass as "SROA-eliminatable". --- base/Base.jl | 3 + base/compiler/bootstrap.jl | 2 +- base/compiler/compiler.jl | 2 + base/compiler/optimize.jl | 86 +- base/compiler/ssair/EscapeAnalysis/EAUtils.jl | 403 ++++ .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 1604 +++++++++++++ .../ssair/EscapeAnalysis/disjoint_set.jl | 143 ++ base/compiler/ssair/driver.jl | 6 +- base/compiler/ssair/passes.jl | 13 +- base/compiler/utilities.jl | 4 + base/exports.jl | 1 + doc/make.jl | 1 + doc/src/devdocs/EscapeAnalysis/README.md | 322 +++ doc/src/devdocs/llvm.md | 2 +- .../InteractiveUtils/src/InteractiveUtils.jl | 2 +- stdlib/InteractiveUtils/src/macros.jl | 12 +- test/choosetests.jl | 2 +- .../compiler/EscapeAnalysis/EscapeAnalysis.jl | 2052 +++++++++++++++++ test/compiler/EscapeAnalysis/setup.jl | 89 + test/compiler/irpasses.jl | 2 +- 20 files changed, 4699 insertions(+), 52 deletions(-) create mode 100644 base/compiler/ssair/EscapeAnalysis/EAUtils.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/disjoint_set.jl create mode 100644 doc/src/devdocs/EscapeAnalysis/README.md create mode 100644 test/compiler/EscapeAnalysis/EscapeAnalysis.jl create mode 100644 test/compiler/EscapeAnalysis/setup.jl diff --git a/base/Base.jl b/base/Base.jl index 8fd0780133818..82b1b43822d7b 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -397,6 +397,9 @@ include("util.jl") include("asyncmap.jl") +include("compiler/ssair/EscapeAnalysis/EAUtils.jl") +using .EAUtils: code_escapes + # deprecated functions include("deprecated.jl") diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 2517b181d2804..75ec987656509 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -14,7 +14,7 @@ let fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping - run_passes, + analyze_escapes, run_passes, # then we create caches for inference entries typeinf_ext, typeinf, typeinf_edge, ] diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 5f2f5614ba209..a0a2432de351a 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -95,6 +95,8 @@ ntuple(f, n) = (Any[f(i) for i = 1:n]...,) # core docsystem include("docs/core.jl") +import Core.Compiler.CoreDocs +Core.atdoc!(CoreDocs.docm) # sorting function sort end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 56f23ca7c2b39..d2f61b0064455 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1,5 +1,45 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +############# +# constants # +############# + +# The slot has uses that are not statically dominated by any assignment +# This is implied by `SLOT_USEDUNDEF`. +# If this is not set, all the uses are (statically) dominated by the defs. +# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. +const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) +const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once +const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError +# const SLOT_CALLED = 64 + +# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c + +const IR_FLAG_NULL = 0x00 +# This statement is marked as @inbounds by user. +# Ff replaced by inlining, any contained boundschecks may be removed. +const IR_FLAG_INBOUNDS = 0x01 << 0 +# This statement is marked as @inline by user +const IR_FLAG_INLINE = 0x01 << 1 +# This statement is marked as @noinline by user +const IR_FLAG_NOINLINE = 0x01 << 2 +const IR_FLAG_THROW_BLOCK = 0x01 << 3 +# This statement may be removed if its result is unused. In particular it must +# thus be both pure and effect free. +const IR_FLAG_EFFECT_FREE = 0x01 << 4 + +# known to be always effect-free (in particular nothrow) +const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields] + +# known to be effect-free if the are nothrow +const _PURE_OR_ERROR_BUILTINS = [ + fieldtype, apply_type, isa, UnionAll, + getfield, arrayref, const_arrayref, arraysize, isdefined, Core.sizeof, + Core.kwfunc, Core.ifelse, Core._typevar, (<:) +] + +const TOP_TUPLE = GlobalRef(Core, :tuple) + ##################### # OptimizationState # ##################### @@ -52,7 +92,10 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f return nothing end +function argextype end # imported by EscapeAnalysis include("compiler/ssair/driver.jl") +using .EscapeAnalysis +import .EscapeAnalysis: EscapeState mutable struct OptimizationState linfo::MethodInstance @@ -121,46 +164,6 @@ function ir_to_codeinf!(opt::OptimizationState) return src end -############# -# constants # -############# - -# The slot has uses that are not statically dominated by any assignment -# This is implied by `SLOT_USEDUNDEF`. -# If this is not set, all the uses are (statically) dominated by the defs. -# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. -const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) -const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once -const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError -# const SLOT_CALLED = 64 - -# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c - -const IR_FLAG_NULL = 0x00 -# This statement is marked as @inbounds by user. -# Ff replaced by inlining, any contained boundschecks may be removed. -const IR_FLAG_INBOUNDS = 0x01 << 0 -# This statement is marked as @inline by user -const IR_FLAG_INLINE = 0x01 << 1 -# This statement is marked as @noinline by user -const IR_FLAG_NOINLINE = 0x01 << 2 -const IR_FLAG_THROW_BLOCK = 0x01 << 3 -# This statement may be removed if its result is unused. In particular it must -# thus be both pure and effect free. -const IR_FLAG_EFFECT_FREE = 0x01 << 4 - -# known to be always effect-free (in particular nothrow) -const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields] - -# known to be effect-free if the are nothrow -const _PURE_OR_ERROR_BUILTINS = [ - fieldtype, apply_type, isa, UnionAll, - getfield, arrayref, const_arrayref, arraysize, isdefined, Core.sizeof, - Core.kwfunc, Core.ifelse, Core._typevar, (<:) -] - -const TOP_TUPLE = GlobalRef(Core, :tuple) - ######### # logic # ######### @@ -514,7 +517,8 @@ function run_passes(ci::CodeInfo, sv::OptimizationState) @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) - @timeit "SROA" ir = sroa_pass!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + @timeit "SROA" ir = sroa_pass!(ir, nargs) @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EAUtils.jl b/base/compiler/ssair/EscapeAnalysis/EAUtils.jl new file mode 100644 index 0000000000000..490834d4df123 --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EAUtils.jl @@ -0,0 +1,403 @@ +const EA_AS_PKG = Symbol(@__MODULE__) !== :Base # develop EA as an external package + +module EAUtils + +import ..EA_AS_PKG +if EA_AS_PKG + import ..EscapeAnalysis +else + import Core.Compiler.EscapeAnalysis: EscapeAnalysis + Base.getindex(estate::EscapeAnalysis.EscapeState, @nospecialize(x)) = + Core.Compiler.getindex(estate, x) +end +const EA = EscapeAnalysis +const CC = Core.Compiler + +# entries +# ------- + +@static if EA_AS_PKG +import InteractiveUtils: gen_call_with_extracted_types_and_kwargs + +@doc """ + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escapes interp=myinterp myfunc(myargs...)`. +""" +macro code_escapes(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) +end +end # @static if EA_AS_PKG + +""" + code_escapes(f, argtypes=Tuple{}; [world], [interp]) -> result::EscapeResult + code_escapes(tt::Type{<:Tuple}; [world], [interp]) -> result::EscapeResult + +Runs the escape analysis on optimized IR of a generic function call with the given type signature. +Note that the escape analysis runs after inlining, but before any other optimizations. + +```julia +julia> mutable struct SafeRef{T} + x::T + end + +julia> Base.getindex(x::SafeRef) = x.x; + +julia> Base.isassigned(x::SafeRef) = true; + +julia> get′(x) = isassigned(x) ? x[] : throw(x); + +julia> result = code_escapes((String,String,String)) do s1, s2, s3 + r1 = Ref(s1) + r2 = Ref(s2) + r3 = SafeRef(s3) + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global g = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + s3 = get′(r3) # still `r2` doesn't escape fully + return s2, s3 + end +#3(X _2::String, ↑ _3::String, ↑ _4::String) in Main at REPL[7]:2 +2 X 1 ── %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 *′ │ %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref +4 ✓′ └─── %3 = %new(SafeRef{String}, _4)::SafeRef{String} │╻╷ SafeRef +5 ◌ 2 ── %4 = \$(Expr(:enter, #8)) │ + ✓′ │ %5 = ϒ (%3)::SafeRef{String} │ + *′ └─── %6 = ϒ (%2)::Base.RefValue{String} │ +6 ◌ 3 ── %7 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %7 ││ + X 4 ── Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ +7 ◌ 6 ── nothing::typeof(Core.sizeof) │╻ sizeof + ◌ │ nothing::Int64 ││ + ◌ └─── \$(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + ✓′ 8 ── %17 = φᶜ (%5)::SafeRef{String} │ + *′ │ %18 = φᶜ (%6)::Base.RefValue{String} │ + ◌ └─── \$(Expr(:leave, 1)) │ + X 9 ── %20 = \$(Expr(:the_exception))::Any │ +9 ◌ │ (Main.g = %20)::Any │ + ◌ └─── \$(Expr(:pop_exception, :(%4)))::Any │ +11 ✓′ 10 ┄ %23 = φ (#7 => %3, #9 => %17)::SafeRef{String} │ + *′ │ %24 = φ (#7 => %2, #9 => %18)::Base.RefValue{String} │ + ◌ │ %25 = Base.isdefined(%24, :x)::Bool ││╻ isassigned + ◌ └─── goto #12 if not %25 ││ + ↑ 11 ─ %27 = Base.getfield(%24, :x)::String │││╻ getproperty + ◌ └─── goto #13 ││ + ◌ 12 ─ Main.throw(%24)::Union{} ││ + ◌ └─── unreachable ││ +12 ↑ 13 ─ %31 = Base.getfield(%23, :x)::String │╻╷╷ get′ +13 ↑ │ %32 = Core.tuple(%27, %31)::Tuple{String, String} │ + ◌ └─── return %32 │ +``` + +The symbols in the side of each call argument and SSA statements represents the following meaning: +- `◌`: this value is not analyzed because escape information of it won't be used anyway (when the object is `isbitstype` for example) +- `✓`: this value never escapes (`has_no_escape(result.state[x])` holds) +- `↑`: this value can escape to the caller via return (`has_return_escape(result.state[x])` holds) +- `X`: this value can escape to somewhere the escape analysis can't reason about like escapes to a global memory (`has_all_escape(result.state[x])` holds) +- `*`: this value's escape state is between the `ReturnEscape` and `AllEscape` in the `EscapeInfo`, e.g. it has unhandled `ThrownEscape` +and additional `′` indicates that field analysis has been done successfully on that value. + +For testing, escape information of each call argument and SSA value can be inspected programmatically as like: +```julia +julia> result.state[Core.Argument(3)] +ReturnEscape + +julia> result.state[Core.SSAValue(3)] +NoEscape′ +``` +""" +function code_escapes(@nospecialize(args...); + world = get_world_counter(), + interp = Core.Compiler.NativeInterpreter(world)) + interp = EscapeAnalyzer(interp) + results = code_typed(args...; optimize=true, world, interp) + isone(length(results)) || throw(ArgumentError("`code_escapes` only supports single analysis result")) + return EscapeResult(interp.ir, interp.state, interp.linfo) +end + +# AbstractInterpreter +# ------------------- + +# imports +import .CC: + AbstractInterpreter, + NativeInterpreter, + WorldView, + WorldRange, + InferenceParams, + OptimizationParams, + get_world_counter, + get_inference_cache, + lock_mi_inference, + unlock_mi_inference, + add_remark!, + may_optimize, + may_compress, + may_discard_trees, + verbose_stmt_info, + code_cache, + @timeit, + get_inference_cache, + convert_to_ircode, + slot2reg, + compact!, + ssa_inlining_pass!, + sroa_pass!, + adce_pass!, + type_lift_pass!, + JLOptions, + verify_ir, + verify_linetable +# usings +import Core: + CodeInstance, MethodInstance, CodeInfo +import .CC: + OptimizationState, IRCode +import .EA: + analyze_escapes, cache_escapes! + +mutable struct EscapeAnalyzer{State} <: AbstractInterpreter + native::NativeInterpreter + ir::IRCode + state::State + linfo::MethodInstance + EscapeAnalyzer(native::NativeInterpreter) = new{EscapeState}(native) +end + +CC.InferenceParams(interp::EscapeAnalyzer) = InferenceParams(interp.native) +CC.OptimizationParams(interp::EscapeAnalyzer) = OptimizationParams(interp.native) +CC.get_world_counter(interp::EscapeAnalyzer) = get_world_counter(interp.native) + +CC.lock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing +CC.unlock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing + +CC.add_remark!(interp::EscapeAnalyzer, sv, s) = add_remark!(interp.native, sv, s) + +CC.may_optimize(interp::EscapeAnalyzer) = may_optimize(interp.native) +CC.may_compress(interp::EscapeAnalyzer) = may_compress(interp.native) +CC.may_discard_trees(interp::EscapeAnalyzer) = may_discard_trees(interp.native) +CC.verbose_stmt_info(interp::EscapeAnalyzer) = verbose_stmt_info(interp.native) + +CC.get_inference_cache(interp::EscapeAnalyzer) = get_inference_cache(interp.native) + +const GLOBAL_CODE_CACHE = IdDict{MethodInstance,CodeInstance}() +__clear_code_cache!() = empty!(GLOBAL_CODE_CACHE) + +function CC.code_cache(interp::EscapeAnalyzer) + worlds = WorldRange(get_world_counter(interp)) + return WorldView(GlobalCache(), worlds) +end + +struct GlobalCache end + +CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_CODE_CACHE, mi) + +CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_CODE_CACHE, mi, default) + +CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_CODE_CACHE, mi) + +function CC.setindex!(wvc::WorldView{GlobalCache}, ci::CodeInstance, mi::MethodInstance) + GLOBAL_CODE_CACHE[mi] = ci + add_callback!(mi) # register the callback on invalidation + return nothing +end + +function add_callback!(linfo) + if !isdefined(linfo, :callbacks) + linfo.callbacks = Any[invalidate_cache!] + else + if !any(@nospecialize(cb)->cb===invalidate_cache!, linfo.callbacks) + push!(linfo.callbacks, invalidate_cache!) + end + end + return nothing +end + +function invalidate_cache!(replaced, max_world, depth = 0) + delete!(GLOBAL_CODE_CACHE, replaced) + + if isdefined(replaced, :backedges) + for mi in replaced.backedges + mi = mi::MethodInstance + if !haskey(GLOBAL_CODE_CACHE, mi) + continue # otherwise fall into infinite loop + end + invalidate_cache!(mi, max_world, depth+1) + end + end + return nothing +end + +function CC.optimize(interp::EscapeAnalyzer, opt::OptimizationState, params::OptimizationParams, @nospecialize(result)) + ir = run_passes_with_ea(interp, opt.src, opt) + return CC.finish(interp, opt, params, ir, result) +end + +function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::OptimizationState) + @timeit "convert" ir = convert_to_ircode(ci, sv) + @timeit "slot2reg" ir = slot2reg(ir, ci, sv) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + @timeit "compact 1" ir = compact!(ir) + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + @timeit "compact 2" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + local state + try + @timeit "collect escape information" state = analyze_escapes(ir, nargs) + catch err + @info "error happened within `analyze_escapes`, insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + cacheir = Core.Compiler.copy(ir) + # cache this result + cache_escapes!(sv.linfo, state, cacheir) + # return back the result + interp.ir = cacheir + interp.state = state + interp.linfo = sv.linfo + @timeit "SROA" ir = sroa_pass!(ir, nargs) + @timeit "ADCE" ir = adce_pass!(ir) + @timeit "type lift" ir = type_lift_pass!(ir) + @timeit "compact 3" ir = compact!(ir) + if JLOptions().debug_level == 2 + @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) + end + return ir +end + +# printing +# -------- + +import Core: Argument, SSAValue +import .CC: widenconst, singleton_type +import .EA: EscapeInfo, EscapeState, ⊑, ⊏ + +# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.) +__clear_caches!() = (__clear_code_cache!(); EA.__clear_escape_cache!()) + +function get_name_color(x::EscapeInfo, symbol::Bool = false) + getname(x) = string(nameof(x)) + if x === EA.⊥ + name, color = (getname(EA.NotAnalyzed), "◌"), :plain + elseif EA.has_no_escape(x) + name, color = (getname(EA.NoEscape), "✓"), :green + elseif EA.has_all_escape(x) + name, color = (getname(EA.AllEscape), "X"), :red + elseif EA.NoEscape() ⊏ (EA.ignore_thrownescapes ∘ EA.ignore_aliasinfo)(x) ⊑ EA.AllReturnEscape() + name = (getname(EA.ReturnEscape), "↑") + color = EA.has_thrown_escape(x) ? :yellow : :cyan + else + name = (nothing, "*") + color = EA.has_thrown_escape(x) ? :yellow : :bold + end + name = symbol ? last(name) : first(name) + if name !== nothing && !isa(x.AliasInfo, Bool) + name = string(name, "′") + end + return name, color +end + +# pcs = sprint(show, collect(x.EscapeSites); context=:limit=>true) +function Base.show(io::IO, x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + Base.@invoke show(io::IO, x::Any) + else + printstyled(io, name; color) + end +end +function Base.show(io::IO, ::MIME"application/prs.juno.inline", x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + return x # use fancy tree-view + else + printstyled(io, name; color) + end +end + +struct EscapeResult + ir::IRCode + state::EscapeState + linfo::Union{Nothing,MethodInstance} + EscapeResult(ir::IRCode, state::EscapeState, linfo::Union{Nothing,MethodInstance} = nothing) = + new(ir, state, linfo) +end +Base.show(io::IO, result::EscapeResult) = print_with_info(io, result.ir, result.state, result.linfo) +@eval Base.iterate(res::EscapeResult, state=1) = + return state > $(fieldcount(EscapeResult)) ? nothing : (getfield(res, state), state+1) + +# adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897 +function print_with_info(io::IO, + ir::IRCode, state::EscapeState, linfo::Union{Nothing,MethodInstance}) + # print escape information on SSA values + function preprint(io::IO) + ft = ir.argtypes[1] + f = singleton_type(ft) + if f === nothing + f = widenconst(ft) + end + print(io, f, '(') + for i in 1:state.nargs + arg = state[Argument(i)] + i == 1 && continue + c, color = get_name_color(arg, true) + printstyled(io, c, ' ', '_', i, "::", ir.argtypes[i]; color) + i ≠ state.nargs && print(io, ", ") + end + print(io, ')') + if !isnothing(linfo) + def = linfo.def + printstyled(io, " in ", (isa(def, Module) ? (def,) : (def.module, " at ", def.file, ':', def.line))...; color=:bold) + end + println(io) + end + + # print escape information on SSA values + # nd = ndigits(length(ssavalues)) + function preprint(io::IO, idx::Int) + c, color = get_name_color(state[SSAValue(idx)], true) + # printstyled(io, lpad(idx, nd), ' ', c, ' '; color) + printstyled(io, rpad(c, 2), ' '; color) + end + + print_with_info(preprint, (args...)->nothing, io, ir) +end + +function print_with_info(preprint, postprint, io::IO, ir::IRCode) + io = IOContext(io, :displaysize=>displaysize(io)) + used = Base.IRShow.stmts_used(io, ir) + # line_info_preprinter = Base.IRShow.lineinfo_disabled + line_info_preprinter = function (io::IO, indent::String, idx::Int) + r = Base.IRShow.inline_linfo_printer(ir)(io, indent, idx) + idx ≠ 0 && preprint(io, idx) + return r + end + line_info_postprinter = Base.IRShow.default_expr_type_printer + preprint(io) + bb_idx_prev = bb_idx = 1 + for idx = 1:length(ir.stmts) + preprint(io, idx) + bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, used, ir.cfg, bb_idx) + postprint(io, idx, bb_idx != bb_idx_prev) + bb_idx_prev = bb_idx + end + max_bb_idx_size = ndigits(length(ir.cfg.blocks)) + line_info_preprinter(io, " "^(max_bb_idx_size + 2), 0) + postprint(io) + return nothing +end + +end # module EAUtils diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 0000000000000..8b5c3c2d3fbde --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,1604 @@ +baremodule EscapeAnalysis + +export + analyze_escapes, + cache_escapes!, + getaliases, + isaliased, + has_no_escape, + has_arg_escape, + has_return_escape, + has_thrown_escape, + has_all_escape + +const _TOP_MOD = ccall(:jl_base_relative_to, Any, (Any,), EscapeAnalysis)::Module + +# imports +import ._TOP_MOD: ==, getindex, setindex! +# usings +import Core: + MethodInstance, Const, Argument, SSAValue, PiNode, PhiNode, UpsilonNode, PhiCNode, + ReturnNode, GotoNode, GotoIfNot, SimpleVector, sizeof, ifelse, arrayset, arrayref, + arraysize +import ._TOP_MOD: # Base definitions + @__MODULE__, @eval, @assert, @nospecialize, @inbounds, @inline, @noinline, @label, @goto, + !, !==, !=, ≠, +, -, ≤, <, ≥, >, &, |, include, error, missing, copy, + Vector, BitSet, IdDict, IdSet, UnitRange, ∪, ⊆, ∩, :, ∈, ∉, in, length, get, first, last, + isempty, isassigned, pop!, push!, pushfirst!, empty!, max, min, Csize_t +import Core.Compiler: # Core.Compiler specific definitions + isbitstype, isexpr, is_meta_expr_head, println, + IRCode, IR_FLAG_EFFECT_FREE, widenconst, argextype, singleton_type, fieldcount_noerror, + try_compute_field, try_compute_fieldidx, hasintersect, ⊑ as ⊑ₜ, intrinsic_nothrow, + array_builtin_common_typecheck, arrayset_typecheck, setfield!_nothrow, compute_trycatch + +if _TOP_MOD !== Core.Compiler + include(@__MODULE__, "disjoint_set.jl") +else + include(@__MODULE__, "compiler/ssair/EscapeAnalysis/disjoint_set.jl") +end + +const AInfo = BitSet # XXX better to be IdSet{Int}? +struct Indexable + array::Bool + infos::Vector{AInfo} +end +struct Unindexable + array::Bool + info::AInfo +end +function merge_to_unindexable(info::AInfo, infos::Vector{AInfo}) + for i = 1:length(infos) + info = info ∪ infos[i] + end + return info +end +merge_to_unindexable(infos::Vector{AInfo}) = merge_to_unindexable(AInfo(), infos) + +const LivenessSet = BitSet + +""" + x::EscapeInfo + +A lattice for escape information, which holds the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::Bool`: indicates `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements numbers where `x` can be thrown as exception: + this information will be used by `escape_exception!` to propagate potential escapes via exception +- `x.AliasInfo::Union{Indexable,Unindexable,Bool}`: maintains all possible values + that can be aliased to fields or array elements of `x`: + * `x.AliasInfo === false` indicates the fields/elements of `x` isn't analyzed yet + * `x.AliasInfo === true` indicates the fields/elements of `x` can't be analyzed, + e.g. the type of `x` is not known or is not concrete and thus its fields/elements + can't be known precisely + * `x.AliasInfo::Indexable` records all the possible values that can be aliased to fields/elements of `x` with precise index information + * `x.AliasInfo::Unindexable` records all the possible values that can be aliased to fields/elements of `x` without precise index information +- `x.Liveness::BitSet`: records SSA statement numbers where `x` should be live, e.g. + to be used as a call argument, to be returned to a caller, or preserved for `:foreigncall`. + `0 ∈ x.Liveness` has the special meaning that it's a call argument of the currently analyzed + call frame (and thus it's visible from the caller immediately). +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + * `-1` : no escape + * `0` : unknown or multiple + * `n` : through argument N + +There are utility constructors to create common `EscapeInfo`s, e.g., +- `NoEscape()`: the bottom(-like) element of this lattice, meaning it won't escape to anywhere +- `AllEscape()`: the topmost element of this lattice, meaning it will escape to everywhere + +`analyze_escapes` will transition these elements from the bottom to the top, +in the same direction as Julia's native type inference routine. +An abstract state will be initialized with the bottom(-like) elements: +- the call arguments are initialized as `ArgEscape()`, whose `Liveness` property includes `0` + to indicate that it is passed as a call argument and visible from a caller immediately +- the other states are initialized as `NotAnalyzed()`, which is a special lattice element that + is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning + other than it's not analyzed yet (thus it's not formally part of the lattice) +""" +struct EscapeInfo + Analyzed::Bool + ReturnEscape::Bool + ThrownEscape::LivenessSet + AliasInfo #::Union{Indexable,Unindexable,Bool} + Liveness::LivenessSet + # TODO: ArgEscape::Int + + function EscapeInfo( + Analyzed::Bool, + ReturnEscape::Bool, + ThrownEscape::LivenessSet, + AliasInfo#=::Union{Indexable,Unindexable,Bool}=#, + Liveness::LivenessSet, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end + function EscapeInfo( + x::EscapeInfo, + # non-concrete fields should be passed as default arguments + # in order to avoid allocating non-concrete `NamedTuple`s + AliasInfo#=::Union{Indexable,Unindexable,Bool}=# = x.AliasInfo; + Analyzed::Bool = x.Analyzed, + ReturnEscape::Bool = x.ReturnEscape, + ThrownEscape::LivenessSet = x.ThrownEscape, + Liveness::LivenessSet = x.Liveness, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end +end + +# precomputed default values in order to eliminate computations at each callsite +const BOT_THROWN_ESCAPE = LivenessSet() +const TOP_THROWN_ESCAPE = LivenessSet(1:100_000) + +const BOT_ALIAS_INFO = false +const TOP_ALIAS_INFO = true + +const BOT_LIVENESS = LivenessSet() +const TOP_LIVENESS = LivenessSet(0:100_000) +const ARG_LIVENESS = LivenessSet(0) + +# the constructors +NotAnalyzed() = EscapeInfo(false, false, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, BOT_LIVENESS) # not formally part of the lattice +NoEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, BOT_LIVENESS) +ArgEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, TOP_ALIAS_INFO, ARG_LIVENESS) # TODO allow interprocedural alias analysis? +ReturnEscape(pc::Int) = EscapeInfo(true, true, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, LivenessSet(pc)) +AllReturnEscape() = EscapeInfo(true, true, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, TOP_LIVENESS) +ThrownEscape(pc::Int) = EscapeInfo(true, false, LivenessSet(pc), BOT_ALIAS_INFO, BOT_LIVENESS) +AllEscape() = EscapeInfo(true, true, TOP_THROWN_ESCAPE, TOP_ALIAS_INFO, TOP_LIVENESS) + +const ⊥, ⊤ = NotAnalyzed(), AllEscape() + +# Convenience names for some ⊑ queries +has_no_escape(x::EscapeInfo) = !x.ReturnEscape && isempty(x.ThrownEscape) +has_arg_escape(x::EscapeInfo) = 0 in x.Liveness +has_return_escape(x::EscapeInfo) = x.ReturnEscape +has_return_escape(x::EscapeInfo, pc::Int) = x.ReturnEscape && pc in x.Liveness +has_thrown_escape(x::EscapeInfo) = !isempty(x.ThrownEscape) +has_thrown_escape(x::EscapeInfo, pc::Int) = pc in x.ThrownEscape +has_all_escape(x::EscapeInfo) = ⊤ ⊑ x + +# utility lattice constructors +ignore_thrownescapes(x::EscapeInfo) = EscapeInfo(x; ThrownEscape=BOT_THROWN_ESCAPE) +ignore_aliasinfo(x::EscapeInfo) = EscapeInfo(x, BOT_ALIAS_INFO) +ignore_liveness(x::EscapeInfo) = EscapeInfo(x; Liveness=BOT_LIVENESS) + +# we need to make sure this `==` operator corresponds to lattice equality rather than object equality, +# otherwise `propagate_changes` can't detect the convergence +x::EscapeInfo == y::EscapeInfo = begin + # fast pass: better to avoid top comparison + x === y && return true + x.Analyzed === y.Analyzed || return false + x.ReturnEscape === y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt === TOP_THROWN_ESCAPE || return false + elseif yt === TOP_THROWN_ESCAPE + return false # x.ThrownEscape === TOP_THROWN_ESCAPE + else + xt == yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa === ya || return false + elseif isa(xa, Indexable) + isa(ya, Indexable) || return false + xa.array === ya.array || return false + xa.infos == ya.infos || return false + else + xa = xa::Unindexable + isa(ya, Unindexable) || return false + xa.array === ya.array || return false + xa.info == ya.info || return false + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl === TOP_LIVENESS || return false + elseif yl === TOP_LIVENESS + return false # x.Liveness === TOP_LIVENESS + else + xl == yl || return false + end + return true +end + +""" + x::EscapeInfo ⊑ y::EscapeInfo -> Bool + +The non-strict partial order over `EscapeInfo`. +""" +x::EscapeInfo ⊑ y::EscapeInfo = begin + # fast pass: better to avoid top comparison + if y === ⊤ + return true + elseif x === ⊤ + return false # return y === ⊤ + elseif x === ⊥ + return true + elseif y === ⊥ + return false # return x === ⊥ + end + x.Analyzed ≤ y.Analyzed || return false + x.ReturnEscape ≤ y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt !== TOP_THROWN_ESCAPE && return false + elseif yt !== TOP_THROWN_ESCAPE + xt ⊆ yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa && ya !== true && return false + elseif isa(xa, Indexable) + if isa(ya, Indexable) + xa.array === ya.array || return false + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + xn > yn && return false + for i in 1:xn + xinfos[i] ⊆ yinfos[i] || return false + end + elseif isa(ya, Unindexable) + xa.array === ya.array || return false + xinfos, yinfo = xa.infos, ya.info + for i = length(xf) + xinfos[i] ⊆ yinfo || return false + end + else + ya === true || return false + end + else + xa = xa::Unindexable + if isa(ya, Unindexable) + xa.array === ya.array || return false + xinfo, yinfo = xa.info, ya.info + xinfo ⊆ yinfo || return false + else + ya === true || return false + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl !== TOP_LIVENESS && return false + elseif yl !== TOP_LIVENESS + xl ⊆ yl || return false + end + return true +end + +""" + x::EscapeInfo ⊏ y::EscapeInfo -> Bool + +The strict partial order over `EscapeInfo`. +This is defined as the irreflexive kernel of `⊏`. +""" +x::EscapeInfo ⊏ y::EscapeInfo = x ⊑ y && !(y ⊑ x) + +""" + x::EscapeInfo ⋤ y::EscapeInfo -> Bool + +This order could be used as a slightly more efficient version of the strict order `⊏`, +where we can safely assume `x ⊑ y` holds. +""" +x::EscapeInfo ⋤ y::EscapeInfo = !(y ⊑ x) + +""" + x::EscapeInfo ⊔ y::EscapeInfo -> EscapeInfo + +Computes the join of `x` and `y` in the partial order defined by `EscapeInfo`. +""" +x::EscapeInfo ⊔ y::EscapeInfo = begin + # fast pass: better to avoid top join + if x === ⊤ || y === ⊤ + return ⊤ + elseif x === ⊥ + return y + elseif y === ⊥ + return x + end + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE || yt === TOP_THROWN_ESCAPE + ThrownEscape = TOP_THROWN_ESCAPE + elseif xt === BOT_THROWN_ESCAPE + ThrownEscape = yt + elseif yt === BOT_THROWN_ESCAPE + ThrownEscape = xt + else + ThrownEscape = xt ∪ yt + end + xa, ya = x.AliasInfo, y.AliasInfo + if xa === true || ya === true + AliasInfo = true + elseif xa === false + AliasInfo = ya + elseif ya === false + AliasInfo = xa + elseif isa(xa, Indexable) + if isa(ya, Indexable) && xa.array === ya.array + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + nmax, nmin = max(xn, yn), min(xn, yn) + infos = Vector{AInfo}(undef, nmax) + for i in 1:nmax + if i > nmin + infos[i] = (xn > yn ? xinfos : yinfos)[i] + else + infos[i] = xinfos[i] ∪ yinfos[i] + end + end + AliasInfo = Indexable(xa.array, infos) + elseif isa(ya, Unindexable) && xa.array === ya.array + xinfos, yinfo = xa.infos, ya.info + info = merge_to_unindexable(yinfo, xinfos) + AliasInfo = Unindexable(xa.array, info) + else + AliasInfo = true # handle conflicting case conservatively + end + else + xa = xa::Unindexable + if isa(ya, Indexable) && xa.array === ya.array + xinfo, yinfos = xa.info, ya.infos + info = merge_to_unindexable(xinfo, yinfos) + AliasInfo = Unindexable(xa.array, info) + elseif isa(ya, Unindexable) && xa.array === ya.array + xinfo, yinfo = xa.info, ya.info + info = xinfo ∪ yinfo + AliasInfo = Unindexable(xa.array, info) + else + AliasInfo = true # handle conflicting case conservatively + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS || yl === TOP_LIVENESS + Liveness = TOP_LIVENESS + elseif xl === BOT_LIVENESS + Liveness = yl + elseif yl === BOT_LIVENESS + Liveness = xl + else + Liveness = xl ∪ yl + end + return EscapeInfo( + x.Analyzed | y.Analyzed, + x.ReturnEscape | y.ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) +end + +# TODO setup a more effient struct for cache +# which can discard escape information on SSS values and arguments that don't join dispatch signature + +const AliasSet = IntDisjointSet{Int} + +""" + estate::EscapeState + +Extended lattice that maps arguments and SSA values to escape information represented as `EscapeInfo`. +Escape information imposed on SSA IR element `x` can be retrieved by `estate[x]`. +""" +struct EscapeState + escapes::Vector{EscapeInfo} + aliasset::AliasSet + nargs::Int +end +function EscapeState(nargs::Int, nstmts::Int) + escapes = EscapeInfo[ + 1 ≤ i ≤ nargs ? ArgEscape() : ⊥ for i in 1:(nargs+nstmts)] + aliaset = AliasSet(nargs+nstmts) + return EscapeState(escapes, aliaset, nargs) +end +function getindex(estate::EscapeState, @nospecialize(x)) + if isa(x, Argument) || isa(x, SSAValue) + return estate.escapes[iridx(x, estate)] + else + return nothing + end +end +function setindex!(estate::EscapeState, v::EscapeInfo, @nospecialize(x)) + if isa(x, Argument) || isa(x, SSAValue) + estate.escapes[iridx(x, estate)] = v + end + return estate +end + +""" + iridx(x, estate::EscapeState) -> xidx::Union{Int,Nothing} + +Tries to convert analyzable IR element `x::Union{Argument,SSAValue}` to +its unique identifier number `xidx` that is valid in the analysis context of `estate`. +Returns `nothing` if `x` isn't maintained by `estate` and thus unanalyzable (e.g. `x::GlobalRef`). + +`irval` is the inverse function of `iridx` (not formally), i.e. +`irval(iridx(x::Union{Argument,SSAValue}, state), state) === x`. +""" +function iridx(@nospecialize(x), estate::EscapeState) + if isa(x, Argument) + xidx = x.n + @assert 1 ≤ xidx ≤ estate.nargs "invalid Argument" + elseif isa(x, SSAValue) + xidx = x.id + estate.nargs + else + return nothing + end + return xidx +end + +""" + irval(xidx::Int, estate::EscapeState) -> x::Union{Argument,SSAValue} + +Converts its unique identifier number `xidx` to the original IR element `x::Union{Argument,SSAValue}` +that is analyzable in the context of `estate`. + +`iridx` is the inverse function of `irval` (not formally), i.e. +`iridx(irval(xidx, state), state) === xidx`. +""" +function irval(xidx::Int, estate::EscapeState) + x = xidx > estate.nargs ? SSAValue(xidx-estate.nargs) : Argument(xidx) + return x +end + +function getaliases(x::Union{Argument,SSAValue}, estate::EscapeState) + xidx = iridx(x, estate) + aliases = getaliases(xidx, estate) + aliases === nothing && return nothing + return Union{Argument,SSAValue}[irval(aidx, estate) for aidx in aliases] +end +function getaliases(xidx::Int, estate::EscapeState) + aliasset = estate.aliasset + root = find_root!(aliasset, xidx) + if xidx ≠ root || aliasset.ranks[xidx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Int[] + for aidx in 1:length(aliasset.parents) + if aliasset.parents[aidx] == root + push!(aliases, aidx) + end + end + return aliases + else + return nothing + end +end + +isaliased(x::Union{Argument,SSAValue}, y::Union{Argument,SSAValue}, estate::EscapeState) = + isaliased(iridx(x, estate), iridx(y, estate), estate) +isaliased(xidx::Int, yidx::Int, estate::EscapeState) = + in_same_set(estate.aliasset, xidx, yidx) + +""" + ArgEscapeInfo(x::EscapeInfo) -> x′::ArgEscapeInfo + +The data structure for caching `x::EscapeInfo` for interprocedural propagation, +which is slightly more efficient than the original `x::EscapeInfo` object. +""" +struct ArgEscapeInfo + AllEscape::Bool + ReturnEscape::Bool + ThrownEscape::Bool + function ArgEscapeInfo(x::EscapeInfo) + x === ⊤ && return new(true, true, true) + ThrownEscape = isempty(x.ThrownEscape) ? false : true + return new(false, x.ReturnEscape, ThrownEscape) + end +end + +""" + cache_escapes!(linfo::MethodInstance, estate::EscapeState, _::IRCode) + +Transforms escape information of `estate` for interprocedural propagation, +and caches it in a global cache that can then be looked up later when +`linfo` callsite is seen again. +""" +function cache_escapes! end + +# when working outside of Core.Compiler, cache as much as information for later inspection and debugging +if _TOP_MOD !== Core.Compiler + struct EscapeCache + cache::Vector{ArgEscapeInfo} + state::EscapeState # preserved just for debugging purpose + ir::IRCode # preserved just for debugging purpose + end + const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,EscapeCache}() + function cache_escapes!(linfo::MethodInstance, estate::EscapeState, cacheir::IRCode) + cache = EscapeCache(to_interprocedural(estate), estate, cacheir) + GLOBAL_ESCAPE_CACHE[linfo] = cache + return cache + end + argescapes_from_cache(cache::EscapeCache) = cache.cache +else + const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,Vector{ArgEscapeInfo}}() + function cache_escapes!(linfo::MethodInstance, estate::EscapeState, _::IRCode) + cache = to_interprocedural(estate) + GLOBAL_ESCAPE_CACHE[linfo] = cache + return cache + end + argescapes_from_cache(cache::Vector{ArgEscapeInfo}) = cache +end + +function to_interprocedural(estate::EscapeState) + cache = Vector{ArgEscapeInfo}(undef, estate.nargs) + for i = 1:estate.nargs + cache[i] = ArgEscapeInfo(estate.escapes[i]) + end + return cache +end + +__clear_escape_cache!() = empty!(GLOBAL_ESCAPE_CACHE) + +abstract type Change end +struct EscapeChange <: Change + xidx::Int + xinfo::EscapeInfo +end +struct AliasChange <: Change + xidx::Int + yidx::Int +end +struct LivenessChange <: Change + xidx::Int + livepc::Int +end +const Changes = Vector{Change} + +struct AnalysisState + ir::IRCode + estate::EscapeState + changes::Changes +end + +function getinst(ir::IRCode, idx::Int) + nstmts = length(ir.stmts) + if idx ≤ nstmts + return ir.stmts[idx] + else + return ir.new_nodes.stmts[idx - nstmts] + end +end + +""" + analyze_escapes(ir::IRCode, nargs::Int) -> estate::EscapeState + +Analyzes escape information in `ir`. +`nargs` is the number of actual arguments of the analyzed call. +""" +function analyze_escapes(ir::IRCode, nargs::Int) + stmts = ir.stmts + nstmts = length(stmts) + length(ir.new_nodes.stmts) + + # only manage a single state, some flow-sensitivity is encoded as `EscapeInfo` properties + estate = EscapeState(nargs, nstmts) + changes = Changes() # stashes changes that happen at current statement + tryregions = compute_tryregions(ir) + astate = AnalysisState(ir, estate, changes) + + local debug_itr_counter = 0 + while true + local anyupdate = false + + for pc in nstmts:-1:1 + stmt = getinst(ir, pc)[:inst] + + # collect escape information + if isa(stmt, Expr) + head = stmt.head + if head === :call + escape_call!(astate, pc, stmt.args) + elseif head === :invoke + escape_invoke!(astate, pc, stmt.args) + elseif head === :new || head === :splatnew + escape_new!(astate, pc, stmt.args) + elseif head === :(=) + lhs, rhs = stmt.args + if isa(lhs, GlobalRef) # global store + add_escape_change!(astate, rhs, ⊤) + else + unexpected_assignment!(ir, pc) + end + elseif head === :foreigncall + escape_foreigncall!(astate, pc, stmt.args) + elseif head === :throw_undef_if_not # XXX when is this expression inserted ? + add_escape_change!(astate, stmt.args[1], ThrownEscape(pc)) + elseif is_meta_expr_head(head) + # meta expressions doesn't account for any usages + continue + elseif head === :enter || head === :leave || head === :the_exception || head === :pop_exception + # ignore these expressions since escapes via exceptions are handled by `escape_exception!` + # `escape_exception!` conservatively propagates `AllEscape` anyway, + # and so escape information imposed on `:the_exception` isn't computed + continue + elseif head === :static_parameter || # this exists statically, not interested in its escape + head === :copyast || # XXX can this account for some escapes? + head === :undefcheck || # XXX can this account for some escapes? + head === :isdefined || # just returns `Bool`, nothing accounts for any escapes + head === :gc_preserve_begin || # `GC.@preserve` expressions themselves won't be used anywhere + head === :gc_preserve_end # `GC.@preserve` expressions themselves won't be used anywhere + continue + else + for x in stmt.args + add_escape_change!(astate, x, ⊤) + end + end + elseif isa(stmt, ReturnNode) + if isdefined(stmt, :val) + add_escape_change!(astate, stmt.val, ReturnEscape(pc)) + end + elseif isa(stmt, PhiNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, PiNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, PhiCNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, UpsilonNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, GlobalRef) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + elseif isa(stmt, SSAValue) + escape_val!(astate, pc, stmt) + elseif isa(stmt, Argument) + escape_val!(astate, pc, stmt) + else # otherwise `stmt` can be GotoNode, GotoIfNot, and inlined values etc. + continue + end + + isempty(changes) && continue + + anyupdate |= propagate_changes!(estate, changes) + + empty!(changes) + end + + tryregions !== nothing && escape_exception!(astate, tryregions) + + debug_itr_counter += 1 + + anyupdate || break + end + + # if debug_itr_counter > 2 + # println("[EA] excessive iteration count found ", debug_itr_counter, " (", singleton_type(ir.argtypes[1]), ")") + # end + + return estate +end + +# propagate changes, and check convergence +function propagate_changes!(estate::EscapeState, changes::Changes) + local anychanged = false + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(estate, change) + elseif isa(change, LivenessChange) + anychanged |= propagate_liveness_change!(estate, change) + else + change = change::AliasChange + anychanged |= propagate_alias_change!(estate, change) + end + end + return anychanged +end + +@inline propagate_escape_change!(estate::EscapeState, change::EscapeChange) = + propagate_escape_change!(⊔, estate, change) + +# allows this to work as lattice join as well as lattice meet +@inline function propagate_escape_change!(@nospecialize(op), + estate::EscapeState, change::EscapeChange) + (; xidx, xinfo) = change + anychanged = _propagate_escape_change!(op, estate, xidx, xinfo) + aliases = getaliases(xidx, estate) + if aliases !== nothing + for aidx in aliases + anychanged |= _propagate_escape_change!(op, estate, aidx, xinfo) + end + end + return anychanged +end + +@inline function _propagate_escape_change!(@nospecialize(op), + estate::EscapeState, xidx::Int, info::EscapeInfo) + old = estate.escapes[xidx] + new = op(old, info) + if old ≠ new + estate.escapes[xidx] = new + return true + end + return false +end + +# propagate Liveness changes separately in order to avoid constructing too many LivenessSet +@inline function propagate_liveness_change!(estate::EscapeState, change::LivenessChange) + (; xidx, livepc) = change + info = estate.escapes[xidx] + Liveness = info.Liveness + Liveness === TOP_LIVENESS && return false + livepc in Liveness && return false + if Liveness === BOT_LIVENESS || Liveness === ARG_LIVENESS + # if this Liveness is a constant, we shouldn't modify it and propagate this change as a new EscapeInfo + Liveness = copy(Liveness) + push!(Liveness, livepc) + estate.escapes[xidx] = EscapeInfo(info; Liveness) + return true + else + # directly modify Liveness property in order to avoid excessive copies + push!(Liveness, livepc) + return true + end +end + +@inline function propagate_alias_change!(estate::EscapeState, change::AliasChange) + (; xidx, yidx) = change + xroot = find_root!(estate.aliasset, xidx) + yroot = find_root!(estate.aliasset, yidx) + if xroot ≠ yroot + union!(estate.aliasset, xroot, yroot) + xinfo = estate.escapes[xidx] + yinfo = estate.escapes[yidx] + xyinfo = xinfo ⊔ yinfo + estate.escapes[xidx] = xyinfo + estate.escapes[yidx] = xyinfo + return true + end + return false +end + +function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::EscapeInfo) + xinfo === ⊥ && return nothing # performance optimization + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, EscapeChange(xidx, xinfo)) + end + end + return nothing +end + +function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int) + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, LivenessChange(xidx, livepc)) + end + end + return nothing +end + +function add_alias_change!(astate::AnalysisState, @nospecialize(x), @nospecialize(y)) + if isa(x, GlobalRef) + return add_escape_change!(astate, y, ⊤) + elseif isa(y, GlobalRef) + return add_escape_change!(astate, x, ⊤) + end + estate = astate.estate + xidx = iridx(x, estate) + yidx = iridx(y, estate) + if xidx !== nothing && yidx !== nothing && !isaliased(xidx, yidx, astate.estate) + pushfirst!(astate.changes, AliasChange(xidx, yidx)) + end + return nothing +end + +function escape_edges!(astate::AnalysisState, pc::Int, edges::Vector{Any}) + ret = SSAValue(pc) + for i in 1:length(edges) + if isassigned(edges, i) + v = edges[i] + add_alias_change!(astate, ret, v) + end + end +end + +function escape_val_ifdefined!(astate::AnalysisState, pc::Int, x) + if isdefined(x, :val) + escape_val!(astate, pc, x.val) + end +end + +function escape_val!(astate::AnalysisState, pc::Int, @nospecialize(val)) + ret = SSAValue(pc) + add_alias_change!(astate, ret, val) +end + +# NOTE if we don't maintain the alias set that is separated from the lattice state, we can do +# soemthing like below: it essentially incorporates forward escape propagation in our default +# backward propagation, and leads to inefficient convergence that requires more iterations +# # lhs = rhs: propagate escape information of `rhs` to `lhs` +# function escape_alias!(astate::AnalysisState, @nospecialize(lhs), @nospecialize(rhs)) +# if isa(rhs, SSAValue) || isa(rhs, Argument) +# vinfo = astate.estate[rhs] +# else +# return +# end +# add_escape_change!(astate, lhs, vinfo) +# end + +# linear scan to find regions in which potential throws will be caught +function compute_tryregions(ir::IRCode) + tryregions = nothing + for idx in 1:length(ir.stmts) + stmt = ir.stmts[idx][:inst] + if isexpr(stmt, :enter) + tryregions === nothing && (tryregions = UnitRange{Int}[]) + leave_block = stmt.args[1]::Int + leave_pc = first(ir.cfg.blocks[leave_block].stmts) + push!(tryregions, idx:leave_pc) + end + end + for idx in 1:length(ir.new_nodes.stmts) + stmt = ir.new_nodes.stmts[idx][:inst] + @assert !isexpr(stmt, :enter) "try/catch inside new_nodes unsupported" + end + return tryregions +end + +""" + escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + +Propagates escapes via exceptions that can happen in `tryregions`. + +Naively it seems enough to propagate escape information imposed on `:the_exception` object, +but actually there are several other ways to access to the exception object such as +`Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of the allocated object +via `rethrow_escape!()` call in the example below: +```julia +const Gx = Ref{Any}() +@noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end +end +unsafeget(x) = isassigned(x) ? x[] : throw(x) + +code_escapes() do + r = Ref{String}() + try + t = unsafeget(r) + catch err + t = typeof(err) # `err` (which `r` may alias to) doesn't escape here + rethrow_escape!() # `r` can escape here + end + return t +end +``` + +As indicated by the above example, it requires a global analysis in addition to a base escape +analysis to reason about all possible escapes via existing exception interfaces correctly. +For now we conservatively always propagate `AllEscape` to all potentially thrown objects, +since such an additional analysis might not be worthwhile to do given that exception handlings +and error paths usually don't need to be very performance sensitive, and optimizations of +error paths might be very ineffective anyway since they are sometimes "unoptimized" +intentionally for latency reasons. +""" +function escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + estate = astate.estate + # NOTE if `:the_exception` is the only way to access the exception, we can do: + # exc = SSAValue(pc) + # excinfo = estate[exc] + excinfo = ⊤ + escapes = estate.escapes + for i in 1:length(escapes) + x = escapes[i] + xt = x.ThrownEscape + xt === TOP_THROWN_ESCAPE && @goto propagate_exception_escape # fast pass + for pc in x.ThrownEscape + for region in tryregions + pc in region && @goto propagate_exception_escape # early break because of AllEscape + end + end + continue + @label propagate_exception_escape + xval = irval(i, estate) + add_escape_change!(astate, xval, excinfo) + end +end + +function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}) + linfo = first(args)::MethodInstance + cache = get(GLOBAL_ESCAPE_CACHE, linfo, nothing) + if cache === nothing + for i in 2:length(args) + x = args[i] + add_escape_change!(astate, x, ⊤) + end + else + argescapes = argescapes_from_cache(cache) + ret = SSAValue(pc) + retinfo = astate.estate[ret] # escape information imposed on the call statement + method = linfo.def::Method + nargs = Int(method.nargs) + for i in 2:length(args) + arg = args[i] + if i-1 ≤ nargs + argi = i-1 + else # handle isva signature: COMBAK will this be invalid once we take alias information into account ? + argi = nargs + end + arginfo = argescapes[argi] + info = from_interprocedural(arginfo, retinfo, pc) + if arginfo.ReturnEscape + # if this argument can be "returned", in addition to propagating + # the escape information imposed on this call argument within the callee, + # we should also account for possible aliasing of this argument and the returned value + add_escape_change!(astate, arg, info) + add_alias_change!(astate, ret, arg) + else + # if this is simply passed as the call argument, we can just propagate + # the escape information imposed on this call argument within the callee + add_escape_change!(astate, arg, info) + end + end + end +end + +""" + from_interprocedural(arginfo::ArgEscapeInfo, retinfo::EscapeInfo, pc::Int) -> x::EscapeInfo + +Reinterprets the escape information imposed on the call argument which is cached as `arginfo` +in the context of the caller frame, where `retinfo` is the escape information imposed on +the return value and `pc` is the SSA statement number of the return value. +""" +function from_interprocedural(arginfo::ArgEscapeInfo, retinfo::EscapeInfo, pc::Int) + arginfo.AllEscape && return ⊤ + + ThrownEscape = arginfo.ThrownEscape ? LivenessSet(pc) : BOT_THROWN_ESCAPE + + return EscapeInfo( + #=Analyzed=#true, #=ReturnEscape=#false, ThrownEscape, + # FIXME implement interprocedural memory effect-analysis + # currently, this essentially disables the entire field analysis + # it might be okay from the SROA point of view, since we can't remove the allocation + # as far as it's passed to a callee anyway, but still we may want some field analysis + # for e.g. stack allocation or some other IPO optimizations + #=AliasInfo=#TOP_ALIAS_INFO, #=Liveness=#LivenessSet(pc)) +end + +@noinline function unexpected_assignment!(ir::IRCode, pc::Int) + @eval Main (ir = $ir; pc = $pc) + error("unexpected assignment found: inspect `Main.pc` and `Main.pc`") +end + +function escape_new!(astate::AnalysisState, pc::Int, args::Vector{Any}) + obj = SSAValue(pc) + objinfo = astate.estate[obj] + AliasInfo = objinfo.AliasInfo + nargs = length(args) + if isa(AliasInfo, Bool) + @goto conservative_propagation + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + # fields are known precisely: propagate escape information imposed on recorded possibilities to the exact field values + infos = AliasInfo.infos + nf = length(infos) + objinfo = ignore_aliasinfo(objinfo) + for i in 2:nargs + i-1 > nf && break # may happen when e.g. ϕ-node merges values with different types + arg = args[i] + add_alias_escapes!(astate, arg, infos[i-1]) + push!(infos[i-1], -pc) # record def + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + # fields are known partially: propagate escape information imposed on recorded possibilities to all fields values + info = AliasInfo.info + objinfo = ignore_aliasinfo(objinfo) + for i in 2:nargs + arg = args[i] + add_alias_escapes!(astate, arg, info) + push!(info, -pc) # record def + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + else + # this object has been used as array, but it is allocated as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the fields couldn't be analyzed precisely: propagate the entire escape information + # of this object to all its fields as the most conservative propagation + for i in 2:nargs + arg = args[i] + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + end + if !(getinst(astate.ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0) + add_thrown_escapes!(astate, pc, args) + end +end + +function add_alias_escapes!(astate::AnalysisState, @nospecialize(v), ainfo::AInfo) + estate = astate.estate + for aidx in ainfo + aidx < 0 && continue # ignore def + x = SSAValue(aidx) # obviously this won't be true once we implement ArgEscape + add_alias_change!(astate, v, x) + end +end + +function escape_unanalyzable_obj!(astate::AnalysisState, @nospecialize(obj), objinfo::EscapeInfo) + objinfo = EscapeInfo(objinfo, TOP_ALIAS_INFO) + add_escape_change!(astate, obj, objinfo) + return objinfo +end + +function add_thrown_escapes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], info) + end +end + +function add_liveness_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + arg = args[i] + add_liveness_change!(astate, arg, pc) + end +end + +function add_fallback_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + arg = args[i] + add_escape_change!(astate, arg, info) + add_liveness_change!(astate, arg, pc) + end +end + +# escape every argument `(args[6:length(args[3])])` and the name `args[1]` +# TODO: we can apply a similar strategy like builtin calls to specialize some foreigncalls +function escape_foreigncall!(astate::AnalysisState, pc::Int, args::Vector{Any}) + nargs = length(args) + if nargs < 6 + # invalid foreigncall, just escape everything + for i = 1:length(args) + add_escape_change!(astate, args[i], ⊤) + end + return + end + argtypes = args[3]::SimpleVector + nargs = length(argtypes) + name = args[1] + nn = normalize(name) + if isa(nn, Symbol) + boundserror_ninds = array_resize_info(nn) + if boundserror_ninds !== nothing + boundserror, ninds = boundserror_ninds + escape_array_resize!(boundserror, ninds, astate, pc, args) + return + end + if is_array_copy(nn) + escape_array_copy!(astate, pc, args) + return + elseif is_array_isassigned(nn) + escape_array_isassigned!(astate, pc, args) + return + end + # if nn === :jl_gc_add_finalizer_th + # # TODO add `FinalizerEscape` ? + # end + end + # NOTE array allocations might have been proven as nothrow (https://github.com/JuliaLang/julia/pull/43565) + nothrow = astate.ir.stmts[pc][:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + name_info = nothrow ? ⊥ : ThrownEscape(pc) + add_escape_change!(astate, name, name_info) + add_liveness_change!(astate, name, pc) + for i = 1:nargs + # we should escape this argument if it is directly called, + # otherwise just impose ThrownEscape if not nothrow + if argtypes[i] === Any + arg_info = ⊤ + else + arg_info = nothrow ? ⊥ : ThrownEscape(pc) + end + add_escape_change!(astate, args[5+i], arg_info) + add_liveness_change!(astate, args[5+i], pc) + end + for i = (5+nargs):length(args) + arg = args[i] + add_escape_change!(astate, arg, ⊥) + add_liveness_change!(astate, arg, pc) + end +end + +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}) + ir = astate.ir + ft = argextype(first(args), ir, ir.sptypes, ir.argtypes) + f = singleton_type(ft) + if isa(f, Core.IntrinsicFunction) + # XXX somehow `:call` expression can creep in here, ideally we should be able to do: + # argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + argtypes = Any[] + for i = 2:length(args) + arg = args[i] + push!(argtypes, isexpr(arg, :call) ? Any : argextype(arg, ir)) + end + if intrinsic_nothrow(f, argtypes) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return # TODO accounts for pointer operations? + end + result = escape_builtin!(f, astate, pc, args) + if result === missing + # if this call hasn't been handled by any of pre-defined handlers, + # we escape this call conservatively + for i in 2:length(args) + add_escape_change!(astate, args[i], ⊤) + end + add_escape_change!(astate, SSAValue(pc), ⊤) + return + elseif result === true + add_liveness_changes!(astate, pc, args, 2) + return # ThrownEscape is already checked + else + # we escape statements with the `ThrownEscape` property using the effect-freeness + # computed by `stmt_effect_free` invoked within inlining + # TODO throwness ≠ "effect-free-ness" + if getinst(astate.ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return + end +end + +escape_builtin!(@nospecialize(f), _...) = return missing + +# safe builtins +escape_builtin!(::typeof(isa), _...) = return false +escape_builtin!(::typeof(typeof), _...) = return false +escape_builtin!(::typeof(sizeof), _...) = return false +escape_builtin!(::typeof(===), _...) = return false +# not really safe, but `ThrownEscape` will be imposed later +escape_builtin!(::typeof(isdefined), _...) = return false +escape_builtin!(::typeof(throw), _...) = return false + +function escape_builtin!(::typeof(ifelse), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 4 || return false + f, cond, th, el = args + ret = SSAValue(pc) + condt = argextype(cond, astate.ir) + if isa(condt, Const) && (cond = condt.val; isa(cond, Bool)) + if cond + add_alias_change!(astate, th, ret) + else + add_alias_change!(astate, el, ret) + end + else + add_alias_change!(astate, th, ret) + add_alias_change!(astate, el, ret) + end + return false +end + +function escape_builtin!(::typeof(typeassert), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + f, obj, typ = args + ret = SSAValue(pc) + add_alias_change!(astate, ret, obj) + return false +end + +function escape_builtin!(::typeof(tuple), astate::AnalysisState, pc::Int, args::Vector{Any}) + escape_new!(astate, pc, args) + return false +end + +function analyze_fields(ir::IRCode, @nospecialize(typ), @nospecialize(fld)) + nfields = fieldcount_noerror(typ) + if nfields === nothing + return Unindexable(false, AInfo()), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(false, AInfo()), 0 + end + return Indexable(false, AInfo[AInfo() for _ in 1:nfields]), fidx +end + +function reanalyze_fields(ir::IRCode, AliasInfo::Indexable, @nospecialize(typ), @nospecialize(fld)) + infos = AliasInfo.infos + nfields = fieldcount_noerror(typ) + if nfields === nothing + return Unindexable(false, merge_to_unindexable(infos)), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(false, merge_to_unindexable(infos)), 0 + end + ninfos = length(infos) + if nfields > ninfos + for _ in 1:(nfields-ninfos) + push!(infos, AInfo()) + end + end + return AliasInfo, fidx +end + +function escape_builtin!(::typeof(getfield), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 3 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + typ = widenconst(argextype(obj, ir)) + if hasintersect(typ, Module) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + end + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + return false + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the fields of this object haven't been analyzed yet: analyze them now + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, Indexable) + @goto record_indexable_use + else + @goto record_unindexable_use + end + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto record_unindexable_use + @label record_indexable_use + push!(AliasInfo.infos[fidx], pc) # record use + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + @label record_unindexable_use + push!(AliasInfo.info, pc) # record use + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed precisely: propagate the escape information + # imposed on the return value of this `getfield` call to the object itself + # as the most conservative propagation + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, obj, ssainfo) + end + return false +end + +function escape_builtin!(::typeof(setfield!), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + val = args[4] + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + @goto add_thrown_escapes + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the fields of this object haven't been analyzed yet: analyze them now + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, Indexable) + @goto escape_indexable_def + else + @goto escape_unindexable_def + end + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto escape_unindexable_def + @label escape_indexable_def + add_alias_escapes!(astate, val, AliasInfo.infos[fidx]) + push!(AliasInfo.infos[fidx], -pc) # record def + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + info = AliasInfo.info + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, -pc) # record def + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed: propagate the entire escape information + # of this object to the value being assigned as the most conservative propagation + add_escape_change!(astate, val, objinfo) + end + # also propagate escape information imposed on the return value of this `setfield!` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, val, ssainfo) + # compute the throwness of this setfield! call here since builtin_nothrow doesn't account for that + @label add_thrown_escapes + argtypes = Any[] + for i = 2:length(args) + push!(argtypes, argextype(args[i], ir)) + end + setfield!_nothrow(argtypes) || add_thrown_escapes!(astate, pc, args, 2) + return true +end + +function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + # check potential thrown escapes from this arrayref call + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + if !array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 3) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and thus don't know what values + # can be referenced here: directly propagate the escape information imposed on the return + # value of this `arrayref` call to the array itself as the most conservative propagation + # but also with updated index information + # TODO enable index analysis when constant values are available? + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the elements of this array haven't been analyzed yet: set AliasInfo now + AliasInfo = Unindexable(true, AInfo()) + @goto record_unindexable_use + elseif isa(AliasInfo, Indexable) && AliasInfo.array + throw("array index analysis unsupported") + elseif isa(AliasInfo, Unindexable) && AliasInfo.array + @label record_unindexable_use + push!(AliasInfo.info, pc) # record use + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + end + return true +end + +function escape_builtin!(::typeof(arrayset), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 5 || return false + # check potential escapes from this arrayset call + # NOTE here we essentially only need to account for TypeError, assuming that + # UndefRefError or BoundsError don't capture any of the arguments here + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + valt = argtypes[3] + if !(array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 4) && + arrayset_typecheck(aryt, valt)) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + val = args[4] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and won't record what value + # is being assigned here: directly propagate the escape information of this array to + # the value being assigned as the most conservative propagation + # TODO enable index analysis when constant values are available? + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the elements of this array haven't been analyzed yet: set AliasInfo now + AliasInfo = Unindexable(true, AInfo()) + @goto escape_unindexable_def + elseif isa(AliasInfo, Indexable) && AliasInfo.array + throw("array index analysis unsupported") + elseif isa(AliasInfo, Unindexable) && AliasInfo.array + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, -pc) # record def + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + add_escape_change!(astate, val, aryinfo) + end + # also propagate escape information imposed on the return value of this `arrayset` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + return true +end + +function escape_builtin!(::typeof(arraysize), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + ary = args[2] + dim = args[3] + if !arraysize_typecheck(ary, dim, astate.ir) + add_escape_change!(astate, ary, ThrownEscape(pc)) + add_escape_change!(astate, dim, ThrownEscape(pc)) + end + # NOTE we may still see "arraysize: dimension out of range", but it doesn't capture anything + return true +end + +function arraysize_typecheck(@nospecialize(ary), @nospecialize(dim), ir::IRCode) + aryt = argextype(ary, ir) + aryt ⊑ₜ Array || return false + dimt = argextype(dim, ir) + dimt ⊑ₜ Int || return false + return true +end + +# returns nothing if this isn't array resizing operation, +# otherwise returns true if it can throw BoundsError and false if not +function array_resize_info(name::Symbol) + if name === :jl_array_grow_beg || name === :jl_array_grow_end + return false, 1 + elseif name === :jl_array_del_beg || name === :jl_array_del_end + return true, 1 + elseif name === :jl_array_grow_at || name === :jl_array_del_at + return true, 2 + else + return nothing + end +end + +# NOTE may potentially throw "cannot resize array with shared data" error, +# but just ignore it since it doesn't capture anything +function escape_array_resize!(boundserror::Bool, ninds::Int, + astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6+ninds || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ₜ Array || return add_fallback_changes!(astate, pc, args) + for i in 1:ninds + ind = args[i+6] + indt = argextype(ind, astate.ir) + indt ⊑ₜ Integer || return add_fallback_changes!(astate, pc, args) + end + if boundserror + # this array resizing can potentially throw `BoundsError`, impose it now + add_escape_change!(astate, ary, ThrownEscape(pc)) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_copy(name::Symbol) = name === :jl_array_copy + +# FIXME this implementation is very conservative, improve the accuracy and solve broken test cases +function escape_array_copy!(astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6 || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ₜ Array || return add_fallback_changes!(astate, pc, args) + if isa(ary, SSAValue) || isa(ary, Argument) + newary = SSAValue(pc) + aryinfo = astate.estate[ary] + newaryinfo = astate.estate[newary] + add_escape_change!(astate, newary, aryinfo) + add_escape_change!(astate, ary, newaryinfo) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_isassigned(name::Symbol) = name === :jl_array_isassigned + +function escape_array_isassigned!(astate::AnalysisState, pc::Int, args::Vector{Any}) + if !array_isassigned_nothrow(args, astate.ir) + add_thrown_escapes!(astate, pc, args) + end + add_liveness_changes!(astate, pc, args, 6) +end + +function array_isassigned_nothrow(args::Vector{Any}, src::IRCode) + # if !validate_foreigncall_args(args, + # :jl_array_isassigned, Cint, svec(Any,Csize_t), 0, :ccall) + # return false + # end + length(args) ≥ 7 || return false + arytype = argextype(args[6], src) + arytype ⊑ₜ Array || return false + idxtype = argextype(args[7], src) + idxtype ⊑ₜ Csize_t || return false + return true +end + +# # COMBAK do we want to enable this (and also backport this to Base for array allocations?) +# import Core.Compiler: Cint, svec +# function validate_foreigncall_args(args::Vector{Any}, +# name::Symbol, @nospecialize(rt), argtypes::SimpleVector, nreq::Int, convension::Symbol) +# length(args) ≥ 5 || return false +# normalize(args[1]) === name || return false +# args[2] === rt || return false +# args[3] === argtypes || return false +# args[4] === vararg || return false +# normalize(args[5]) === convension || return false +# return true +# end + +if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +escape_builtin!(::typeof(arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(mutating_arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(arraythaw), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(ImmutableArray, astate, args) +function is_safe_immutable_array_op(@nospecialize(arytype), astate::AnalysisState, args::Vector{Any}) + length(args) == 2 || return false + argextype(args[2], astate.ir) ⊑ₜ arytype || return false + return true +end + +end # if isdefined(Core, :ImmutableArray) + +# NOTE define fancy package utilities when developing EA as an external package +if _TOP_MOD !== Core.Compiler + include(@__MODULE__, "EAUtils.jl") + using .EAUtils: code_escapes, @code_escapes + export code_escapes, @code_escapes +end + +end # baremodule EscapeAnalysis diff --git a/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl new file mode 100644 index 0000000000000..915bc214d7c3c --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl @@ -0,0 +1,143 @@ +# A disjoint set implementation adapted from +# https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +# imports +import ._TOP_MOD: + length, + eltype, + union!, + push! +# usings +import ._TOP_MOD: + OneTo, collect, zero, zeros, one, typemax + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index e54a09fe351b3..7329dafcb1121 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -14,8 +14,10 @@ include("compiler/ssair/basicblock.jl") include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") -include("compiler/ssair/passes.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -#@isdefined(Base) && include("compiler/ssair/show.jl") +function try_compute_field end # imported by EscapeAnalysis +include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl") +include("compiler/ssair/passes.jl") +# @isdefined(Base) && include("compiler/ssair/show.jl") diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 51b5ec076f25d..c6913dd077d60 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -641,7 +641,7 @@ its argument). In a case when all usages are fully eliminated, `struct` allocation may also be erased as a result of succeeding dead code elimination. """ -function sroa_pass!(ir::IRCode) +function sroa_pass!(ir::IRCode, nargs::Int) compact = IncrementalCompact(ir) defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() @@ -813,7 +813,7 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) + sroa_mutables!(ir, defuses, used_ssas, nargs) return ir else simple_dce!(compact) @@ -821,9 +821,10 @@ function sroa_pass!(ir::IRCode) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) +function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, nargs::Int) # initialization of domtree is delayed to avoid the expensive computation in many cases local domtree = nothing + estate = analyze_escapes(ir, nargs) for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -899,6 +900,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end end + is_load_forwardable(estate[SSAValue(idx)]) || println("[EA] bad EA: ", ir.argtypes[1:nargs], " at ", idx) # Everything accounted for. Go field by field and perform idf: # Compute domtree now, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` @@ -957,6 +959,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end +function is_load_forwardable(x::EscapeAnalysis.EscapeInfo) + AliasInfo = x.AliasInfo + return isa(AliasInfo, EscapeAnalysis.Indexable) && !AliasInfo.array +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index e97441495f16b..9b1106e964919 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -19,6 +19,8 @@ function _any(@nospecialize(f), a) end return false end +any(@nospecialize(f), itr) = _any(f, itr) +any(itr) = _any(identity, itr) function _all(@nospecialize(f), a) for x in a @@ -26,6 +28,8 @@ function _all(@nospecialize(f), a) end return true end +all(@nospecialize(f), itr) = _all(f, itr) +all(itr) = _all(identity, itr) function contains_is(itr, @nospecialize(x)) for y in itr diff --git a/base/exports.jl b/base/exports.jl index c43e66eecb74c..7d71915909a5b 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -774,6 +774,7 @@ export # help and reflection code_typed, code_lowered, + code_escapes, fullname, functionloc, isconst, diff --git a/doc/make.jl b/doc/make.jl index 8be3b807400d1..214644f7d7473 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -148,6 +148,7 @@ DevDocs = [ "devdocs/require.md", "devdocs/inference.md", "devdocs/ssair.md", + "devdocs/EscapeAnalysis/README.md", "devdocs/gc-sa.md", ], "Developing/debugging Julia's C code" => [ diff --git a/doc/src/devdocs/EscapeAnalysis/README.md b/doc/src/devdocs/EscapeAnalysis/README.md new file mode 100644 index 0000000000000..a0c2340fc08d0 --- /dev/null +++ b/doc/src/devdocs/EscapeAnalysis/README.md @@ -0,0 +1,322 @@ +[![CI](https://github.com/aviatesk/EscapeAnalysis.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/aviatesk/EscapeAnalysis.jl/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/aviatesk/EscapeAnalysis.jl/branch/master/graph/badge.svg?token=ADEKPZRUJH)](https://codecov.io/gh/aviatesk/EscapeAnalysis.jl) +[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://aviatesk.github.io/EscapeAnalysis.jl/dev/) + +`EscapeAnalysis` is a simple module that collects escape information in +[Julia's SSA optimization IR](@ref Julia-SSA-form-IR) a.k.a. `IRCode`. + +You can give a try to the escape analysis with the convenience entries that +`EscapeAnalysis` exports for testing and debugging purposes: +```@docs +Base.code_escapes +InteractiveUtils.@code_escapes +``` + +## Analysis Design + +### Lattice Design + +`EscapeAnalysis` is implemented as a [data-flow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) +that works on a lattice of `x::EscapeInfo`, which is composed of the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::BitSet`: records SSA statements where `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements where `x` can be thrown as exception + (used for the [exception handling](@ref EA-Exception-Handling) described below) +- `x.AliasInfo`: maintains all possible values that can be aliased to fields or array elements of `x` + (used for the [alias analysis](@ref EA-Alias-Analysis) described below) +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + +These attributes can be combined to create a partial lattice that has a finite height, given +the invariant that an input program has a finite number of statements, which is assured by Julia's semantics. +The clever part of this lattice design is that it enables a simpler implementation of +lattice operations by allowing them to handle each lattice property separately[^LatticeDesign]. + +### Backward Escape Propagation + +This escape analysis implementation is based on the data-flow algorithm described in the paper[^MM02]. +The analysis works on the lattice of `EscapeInfo` and transitions lattice elements from the +bottom to the top until every lattice element gets converged to a fixed point by maintaining +a (conceptual) working set that contains program counters corresponding to remaining SSA +statements to be analyzed. The analysis manages a single global state that tracks +`EscapeInfo` of each argument and SSA statement, but also note that some flow-sensitivity +is encoded as program counters recorded in `EscapeInfo`'s `ReturnEscape` property, +which can be combined with domination analysis later to reason about flow-sensitivity if necessary. + +One distinctive design of this escape analysis is that it is fully _backward_, +i.e. escape information flows _from usages to definitions_. +For example, in the code snippet below, EA first analyzes the statement `return %1` and +imposes `ReturnEscape` on `%1` (corresponding to `obj`), and then it analyzes +`%1 = %new(Base.RefValue{String, _2}))` and propagates the `ReturnEscape` imposed on `%1` +to the call argument `_2` (corresponding to `s`): +```julia +julia> code_escapes((String,)) do s + obj = Ref(s) + return obj + end +#1(↑ _2::String) in Main at REPL[2]:2 +2 ↑ 1 ─ %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 ◌ └── return %1 │ +``` + +The key observation here is that this backward analysis allows escape information to flow +naturally along the use-def chain rather than control-flow[^BackandForth]. +As a result this scheme enables a simple implementation of escape analysis, +e.g. `PhiNode` for example can be handled simply by propagating escape information +imposed on a `PhiNode` to its predecessor values: +```julia +julia> code_escapes((Bool, String, String)) do cnd, s, t + if cnd + obj = Ref(s) + else + obj = Ref(t) + end + return obj + end + #3(↑ _2::Bool, ↑ _3::String, ↑ _4::String) in Main at REPL[3]:2 +2 ◌ 1 ─ goto #3 if not _2 │ +3 ↑ 2 ─ %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref + ◌ └── goto #4 │ +5 ↑ 3 ─ %4 = %new(Base.RefValue{String}, _4)::Base.RefValue{String} │╻╷╷ Ref +7 ↑ 4 ┄ %5 = φ (#2 => %2, #3 => %4)::Base.RefValue{String} │ + ◌ └── return %5 │ +``` + +### [Alias Analysis](@id EA-Alias-Analysis) + +`EscapeAnalysis` implements a backward field analysis in order to reason about escapes +imposed on object fields with certain accuracy, +and `x::EscapeInfo`'s `x.AliasInfo` property exists for this purpose. +It records all possible values that can be aliased to fields of `x` at "usage" sites, +and then the escape information of that recorded values are propagated to the actual field values later at "definition" sites. +More specifically, the analysis records a value that may be aliased to a field of object by analyzing `getfield` call, +and then it propagates its escape information to the field when analyzing `%new(...)` expression or `setfield!` call[^Dynamism]. +```julia +julia> mutable struct SafeRef{T} + x::T + end + +julia> Base.getindex(x::SafeRef) = x.x; + +julia> Base.setindex!(x::SafeRef, v) = x.x = v; + +julia> code_escapes((String,)) do s + obj = SafeRef("init") + obj[] = s + v = obj[] + return v + end +#5(↑ _2::String) in Main at REPL[7]:2 +2 ✓′ 1 ─ %1 = %new(SafeRef{String}, "init")::SafeRef{String} │╻╷ SafeRef +3 ◌ │ Base.setfield!(%1, :x, _2)::String │╻╷ setindex! +4 ↑ │ %3 = Base.getfield(%1, :x)::String │╻╷ getindex +5 ◌ └── return %3 │ +``` +In the example above, `ReturnEscape` imposed on `%3` (corresponding to `v`) is _not_ directly +propagated to `%1` (corresponding to `obj`) but rather that `ReturnEscape` is only propagated +to `_2` (corresponding to `s`). Here `%3` is recorded in `%1`'s `AliasInfo` property as +it can be aliased to the first field of `%1`, and then when analyzing `Base.setfield!(%1, :x, _2)::String`, +that escape information is propagated to `_2` but not to `%1`. + +So `EscapeAnalysis` tracks which IR elements can be aliased across a `getfield`-`%new`/`setfield!` chain +in order to analyze escapes of object fields, but actually this alias analysis needs to be +generalized to handle other IR elements as well. This is because in Julia IR the same +object is sometimes represented by different IR elements and so we should make sure that those +different IR elements that actually can represent the same object share the same escape information. +IR elements that return the same object as their operand(s), such as `PiNode` and `typeassert`, +can cause that IR-level aliasing and thus requires escape information imposed on any of such +aliased values to be shared between them. +More interestingly, it is also needed for correctly reasoning about mutations on `PhiNode`. +Let's consider the following example: +```julia +julia> code_escapes((Bool, String,)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + y = ϕ1[] + return y + end +#7(↑ _2::Bool, ↑ _3::String) in Main at REPL[8]:2 +2 ◌ 1 ─ goto #3 if not _2 │ +3 ✓′ 2 ─ %2 = %new(SafeRef{String}, "foo")::SafeRef{String} │╻╷ SafeRef + ◌ └── goto #4 │ +5 ✓′ 3 ─ %4 = %new(SafeRef{String}, "bar")::SafeRef{String} │╻╷ SafeRef +7 ✓′ 4 ┄ %5 = φ (#2 => %2, #3 => %4)::SafeRef{String} │ + ✓′ │ %6 = φ (#2 => %2, #3 => %4)::SafeRef{String} │ + ◌ │ Base.setfield!(%5, :x, _3)::String │╻ setindex! +8 ↑ │ %8 = Base.getfield(%6, :x)::String │╻╷ getindex +9 ◌ └── return %8 │ +``` +`ϕ1 = %5` and `ϕ2 = %6` are aliased and thus `ReturnEscape` imposed on `%8 = Base.getfield(%6, :x)::String` (corresponding to `y = ϕ1[]`) +needs to be propagated to `Base.setfield!(%5, :x, _3)::String` (corresponding to `ϕ2[] = x`). +In order for such escape information to be propagated correctly, the analysis should recognize that +the _predecessors_ of `ϕ1` and `ϕ2` can be aliased as well and equalize their escape information. + +One interesting property of such aliasing information is that it is not known at "usage" site +but can only be derived at "definition" site (as aliasing is conceptually equivalent to assignment), +and thus it doesn't naturally fit in a backward analysis. In order to efficiently propagate escape +information between related values, EscapeAnalysis.jl uses an approach inspired by the escape +analysis algorithm explained in an old JVM paper[^JVM05]. That is, in addition to managing +escape lattice elements, the analysis also maintains an "equi"-alias set, a disjoint set of +aliased arguments and SSA statements. The alias set manages values that can be aliased to +each other and allows escape information imposed on any of such aliased values to be equalized +between them. + +Lastly, this scheme of alias/field analysis can also be generalized to analyze array operations. +`EscapeAnalysis` currently reasons about escapes imposed on array elements using +an imprecise version of the field analysis described above, where `AliasInfo` doesn't +try to track precise array index but rather simply records all possible values that can be +aliased any elements of the array. + +### [Exception Handling](@id EA-Exception-Handling) + +It would be also worth noting how `EscapeAnalysis` handles possible escapes via exceptions. +Naively it seems enough to propagate escape information imposed on `:the_exception` object to +all values that may be thrown in a corresponding `try` block. +But there are actually several other ways to access to the exception object in Julia, +such as `Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of `r` in the example below: +```julia +julia> const Gx = Ref{Any}(); + +julia> @noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end + end; + +julia> get′(x) = isassigned(x) ? x[] : throw(x); + +julia> code_escapes() do + r = Ref{String}() + local t + try + t = get′(r) + catch err + t = typeof(err) # `err` (which `r` aliases to) doesn't escape here + rethrow_escape!() # but `r` escapes here + end + return t + end +#9() in Main at REPL[12]:2 +2 X 1 ── %1 = %new(Base.RefValue{String})::Base.RefValue{String} │╻╷ Ref +4 ◌ 2 ── %2 = $(Expr(:enter, #8)) │ +5 ◌ 3 ── %3 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %3 ││ + ↑ 4 ── %5 = Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ + ◌ 6 ── $(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + ◌ 8 ── $(Expr(:leave, 1)) │ + ◌ 9 ── %12 = $(Expr(:the_exception))::Any │ +7 ↑ │ %13 = Main.typeof(%12)::DataType │ +8 ◌ │ invoke Main.rethrow_escape!()::Any │ + ◌ └─── $(Expr(:pop_exception, :(%2)))::Any │ +10 ↑ 10 ┄ %16 = φ (#7 => %5, #9 => %13)::Union{DataType, String} │ + ◌ └─── return %16 │ +``` + +It requires a global analysis in order to correctly reason about all possible escapes via +existing exception interfaces. For now we always propagate the topmost escape information to +all potentially thrown objects conservatively, since such an additional analysis might not be +worthwhile to do given that exception handling and error path usually don't need to be +very performance sensitive, and also optimizations of error paths might be very ineffective anyway +since they are often even "unoptimized" intentionally for latency reasons. + +`x::EscapeInfo`'s `x.ThrownEscape` property records SSA statements where `x` can be thrown as an exception. +Using this information `EscapeAnalysis` can propagate possible escapes via exceptions limitedly +to only those may be thrown in each `try` region: +```julia +julia> result = code_escapes((String,String)) do s1, s2 + r1 = Ref(s1) + r2 = Ref(s2) + local ret + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global g = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + return s2 + end +#11(X _2::String, ↑ _3::String) in Main at REPL[13]:2 +2 X 1 ── %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 *′ └─── %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref +5 ◌ 2 ── %3 = $(Expr(:enter, #8)) │ + *′ └─── %4 = ϒ (%2)::Base.RefValue{String} │ +6 ◌ 3 ── %5 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %5 ││ + X 4 ── Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ +7 ◌ 6 ── nothing::typeof(Core.sizeof) │╻ sizeof + ◌ │ nothing::Int64 ││ + ◌ └─── $(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + *′ 8 ── %15 = φᶜ (%4)::Base.RefValue{String} │ + ◌ └─── $(Expr(:leave, 1)) │ + X 9 ── %17 = $(Expr(:the_exception))::Any │ +9 ◌ │ (Main.g = %17)::Any │ + ◌ └─── $(Expr(:pop_exception, :(%3)))::Any │ +11 *′ 10 ┄ %20 = φ (#7 => %2, #9 => %15)::Base.RefValue{String} │ + ◌ │ %21 = Base.isdefined(%20, :x)::Bool ││╻ isassigned + ◌ └─── goto #12 if not %21 ││ + ↑ 11 ─ %23 = Base.getfield(%20, :x)::String │││╻ getproperty + ◌ └─── goto #13 ││ + ◌ 12 ─ Main.throw(%20)::Union{} ││ + ◌ └─── unreachable ││ +12 ◌ 13 ─ return %23 │ +``` + +## Analysis Usage + +When using `EscapeAnalysis` in Julia's high-level compilation pipeline, we can run +`analyze_escapes(ir::IRCode) -> estate::EscapeState` to analyze escape information of each SSA-IR element in `ir`. + +Note that it should be most effective if `analyze_escapes` runs after inlining, +as `EscapeAnalysis`'s interprocedural escape information handling is limited at this moment. + +Since the computational cost of `analyze_escapes` is not that cheap, +it is more ideal if it runs once and succeeding optimization passes incrementally update + the escape information upon IR transformation. + +```@docs +Core.Compiler.EscapeAnalysis.analyze_escapes +Core.Compiler.EscapeAnalysis.EscapeState +Core.Compiler.EscapeAnalysis.cache_escapes! +``` + +[^LatticeDesign]: Our type inference implementation takes the alternative approach, + where each lattice property is represented by a special lattice element type object. + It turns out that it started to complicate implementations of the lattice operations + mainly because it often requires conversion rules between each lattice element type object. + And we are working on [overhauling our type inference lattice implementation](https://github.com/JuliaLang/julia/pull/42596) + with `EscapeInfo`-like lattice design. + +[^MM02]: _A Graph-Free approach to Data-Flow Analysis_. + Markas Mohnen, 2002, April. + . + +[^BackandForth]: Our type inference algorithm in contrast is implemented as a forward analysis, + because type information usually flows from "definition" to "usage" and it is more + natural and effective to propagate such information in a forward way. + +[^Dynamism]: In some cases, however, object fields can't be analyzed precisely. + For example, object may escape to somewhere `EscapeAnalysis` can't account for possible memory effects on it, + or fields of the objects simply can't be known because of the lack of type information. + In such cases `AliasInfo` property is raised to the topmost element within its own lattice order, + and it causes succeeding field analysis to be conservative and escape information imposed on + fields of an unanalyzable object to be propagated to the object itself. + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . diff --git a/doc/src/devdocs/llvm.md b/doc/src/devdocs/llvm.md index 1e983949ea0b6..840822f136004 100644 --- a/doc/src/devdocs/llvm.md +++ b/doc/src/devdocs/llvm.md @@ -28,7 +28,7 @@ The difference between an intrinsic and a builtin is that a builtin is a first c that can be used like any other Julia function. An intrinsic can operate only on unboxed data, and therefore its arguments must be statically typed. -### Alias Analysis +### [Alias Analysis](@id LLVM-Alias-Analysis) Julia currently uses LLVM's [Type Based Alias Analysis](https://llvm.org/docs/LangRef.html#tbaa-metadata). To find the comments that document the inclusion relationships, look for `static MDNode*` in diff --git a/stdlib/InteractiveUtils/src/InteractiveUtils.jl b/stdlib/InteractiveUtils/src/InteractiveUtils.jl index 1df8a2ca8f93b..76edec0d0cc57 100644 --- a/stdlib/InteractiveUtils/src/InteractiveUtils.jl +++ b/stdlib/InteractiveUtils/src/InteractiveUtils.jl @@ -6,7 +6,7 @@ Base.Experimental.@optlevel 1 export apropos, edit, less, code_warntype, code_llvm, code_native, methodswith, varinfo, versioninfo, subtypes, supertypes, @which, @edit, @less, @functionloc, @code_warntype, - @code_typed, @code_lowered, @code_llvm, @code_native, @time_imports, clipboard + @code_typed, @code_lowered, @code_llvm, @code_native, @code_escapes, @time_imports, clipboard import Base.Docs.apropos diff --git a/stdlib/InteractiveUtils/src/macros.jl b/stdlib/InteractiveUtils/src/macros.jl index 0a1fd848b0208..659111b9b2bf2 100644 --- a/stdlib/InteractiveUtils/src/macros.jl +++ b/stdlib/InteractiveUtils/src/macros.jl @@ -208,7 +208,7 @@ macro which(ex0::Symbol) return :(which($__module__, $ex0)) end -for fname in [:code_warntype, :code_llvm, :code_native] +for fname in [:code_warntype, :code_llvm, :code_native, :code_escapes] @eval begin macro ($fname)(ex0...) gen_call_with_extracted_types_and_kwargs(__module__, $(Expr(:quote, fname)), ex0) @@ -350,6 +350,16 @@ See also: [`code_native`](@ref), [`@code_llvm`](@ref), [`@code_typed`](@ref) and """ :@code_native +""" + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escpase interp=myinterp myfunc(myargs...)`. +""" +:@code_escapes + """ @time_imports diff --git a/test/choosetests.jl b/test/choosetests.jl index e00aedffdd42e..0415481a49b44 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -142,7 +142,7 @@ function choosetests(choices = []) filtertests!(tests, "subarray") filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", "compiler/codegen", - "compiler/inline", "compiler/contextual"]) + "compiler/inline", "compiler/contextual", "compiler/EscapeAnalysis/EscapeAnalysis"]) filtertests!(tests, "stdlib", STDLIBS) # do ambiguous first to avoid failing if ambiguities are introduced by other tests filtertests!(tests, "ambiguous") diff --git a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 0000000000000..001b513a25e3f --- /dev/null +++ b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,2052 @@ +@isdefined(EA_AS_PKG) || include(normpath(@__DIR__, "setup.jl")) + +@testset "basics" begin + let # arg return + result = code_escapes((Any,)) do a # return to caller + return nothing + end + @test has_arg_escape(result.state[Argument(2)]) + # return + result = code_escapes((Any,)) do a + return a + end + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_arg_escape(result.state[Argument(1)]) # self + @test !has_return_escape(result.state[Argument(1)], i) # self + @test has_arg_escape(result.state[Argument(2)]) # a + @test has_return_escape(result.state[Argument(2)], i) # a + end + let # global store + result = code_escapes((Any,)) do a + global aa = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + let # global load + result = code_escapes() do + global gr + return gr + end + i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) + result = code_escapes((Any,)) do s + global v + v = s + return v + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + end + let # :gc_preserve_begin / :gc_preserve_end + result = code_escapes((String,)) do s + m = SafeRef(s) + GC.@preserve m begin + return nothing + end + end + i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # :isdefined + result = code_escapes((String, Bool, )) do a, b + if b + s = Ref(a) + end + return @isdefined(s) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # ϕ-node + result = code_escapes((Bool,Any,Any)) do cond, a, b + c = cond ? a : b # ϕ(a, b) + return c + end + @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], i) # a + @test has_return_escape(result.state[Argument(4)], i) # b + end + let # π-node + result = code_escapes((Any,)) do a + if isa(a, Regex) # a::π(Regex) + return a + end + return nothing + end + @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) + @test any(findall(isreturn, result.ir.stmts.inst)) do i + has_return_escape(result.state[Argument(2)], i) + end + end + let # φᶜ-node / ϒ-node + result = code_escapes((Any,String)) do a, b + local x::String + try + x = a + catch err + x = b + end + return x + end + @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) + @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], i) + @test has_return_escape(result.state[Argument(3)], i) + end + let # branching + result = code_escapes((Any,Bool,)) do a, c + if c + return nothing # a doesn't escape in this branch + else + return a # a escapes to a caller + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # loop + result = code_escapes((Int,)) do n + c = SafeRef{Bool}(false) + while n > 0 + rand(Bool) && return c + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + let # try/catch + result = code_escapes((Any,)) do a + try + nothing + catch err + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + try + nothing + finally + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # :foreigncall + result = code_escapes((Any,)) do x + ccall(:some_ccall, Any, (Any,), x) + end + @test has_all_escape(result.state[Argument(2)]) + end +end + +let # simple allocation + result = code_escapes((Bool,)) do c + mm = SafeRef{Bool}(c) # just allocated, never escapes + return mm[] ? nothing : 1 + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +@testset "inter-procedural" begin + # FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` + # because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 + # and thus it leads to the following two broken tests + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes() do + broadcast_NoEscape(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape2(b) = broadcast(identity, b) + $code_escapes() do + broadcast_NoEscape2(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline f_GlobalEscape_a(a) = (global globala = a) # obvious escape + $code_escapes() do + f_GlobalEscape_a(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) && has_thrown_escape(result.state[SSAValue(i)]) + end + # if we can't determine the matching method statically, we should be conservative + let result = @eval Module() $code_escapes((Ref{Any},)) do a + may_exist(a) + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes((Ref{Any},)) do a + Base.@invokelatest broadcast_NoEscape(a) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # handling of simple union-split (just exploit the inliner's effort) + let T = Union{Int,Nothing} + result = @eval Module() begin + @noinline unionsplit_NoEscape_a(a) = string(nothing) + @noinline unionsplit_NoEscape_a(a::Int) = a + 10 + $code_escapes(($T,)) do x + s = $SafeRef{$T}(x) + unionsplit_NoEscape_a(s[]) + return nothing + end + end + inds = findall(isT(SafeRef{T}), result.ir.stmts.type) # find allocation statement + @assert !isempty(inds) + for i in inds + @test has_no_escape(result.state[SSAValue(i)]) + end + end + + # appropriate conversion of inter-procedural context + # https://github.com/aviatesk/EscapeAnalysis.jl/issues/7 + let M = Module() + @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); Base.inferencebarrier(nothing)) + + result = @eval M $code_escapes() do + a = Ref("foo") # shouldn't be "return escape" + b = f_NoEscape_a(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + + result = @eval M $code_escapes() do + a = Ref("foo") # still should be "return escape" + b = f_NoEscape_a(a) + return a + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + + # should propagate escape information imposed on return value to the aliased call argument + let result = @eval Module() begin + @noinline f_ReturnEscape_a(a) = (println("prevent inlining"); a) + $code_escapes() do + obj = Ref("foo") # should be "return escape" + ret = f_ReturnEscape_a(obj) + return ret # alias of `obj` + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + @noinline f_NoReturnEscape_a(a) = (println("prevent inlining"); identity("hi")) + $code_escapes() do + obj = Ref("foo") # better to not be "return escape" + ret = f_NoReturnEscape_a(obj) + return ret # must not alias to `obj` + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "builtins" begin + let # throw + r = code_escapes((Any,)) do a + throw(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # implicit throws + r = code_escapes((Any,)) do a + getfield(a, :may_not_field) + end + @test has_thrown_escape(r.state[Argument(2)]) + + r = code_escapes((Any,)) do a + sizeof(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # :=== + result = code_escapes((Bool, String)) do cond, s + m = cond ? SafeRef(s) : nothing + c = m === nothing + return c + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # sizeof + ary = [0,1,2] + result = @eval code_escapes() do + ary = $(QuoteNode(ary)) + sizeof(ary) + end + i = only(findall(isT(Core.Const(ary)), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # ifelse + result = code_escapes((Bool,)) do c + r = ifelse(c, Ref("yes"), Ref("no")) + return r + end + inds = findall(isnew, result.ir.stmts.inst) + @assert !isempty(inds) + for i in inds + @test has_return_escape(result.state[SSAValue(i)]) + end + end + let # ifelse (with constant condition) + result = code_escapes() do + r = ifelse(true, Ref("yes"), Ref(nothing)) + return r + end + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)]) + elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) + @test has_no_escape(result.state[SSAValue(i)]) + end + end + end + + let # typeassert + result = code_escapes((Any,)) do x + y = x::String + return y + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + end + + let # isdefined + result = code_escapes((Any,)) do x + isdefined(x, :foo) ? x : throw("undefined") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + + result = code_escapes((Module,)) do m + isdefined(m, 10) # throws + end + @test has_thrown_escape(result.state[Argument(2)]) + end +end + +@testset "flow-sensitivity" begin + # ReturnEscape + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + if cond + return cond + end + return r + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) + @assert length(rts) == 2 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 + end + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + cnt = 0 + while rand(Bool) + cnt += 1 + rand(Bool) && return r + end + rand(Bool) && return r + return cnt + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) # return statement + @assert length(rts) == 3 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 + end +end + +@testset "escape through exceptions" begin + M = @eval Module() begin + unsafeget(x) = isassigned(x) ? x[] : throw(x) + @noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end + end + @noinline function current_exceptions_escape!() + excs = Base.current_exceptions() + Gx[] = excs + end + const Gx = Ref{Any}() + @__MODULE__ + end + + let # simple: return escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err + ret = err + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + + let # simple: global escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret # prevent DCE + try + s = unsafeget(r) + ret = sizeof(s) + catch err + global g = err + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # account for possible escapes via nested throws + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + throw(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + rethrow(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + rethrow_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + local t + try + r = Ref{String}() + t = unsafeget(r) + catch err + t = typeof(err) + rethrow_escape!() + end + return t + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + Gx[] = Base.current_exceptions() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + current_exceptions_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err + global g = err + end + s2 = unsafeget(r2) + return s2, r2 + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test !has_all_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + end + + # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` + + let # limited propagation: exception is caught within a frame => doesn't escape to a caller + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end + let # sequential: escape information imposed on `err1` and `err2 should propagate separately + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err1 + global g = err1 + end + try + s2 = unsafeget(r2) + ret = sizeof(s2) + catch err2 + ret = err2 + end + return ret + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test_broken !has_all_escape(result.state[SSAValue(i2)]) + end + let # nested: escape information imposed on `inner` shouldn't propagate to `s` + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + try + ret = sizeof(s) + catch inner + return inner + end + catch outer + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + end + let # merge: escape information imposed on `err1` and `err2 should be merged + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err1 + return err1 + end + try + s = unsafeget(r) + ret = sizeof(s) + catch err2 + return err2 + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + rs = findall(isreturn, result.ir.stmts.inst) + @test_broken !has_all_escape(result.state[SSAValue(i)]) + for r in rs + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let # no exception handling: should keep propagating the escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + finally + if !@isdefined(ret) + ret = 42 + end + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "field analysis / alias analysis" begin + # escaped allocations + # ------------------- + + # escaped object should escape its fields as well + let result = code_escapes((Any,)) do a + global g = SafeRef{Any}(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + global g = (a,) + nothing + end + i = only(findall(issubT(Tuple), result.ir.stmts.type)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + o0 = SafeRef{Any}(a) + global g = SafeRef(o0) + nothing + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i0, i1 = is + @test has_all_escape(result.state[SSAValue(i0)]) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + t0 = (a,) + global g = (t0,) + nothing + end + inds = findall(issubT(Tuple), result.ir.stmts.type) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end + @test has_all_escape(result.state[Argument(2)]) + end + # global escape through `setfield!` + let result = code_escapes((Any,)) do a + r = SafeRef{Any}(:init) + global g = r + r[] = a + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + global g = r + r[] = b + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) # a + @test has_all_escape(result.state[Argument(3)]) # b + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + Rx[] = s + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let M = EATModule() + @eval M module ___xxx___ + import ..SafeRef + const Rx = SafeRef("Rx") + end + result = @eval M begin + $code_escapes((String,)) do s + rx = getfield(___xxx___, :Rx) + rx[] = s + nothing + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # field escape + # ------------ + + # field escape should propagate to :new arguments + let result = code_escapes((String,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String,)) do a + t = (a,) + f = t[1] + return f + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String, String)) do a, b + obj = SafeRefs(a, b) + fld1 = obj[1] + fld2 = obj[2] + return (fld1, fld2) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # field escape should propagate to `setfield!` argument + let result = code_escapes((String,)) do a + o = SafeRef("foo") + o[] = a + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # propagate escape information imposed on return value of `setfield!` call + let result = code_escapes((String,)) do a + obj = SafeRef("foo") + return (obj[] = a) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # nested allocations + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + return o2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = (a,) + o2 = (o1,) + return o2[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + o1′ = o2[] + a′ = o1′[] + return a′ + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2 = SafeRef(o1) + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isnew, result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2′ = SafeRef(nothing) + o2 = SafeRef{SafeRef}(o2′) + o2[] = o1 + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + findall(1:length(result.ir.stmts)) do i + if isnew(result.ir.stmts[i][:inst]) + t = result.ir.stmts[i][:type] + return t === SafeRef{String} || # o1 + t === SafeRef{SafeRef} # o2 + end + return false + end |> x->foreach(x) do i + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # ϕ-node allocations + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ = SafeRef{Any}(x) + else + ϕ = SafeRef{Any}(y) + end + return ϕ[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + i = only(findall(isϕ, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = SafeRef{Any}(x) + else + ϕ2 = ϕ1 = SafeRef{Any}(y) + end + return ϕ1[], ϕ2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # when ϕ-node merges values with different types + let result = code_escapes((Bool,String,String,String)) do cond, x, y, z + local out + if cond + ϕ = SafeRef(x) + out = ϕ[] + else + ϕ = SafeRefs(z, y) + end + return @isdefined(out) ? out : throw(ϕ) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test !has_return_escape(result.state[Argument(4)], r) # y + @test has_return_escape(result.state[Argument(5)], r) # z + @test has_thrown_escape(result.state[SSAValue(ϕ)], t) + end + + # alias analysis + # -------------- + + # alias via getfield & Expr(:new) + let result = code_escapes((String,)) do s + r = SafeRef(s) + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((String,)) do s + r = SafeRef(Rx) + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(2)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via getfield & setfield! + let result = code_escapes((String,)) do s + r = Ref{String}() + r[] = s + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref(s) + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + r1[] = s + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r1[] = s + r2[] = r1 + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((SafeRef{String}, String,)) do _rx, s + r = SafeRef(_rx) + r[] = Rx + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(3)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via typeassert + let result = code_escapes((Any,)) do a + r = a::String + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # a + @test isaliased(Argument(2), val, result.state) # a <-> r + end + let result = code_escapes((Any,)) do a + global g + (g::SafeRef{Any})[] = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + # alias via ifelse + let result = code_escapes((Bool,Any,Any)) do c, a, b + r = ifelse(c, a, b) + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # a + @test has_return_escape(result.state[Argument(4)], r) # b + @test !isaliased(Argument(2), val, result.state) # c r + @test isaliased(Argument(3), val, result.state) # a <-> r + @test isaliased(Argument(4), val, result.state) # b <-> r + end + let result = @eval EATModule() begin + const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") + $code_escapes((Bool,String,)) do c, a + r = ifelse(c, Lx, Rx) + r[] = a + nothing + end + end + @test has_all_escape(result.state[Argument(3)]) # a + end + # alias via ϕ-node + let result = code_escapes((Bool,String)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # x + @test isaliased(Argument(3), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x + if cond1 + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + cond2 && (ϕ2[] = x) + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(4)], r) # x + @test isaliased(Argument(4), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # alias via π-node + let result = code_escapes((Any,)) do x + if isa(x, String) + return x + end + throw("error!") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + rval = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # x + @test isaliased(Argument(2), rval, result.state) + end + let result = code_escapes((String,)) do x + global g + l = g + if isa(l, SafeRef{String}) + l[] = x + end + nothing + end + @test has_all_escape(result.state[Argument(2)]) # x + end + + # dynamic semantics + # ----------------- + + # conservatively handle untyped objects + let result = @eval code_escapes((Any,Any,)) do T, x + obj = $(Expr(:new, :T, :x)) + end + t = only(findall(isnew, result.ir.stmts.inst)) + @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T + @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x, :y)) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x)) + setfield!(obj, :x, y) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + + # conservatively handle unknown field: + # all fields should be escaped, but the allocation itself doesn't need to be escaped + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRef(a) + return getfield(obj, fld) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs(a, b) + return getfield(obj, fld) # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs(a, b) + return obj[idx] # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[1] # this should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs("a", "b") + obj[idx] = a + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + + # interprocedural + # --------------- + + let result = @eval EATModule() begin + @noinline getx(obj) = obj[] + $code_escapes((String,)) do a + obj = SafeRef(a) + fld = getx(obj) + return fld + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + end + + # TODO interprocedural field analysis + let result = code_escapes((SafeRef{String},)) do s + s[] = "bar" + global g = s[] + nothing + end + @test_broken !has_all_escape(result.state[Argument(2)]) + end + + # TODO flow-sensitivity? + # ---------------------- + + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(:init) + r[] = a + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any,Bool)) do a, b, cond + r = SafeRef{Any}(:init) + if cond + r[] = a + return r[] + else + r[] = b + return nothing + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + r = only(findall(result.ir.stmts.inst) do @nospecialize x + isreturn(x) && isa(x.val, Core.SSAValue) + end) + @test has_return_escape(result.state[Argument(2)], r) # a + @test_broken !has_return_escape(result.state[Argument(3)], r) # b + end + + # handle conflicting field information correctly + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRef("foo") + else + o = SafeRefs("bar", baz) + r = getfield(o, 2) + end + if cnd + o = o::SafeRef + setfield!(o, 1, qux) + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + for new in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(new)]) + end + end + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRefs("foo", "bar") + r = setfield!(o, 2, baz) + else + o = SafeRef(qux) + end + if !cnd + o = o::SafeRef + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + end + + # foreigncall should disable field analysis + let result = code_escapes((Any,Nothing,Int,UInt)) do t, mt, lim, world + ambig = false + min = Ref{UInt}(typemin(UInt)) + max = Ref{UInt}(typemax(UInt)) + has_ambig = Ref{Int32}(0) + mt = ccall(:jl_matching_methods, Any, + (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ref{Int32}), + t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} + return mt, has_ambig[] + end + for i in findall(isnew, result.ir.stmts.inst) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + end +end + +# demonstrate the power of our field / alias analysis with a realistic end to end example +abstract type AbstractPoint{T} end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute(T, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(100000000-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute(a, b) + for i in 0:(100000000-1) + a = add(add(a, b), b) # unreplaceable, since it can be the call argument + end + a.x, a.y +end +function compute!(a, b) + for i in 0:(100000000-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end +let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(1:length(result.ir.stmts)) do i + isϕ(result.ir.stmts[i][:inst]) && isT(MPoint{ComplexF64})(result.ir.stmts[i][:type]) + end + @test !is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end + +@testset "array primitives" begin + inbounds = Base.JLOptions().check_bounds == 0 + + # arrayref + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(true, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(false, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + inbounds && let result = code_escapes((Vector{String},Int)) do xs, i + s = @inbounds xs[i] + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Bool)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((AbstractVector{String},Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{String},Any)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # arrayset + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(false, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + inbounds && let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + @inbounds xs[i] = x + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((String,String,String,)) do s, t, u + xs = Vector{String}(undef, 3) + Base.arrayset(true, xs, s, 1) + Base.arrayset(true, xs, t, 2) + Base.arrayset(true, xs, u, 3) + return xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + for i in 2:result.state.nargs + @test has_return_escape(result.state[Argument(i)], r) + end + end + let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((String,String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs::String + @test has_thrown_escape(result.state[Argument(3)], t) # x::String + end + let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + + # arrayref and arrayset + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test !has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j + x = SafeRef(s) + xs[i] = x + xs[j] # potential error + end + i = only(findall(isnew, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(3)], t) # s + @test has_thrown_escape(result.state[SSAValue(i)], t) # x + end + + # arraysize + let result = code_escapes((Vector{Any},)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Vector{Any},Int,)) do xs, dim + Core.arraysize(xs, dim) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Any,)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) + end + + # arraylen + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs, 1) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # array resizing + # without BoundsErrors + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + # with possible BoundsErrors + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[3] = x + @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[1] = x + @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + inbounds && let result = code_escapes((String,)) do x + xs = @inbounds Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + + # array copy + let result = code_escapes((Vector{Any},)) do xs + return copy(xs) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test_broken !has_return_escape(result.state[Argument(2)], r) + end + let result = code_escapes((String,)) do s + xs = String[s] + xs′ = copy(xs) + return xs′[1] + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i1)]) + @test !has_return_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Vector{Any},)) do xs + xs′ = copy(xs) + return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + ret = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) + @test_broken !has_return_escape(result.state[SSAValue(i)], ret) + @test has_thrown_escape(result.state[Argument(2)], ref) + @test has_return_escape(result.state[Argument(2)], ret) + end + let result = code_escapes((String,)) do s + xs = Vector{String}(undef, 1) + xs[1] = s + xs′ = copy(xs) + length(xs′) > 2 && throw(xs′) + return xs′ + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) + @test_broken !has_return_escape(result.state[SSAValue(i1)], r) + @test has_thrown_escape(result.state[SSAValue(i2)], t) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test has_thrown_escape(result.state[Argument(2)], t) + @test has_return_escape(result.state[Argument(2)], r) + end + + # isassigned + let result = code_escapes((Vector{Any},Int)) do xs, i + return isassigned(xs, i) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test !has_thrown_escape(result.state[Argument(2)]) + end +end + +# demonstrate array primitive support with a realistic end to end example +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + push!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + pushfirst!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) # xs + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((String,String,String)) do s, t, u + xs = String[] + resize!(xs, 3) + xs[1] = s + xs[1] = t + xs[1] = u + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + @test has_return_escape(result.state[Argument(4)], r) # u +end + +@static if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "ImmutableArray" begin + # arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # mutating_arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + mutating_arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # arraythaw + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray,)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector{Any},)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = ImmutableArray(Any[]) + arraythaw(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end +end + +# demonstrate some arrayfreeze optimizations +# !has_return_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization +let result = code_escapes((Int,)) do n + xs = collect(1:n) + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Float64},)) do xs + ys = sin.(xs) + ImmutableArray(ys) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Pair{Int,String}},)) do xs + n = maximum(first, xs) + ys = Vector{String}(undef, n) + for (i, s) in xs + ys[i] = s + end + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end + +end # @static if isdefined(Core, :ImmutableArray) + +# demonstrate a simple type level analysis can sometimes improve the analysis accuracy +# by compensating the lack of yet unimplemented analyses +@testset "special-casing bitstype" begin + let result = code_escapes((Nothing,)) do a + global bb = a + end + @test !(has_all_escape(result.state[Argument(2)])) + end + + let result = code_escapes((Int,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end + + # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) + let result = code_escapes((Int,Any,)) do a, b + t = tuple(a, b) + return t + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test has_return_escape(result.state[Argument(3)], r) + end +end + +# # TODO implement a finalizer elision pass +# mutable struct WithFinalizer +# v +# function WithFinalizer(v) +# x = new(v) +# f(t) = @async println("Finalizing $t.") +# return finalizer(x, x) +# end +# end +# make_m(v = 10) = MyMutable(v) +# function simple(cond) +# m = make_m() +# if cond +# # println(m.v) +# return nothing # <= insert `finalize` call here +# end +# return m +# end + +@static @isdefined(EA_AS_PKG) && @testset "code quality" begin + using JET + + # assert that our main routine are free from (unnecessary) runtime dispatches + + function function_filter(@nospecialize(ft)) + ft === typeof(Core.Compiler.widenconst) && return false # `widenconst` is very untyped, ignore + ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore + return true + end + target_modules = (EscapeAnalysis,) + test_opt(only(methods(EscapeAnalysis.analyze_escapes)).sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + for m in methods(EscapeAnalysis.escape_builtin!) + Base._methods_by_ftype(m.sig, 1, Base.get_world_counter()) === false && continue + test_opt(m.sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + end +end diff --git a/test/compiler/EscapeAnalysis/setup.jl b/test/compiler/EscapeAnalysis/setup.jl new file mode 100644 index 0000000000000..4ae10529c4150 --- /dev/null +++ b/test/compiler/EscapeAnalysis/setup.jl @@ -0,0 +1,89 @@ +using Test +if @isdefined(EA_AS_PKG) + import EscapeAnalysis: code_escapes, @code_escapes + using EscapeAnalysis +else + using Core.Compiler.EscapeAnalysis + import Base: code_escapes + import InteractiveUtils: @code_escapes +end +import Core: Argument, SSAValue, ReturnNode + +@static if isdefined(Core.Compiler, :alloc_array_ndims) + import Core.Compiler: alloc_array_ndims +else + function alloc_array_ndims(name::Symbol) + if name === :jl_alloc_array_1d + return 1 + elseif name === :jl_alloc_array_2d + return 2 + elseif name === :jl_alloc_array_3d + return 3 + elseif name === :jl_new_array + return 0 + end + return nothing + end +end + +isT(T) = (@nospecialize x) -> x === T +issubT(T) = (@nospecialize x) -> x <: T +isreturn(@nospecialize x) = isa(x, Core.ReturnNode) && isdefined(x, :val) +isthrow(@nospecialize x) = Meta.isexpr(x, :call) && Core.Compiler.is_throw_call(x) +isnew(@nospecialize x) = Meta.isexpr(x, :new) +isϕ(@nospecialize x) = isa(x, Core.PhiNode) +function with_normalized_name(@nospecialize(f), @nospecialize(x)) + if Meta.isexpr(x, :foreigncall) + name = x.args[1] + nn = EscapeAnalysis.normalize(name) + return isa(nn, Symbol) && f(nn) + end + return false +end +isarrayalloc(@nospecialize x) = with_normalized_name(nn->!isnothing(alloc_array_ndims(nn)), x) +isarrayresize(@nospecialize x) = with_normalized_name(nn->!isnothing(EscapeAnalysis.array_resize_info(nn)), x) +isarraycopy(@nospecialize x) = with_normalized_name(nn->EscapeAnalysis.is_array_copy(nn), x) +import Core.Compiler: argextype, singleton_type +const EMPTY_SPTYPES = Any[] +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((ir, f), @nospecialize(x)) + return iscall(x) do @nospecialize x + Core.Compiler.singleton_type(Core.Compiler.argextype(x, ir, EMPTY_SPTYPES)) === f + end +end +iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) + +""" + is_load_forwardable(x::EscapeInfo) -> Bool + +Queries if `x` is elibigle for store-to-load forwarding optimization. +""" +function is_load_forwardable(x::EscapeAnalysis.EscapeInfo) + AliasInfo = x.AliasInfo + AliasInfo === false && return true # allows this query to work for immutables since we don't impose escape on them + # NOTE technically we also need to check `!has_thrown_escape(x)` here as well, + # but we can also do equivalent check during forwarding + return isa(AliasInfo, EscapeAnalysis.Indexable) && !AliasInfo.array +end + +let setup_ex = quote + mutable struct SafeRef{T} + x::T + end + Base.getindex(s::SafeRef) = getfield(s, 1) + Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + + mutable struct SafeRefs{S,T} + x1::S + x2::T + end + Base.getindex(s::SafeRefs, idx::Int) = getfield(s, idx) + Base.setindex!(s::SafeRefs, x, idx::Int) = setfield!(s, idx, x) + end + global function EATModule(setup_ex = setup_ex) + M = Module() + Core.eval(M, setup_ex) + return M + end + Core.eval(@__MODULE__, setup_ex) +end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index f355a42d7b08c..cbbf4375541e1 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -407,7 +407,7 @@ let m = Meta.@lower 1 + 1 src.ssaflags = fill(Int32(0), nstmts) ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any]) @test Core.Compiler.verify_ir(ir) === nothing - ir = @test_nowarn Core.Compiler.sroa_pass!(ir) + ir = @test_nowarn Core.Compiler.sroa_pass!(ir, 0) @test Core.Compiler.verify_ir(ir) === nothing end