From a93074c3dd23d2b19eee43c2f765f07ff0cd6d40 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 8 Jan 2024 00:32:15 +0000 Subject: [PATCH] WIP: Allow CodeInstance in Expr(:invoke) This is a quick experiment to see what it would look like to switch Expr(:invoke) to use CodeInstance rather than MethodInstance. There is some unresolved semantic questions here about whether this is a good idea or if there's some other representation that's better, but discussion might be easier with an implementation. The larger context here is the question of whether and how to do more general method specialization. I had written some thoughts in [1], but as mentioned there remains ongoing discussion of whether this is the correct direction or not. [1] https://hackmd.io/@Og2_pcUySm6R_RPqbZ06JA/S1bqP1D_6 --- .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 3 +- base/compiler/ssair/inlining.jl | 59 ++++-- base/compiler/ssair/irinterp.jl | 4 +- base/compiler/ssair/passes.jl | 12 +- base/compiler/ssair/show.jl | 3 +- src/codegen.cpp | 169 ++++++++++-------- src/interpreter.c | 13 +- test/compiler/inline.jl | 12 +- test/compiler/irutils.jl | 4 +- 9 files changed, 166 insertions(+), 113 deletions(-) diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl index c1f3cf3f97885..338120e2249b4 100644 --- a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -1060,7 +1060,8 @@ end # escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)` function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}) - mi = first(args)::MethodInstance + arg1 = first(args) + mi = isa(arg1, Core.CodeInstance) ? arg1.def : arg1::MethodInstance first_idx, last_idx = 2, length(args) # TODO inspect `astate.ir.stmts[pc][:info]` and use const-prop'ed `InferenceResult` if available cache = astate.get_escape_cache(mi) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 650af8248883c..37a0110cfe6d8 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -33,7 +33,7 @@ struct SomeCase end struct InvokeCase - invoke::MethodInstance + invoke::Union{MethodInstance, CodeInstance} effects::Effects info::CallInfo end @@ -831,6 +831,17 @@ function compileable_specialization(mi::MethodInstance, effects::Effects, return InvokeCase(mi_invoke, effects, info) end +function compileable_specialization(ci::CodeInstance, effects::Effects, + et::InliningEdgeTracker, @nospecialize(info::CallInfo); compilesig_invokes::Bool=true) + if compilesig_invokes ? !isa_compileable_sig(ci.def.specTypes, ci.def.sparam_vals, ci.def.def) : + any(@nospecialize(t)->isa(t, TypeVar), ci.def.sparam_vals) + return compileable_specialization(ci.def, effects, et, info; compilesig_invokes) + end + add_inlining_backedge!(et, ci.def) # to the dispatch lookup + push!(et.edges, ci.def.def.sig, ci.def) # add_inlining_backedge to the invoke call + return InvokeCase(ci, effects, info) +end + function compileable_specialization(match::MethodMatch, effects::Effects, et::InliningEdgeTracker, @nospecialize(info::CallInfo); compilesig_invokes::Bool=true) mi = specialize_method(match) @@ -838,22 +849,26 @@ function compileable_specialization(match::MethodMatch, effects::Effects, end struct InferredResult + ci::Union{Nothing, CodeInstance} src::Any effects::Effects - InferredResult(@nospecialize(src), effects::Effects) = new(src, effects) + InferredResult(ci::Union{Nothing, CodeInstance}, @nospecialize(src), effects::Effects) = new(ci, src, effects) +end +@inline function get_cached_result(code::CodeInstance) + if use_const_api(code) + # in this case function can be inlined to a constant + return ConstantCase(quoted(code.rettype_const)) + end + src = @atomic :monotonic code.inferred + effects = decode_effects(code.ipo_purity_bits) + return InferredResult(code, src, effects) end @inline function get_cached_result(state::InliningState, mi::MethodInstance) code = get(code_cache(state), mi, nothing) if code isa CodeInstance - if use_const_api(code) - # in this case function can be inlined to a constant - return ConstantCase(quoted(code.rettype_const)) - end - src = @atomic :monotonic code.inferred - effects = decode_effects(code.ipo_purity_bits) - return InferredResult(src, effects) + return get_cached_result(code) end - return InferredResult(nothing, Effects()) + return InferredResult(nothing, nothing, Effects()) end @inline function get_local_result(inf_result::InferenceResult) effects = inf_result.ipo_effects @@ -864,11 +879,11 @@ end return ConstantCase(quoted(res.val)) end end - return InferredResult(inf_result.src, effects) + return InferredResult(nothing, inf_result.src, effects) end # the general resolver for usual and const-prop'ed calls -function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult}, +function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult,CodeInstance}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState; invokesig::Union{Nothing,Vector{Any}}=nothing) et = InliningEdgeTracker(state, invokesig) @@ -887,17 +902,18 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult, add_inlining_backedge!(et, mi) return inferred_result end - (; src, effects) = inferred_result + (; src, ci, effects) = inferred_result + invoke_target = ci === nothing ? mi : ci # the duplicated check might have been done already within `analyze_method!`, but still # we need it here too since we may come here directly using a constant-prop' result if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag) - return compileable_specialization(mi, effects, et, info; + return compileable_specialization(invoke_target, effects, et, info; compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes) end src = inlining_policy(state.interp, src, info, flag) - src === nothing && return compileable_specialization(mi, effects, et, info; + src === nothing && return compileable_specialization(invoke_target, effects, et, info; compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes) add_inlining_backedge!(et, mi) @@ -906,7 +922,7 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult, end # the special resolver for :invoke-d call -function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::UInt32, +function resolve_todo(mici::Union{MethodInstance, CodeInstance}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState) if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag) return nothing @@ -914,7 +930,13 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U et = InliningEdgeTracker(state) - cached_result = get_cached_result(state, mi) + if isa(mici, CodeInstance) + cached_result = get_cached_result(mici) + mi = mici.def + else + cached_result = get_cached_result(state, mici) + mi = mici + end if cached_result isa ConstantCase add_inlining_backedge!(et, mi) return cached_result @@ -1671,8 +1693,7 @@ end function handle_invoke_expr!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature, state::InliningState) - mi = stmt.args[1]::MethodInstance - case = resolve_todo(mi, info, flag, state) + case = resolve_todo(stmt.args[1], info, flag, state) handle_single_case!(todo, ir, idx, stmt, case, false) return nothing end diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index 7e454612a4eb6..bb3f534468940 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -158,7 +158,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction, (; rt, effects) = abstract_eval_statement_expr(interp, stmt, nothing, irsv) add_flag!(inst, flags_for_effects(effects)) elseif head === :invoke - rt, (nothrow, noub) = concrete_eval_invoke(interp, stmt, stmt.args[1]::MethodInstance, irsv) + arg1 = stmt.args[1] + mi = isa(arg1, CodeInstance) ? arg1.def : arg1::MethodInstance + rt, (nothrow, noub) = concrete_eval_invoke(interp, stmt, mi, irsv) if nothrow add_flag!(inst, IR_FLAG_NOTHROW) end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 62dd4305243ec..98a61d7f12527 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1460,9 +1460,15 @@ end # NOTE we resolve the inlining source here as we don't want to serialize `Core.Compiler` # data structure into the global cache (see the comment in `handle_finalizer_call!`) function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int, - mi::MethodInstance, @nospecialize(info::CallInfo), inlining::InliningState, + code_or_mi::Union{MethodInstance, CodeInstance}, @nospecialize(info::CallInfo), inlining::InliningState, attach_after::Bool) - code = get(code_cache(inlining), mi, nothing) + if isa(code_or_mi, CodeInstance) + code = code_or_mi + mi = code.def + else + mi = code_or_mi + code = get(code_cache(inlining), mi, nothing) + end et = InliningEdgeTracker(inlining) if code isa CodeInstance if use_const_api(code) @@ -1629,7 +1635,7 @@ function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse if inline === nothing # No code in the function - Nothing to do else - mi = finalizer_stmt.args[5]::MethodInstance + mi = finalizer_stmt.args[5]::Union{MethodInstance, CodeInstance} if inline::Bool && try_inline_finalizer!(ir, argexprs, loc, mi, info, inlining, attach_after) # the finalizer body has been inlined else diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index 3936a82a6560e..0083f7058a3fc 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -52,7 +52,8 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng stmt = stmt::Expr # TODO: why is this here, and not in Base.show_unquoted print(io, "invoke ") - linfo = stmt.args[1]::Core.MethodInstance + arg1 = stmt.args[1] + linfo = isa(arg1, Core.CodeInstance) ? arg1.def : arg1::Core.MethodInstance show_unquoted(io, stmt.args[2], indent) print(io, "(") # XXX: this is wrong if `sig` is not a concretetype method diff --git a/src/codegen.cpp b/src/codegen.cpp index ca96e1bcc5545..7913f5a9447fc 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4639,95 +4639,106 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const bool handled = false; jl_cgval_t result; if (lival.constant) { - jl_method_instance_t *mi = (jl_method_instance_t*)lival.constant; - assert(jl_is_method_instance(mi)); - if (mi == ctx.linfo) { - // handle self-recursion specially - jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed; - FunctionType *ft = ctx.f->getFunctionType(); - StringRef protoname = ctx.f->getName(); - if (ft == ctx.types().T_jlfunc) { - result = emit_call_specfun_boxed(ctx, ctx.rettype, protoname, nullptr, argv, nargs, rt); - handled = true; - } - else if (ft != ctx.types().T_jlfuncparams) { - unsigned return_roots = 0; - result = emit_call_specfun_other(ctx, mi, ctx.rettype, protoname, nullptr, argv, nargs, &cc, &return_roots, rt); - handled = true; - } - } - else { - jl_value_t *ci = ctx.params->lookup(mi, ctx.world, ctx.world); // TODO: need to use the right pair world here - if (ci != jl_nothing) { - jl_code_instance_t *codeinst = (jl_code_instance_t*)ci; - auto invoke = jl_atomic_load_acquire(&codeinst->invoke); - // check if we know how to handle this specptr - if (invoke == jl_fptr_const_return_addr) { - result = mark_julia_const(ctx, codeinst->rettype_const); + jl_code_instance_t *codeinst = NULL; + jl_method_instance_t *mi = NULL; + if (jl_is_method_instance(lival.constant)) { + mi = (jl_method_instance_t*)lival.constant; + assert(jl_is_method_instance(mi)); + if (mi == ctx.linfo) { + // handle self-recursion specially + jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed; + FunctionType *ft = ctx.f->getFunctionType(); + StringRef protoname = ctx.f->getName(); + if (ft == ctx.types().T_jlfunc) { + result = emit_call_specfun_boxed(ctx, ctx.rettype, protoname, nullptr, argv, nargs, rt); handled = true; } - else if (invoke != jl_fptr_sparam_addr) { - bool specsig, needsparams; - std::tie(specsig, needsparams) = uses_specsig(mi, codeinst->rettype, ctx.params->prefer_specsig); - std::string name; - StringRef protoname; - bool need_to_emit = true; - bool cache_valid = ctx.use_cache || ctx.external_linkage; - bool external = false; - - // Check if we already queued this up - auto it = ctx.call_targets.find(codeinst); - if (need_to_emit && it != ctx.call_targets.end()) { - protoname = it->second.decl->getName(); - need_to_emit = cache_valid = false; - } + else if (ft != ctx.types().T_jlfuncparams) { + unsigned return_roots = 0; + result = emit_call_specfun_other(ctx, mi, ctx.rettype, protoname, nullptr, argv, nargs, &cc, &return_roots, rt); + handled = true; + } + goto done; + } else { + jl_value_t *ci = ctx.params->lookup(mi, ctx.world, ctx.world); // TODO: need to use the right pair world here + if (ci == jl_nothing) + goto done; + codeinst = (jl_code_instance_t*)ci; + } + } else { + assert(jl_is_code_instance(lival.constant)); + codeinst = (jl_code_instance_t*)lival.constant; + // TODO: Separate copy of the callsig in the CodeInstance + mi = codeinst->def; + } - // Check if it is already compiled (either JIT or externally) - if (cache_valid) { - // optimization: emit the correct name immediately, if we know it - // TODO: use `emitted` map here too to try to consolidate names? - // WARNING: isspecsig is protected by the codegen-lock. If that lock is removed, then the isspecsig load needs to be properly atomically sequenced with this. - auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr); - if (fptr) { - while (!(jl_atomic_load_acquire(&codeinst->specsigflags) & 0b10)) { - jl_cpu_pause(); - } - invoke = jl_atomic_load_relaxed(&codeinst->invoke); - if (specsig ? jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b1 : invoke == jl_fptr_args_addr) { - protoname = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst); - if (ctx.external_linkage) { - // TODO: Add !specsig support to aotcompile.cpp - // Check that the codeinst is containing native code - if (specsig && jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b100) { - external = true; - need_to_emit = false; - } - } - else { // ctx.use_cache - need_to_emit = false; - } + auto invoke = jl_atomic_load_acquire(&codeinst->invoke); + // check if we know how to handle this specptr + if (invoke == jl_fptr_const_return_addr) { + result = mark_julia_const(ctx, codeinst->rettype_const); + handled = true; + } + else if (invoke != jl_fptr_sparam_addr) { + bool specsig, needsparams; + std::tie(specsig, needsparams) = uses_specsig(mi, codeinst->rettype, ctx.params->prefer_specsig); + std::string name; + StringRef protoname; + bool need_to_emit = true; + bool cache_valid = ctx.use_cache || ctx.external_linkage; + bool external = false; + + // Check if we already queued this up + auto it = ctx.call_targets.find(codeinst); + if (need_to_emit && it != ctx.call_targets.end()) { + protoname = it->second.decl->getName(); + need_to_emit = cache_valid = false; + } + + // Check if it is already compiled (either JIT or externally) + if (cache_valid) { + // optimization: emit the correct name immediately, if we know it + // TODO: use `emitted` map here too to try to consolidate names? + // WARNING: isspecsig is protected by the codegen-lock. If that lock is removed, then the isspecsig load needs to be properly atomically sequenced with this. + auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr); + if (fptr) { + while (!(jl_atomic_load_acquire(&codeinst->specsigflags) & 0b10)) { + jl_cpu_pause(); + } + invoke = jl_atomic_load_relaxed(&codeinst->invoke); + if (specsig ? jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b1 : invoke == jl_fptr_args_addr) { + protoname = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst); + if (ctx.external_linkage) { + // TODO: Add !specsig support to aotcompile.cpp + // Check that the codeinst is containing native code + if (specsig && jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b100) { + external = true; + need_to_emit = false; } } - } - if (need_to_emit) { - raw_string_ostream(name) << (specsig ? "j_" : "j1_") << name_from_method_instance(mi) << "_" << jl_atomic_fetch_add(&globalUniqueGeneratedNames, 1); - protoname = StringRef(name); - } - jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed; - unsigned return_roots = 0; - if (specsig) - result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt); - else - result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, rt); - handled = true; - if (need_to_emit) { - Function *trampoline_decl = cast(jl_Module->getNamedValue(protoname)); - ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig}; + else { // ctx.use_cache + need_to_emit = false; + } } } } + if (need_to_emit) { + raw_string_ostream(name) << (specsig ? "j_" : "j1_") << name_from_method_instance(mi) << "_" << jl_atomic_fetch_add(&globalUniqueGeneratedNames, 1); + protoname = StringRef(name); + } + jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed; + unsigned return_roots = 0; + if (specsig) + result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt); + else + result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, rt); + handled = true; + if (need_to_emit) { + Function *trampoline_decl = cast(jl_Module->getNamedValue(protoname)); + ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig}; + } } } +done: if (!handled) { Value *r = emit_jlcall(ctx, jlinvoke_func, boxed(ctx, lival), argv, nargs, julia_call2); result = mark_julia_type(ctx, r, true, rt); diff --git a/src/interpreter.c b/src/interpreter.c index 5102d1417c939..8271080bbe610 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -136,9 +136,16 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state size_t i; for (i = 1; i < nargs; i++) argv[i] = eval_value(args[i], s); - jl_method_instance_t *meth = (jl_method_instance_t*)args[0]; - assert(jl_is_method_instance(meth)); - jl_value_t *result = jl_invoke(argv[1], &argv[2], nargs - 2, meth); + jl_value_t *arg0 = (jl_method_instance_t*)args[0]; + if (jl_is_code_instance(arg0)) { + jl_code_instance_t *codeinst = (jl_code_instance_t*)arg0; + jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke); + if (invoke != NULL) { + return invoke(argv[1], &argv[2], nargs - 2, codeinst); + } + } + assert(jl_is_method_instance(arg0)); + jl_value_t *result = jl_invoke(argv[1], &argv[2], nargs - 2, (jl_method_instance_t*)arg0); JL_GC_POP(); return result; } diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 9e58f23fd755c..b9dce078f7cb4 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -466,7 +466,8 @@ end @test all(code) do @nospecialize x !isinvoke(:simple_caller, x) && !isinvoke(x) do mi - startswith(string(mi.def.name), '#') + m = isa(mi, CodeInstance) ? mi.def.def : mi.def + startswith(string(m.name), '#') end end end @@ -474,7 +475,8 @@ end # the anonymous function that the `do` block created shouldn't be inlined here @test any(code) do @nospecialize x isinvoke(x) do mi - startswith(string(mi.def.name), '#') + m = isa(mi, CodeInstance) ? mi.def.def : mi.def + startswith(string(m.name), '#') end end end @@ -722,8 +724,10 @@ mktempdir() do dir let ci, rt = only(code_typed(issue42246)) if any(ci.code) do stmt - Meta.isexpr(stmt, :invoke) && - stmt.args[1].def.name === nameof(IOBuffer) + Meta.isexpr(stmt, :invoke) || return false + mici = stmt.args[1] + m = isa(mici, Core.CodeInstance) ? mici.def.def : mici.def + m.name === nameof(IOBuffer) end exit(0) else diff --git a/test/compiler/irutils.jl b/test/compiler/irutils.jl index 788d7bbc721ee..b954a03199450 100644 --- a/test/compiler/irutils.jl +++ b/test/compiler/irutils.jl @@ -35,8 +35,8 @@ end # check if `x` is a statically-resolved call of a function whose name is `sym` isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) -isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) -isinvoke(pred::Function, @nospecialize(x)) = isexpr(x, :invoke) && pred(x.args[1]::MethodInstance) +isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->(isa(mi, CodeInstance) ? mi.def.def : mi.def).name===sym, x) +isinvoke(pred::Function, @nospecialize(x)) = isexpr(x, :invoke) && pred(x.args[1]::Union{MethodInstance, CodeInstance}) fully_eliminated(@nospecialize args...; retval=(@__FILE__), kwargs...) = fully_eliminated(code_typed1(args...; kwargs...); retval)