Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Refactor MethodInstance to allow for more general specialization #54373

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ eval(Core, quote
PartialOpaque(@nospecialize(typ), @nospecialize(env), parent::MethodInstance, source) = $(Expr(:new, :PartialOpaque, :typ, :env, :parent, :source))
InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :thentype, :elsetype))
MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))
DefaultSpec(sparam_vals::SimpleVector, inInference::Bool, cache_with_orig::Bool, precompiled::Bool) =
$(Expr(:new, DefaultSpec, :sparam_vals, :inInference, :cache_with_orig, :precompiled))
end)

const NullDebugInfo = DebugInfo(:none)
Expand All @@ -499,15 +501,14 @@ struct LineInfoNode # legacy support for aiding Serializer.deserialize of old IR
LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int32, inlined_at::Int32) = new(mod, method, file, line, inlined_at)
end


function CodeInstance(
mi::MethodInstance, owner, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
mi::MethodSpecialization, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
ipo_effects::UInt32, effects::UInt32, @nospecialize(analysis_results),
relocatability::UInt8, edges::DebugInfo)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8, Any),
mi, owner, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8, Any),
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
ipo_effects, effects, analysis_results, relocatability, edges)
end
GlobalRef(m::Module, s::Symbol) = ccall(:jl_module_globalref, Ref{GlobalRef}, (Any, Any), m, s)
Expand Down Expand Up @@ -647,12 +648,12 @@ Symbol(s::Symbol) = s
# module providing the IR object model
module IR

export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
export CodeInfo, MethodSpecialization, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, SlotNumber, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, DebugInfo,
Const, PartialStruct, InterConditional, EnterNode

using Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
using Core: CodeInfo, MethodSpecialization, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, SlotNumber, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, DebugInfo,
Const, PartialStruct, InterConditional, EnterNode
Expand Down Expand Up @@ -1004,6 +1005,9 @@ const check_top_bit = check_sign_bit
EnterNode(old::EnterNode, new_dest::Int) = isdefined(old, :scope) ?
EnterNode(new_dest, old.scope) : EnterNode(new_dest)

eval(Core, :((MS::Type{<:MethodSpecialization})(def::Union{Method, Module, MethodSpecialization}, abi::Type{<:Tuple}) =
$(Expr(:new, :MS, :def, :abi))))

include(Core, "optimized_generics.jl")

ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)
4 changes: 2 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2299,9 +2299,9 @@ function sp_type_rewrap(@nospecialize(T), mi::MethodInstance, isreturn::Bool)
if isa(mi.def, Method)
spsig = mi.def.sig
if isa(spsig, UnionAll)
if !isempty(mi.sparam_vals)
if !isempty(mi.data.sparam_vals)
sparam_vals = Any[isvarargtype(v) ? TypeVar(:N, Union{}, Any) :
v for v in mi.sparam_vals]
v for v in mi.data.sparam_vals]
T = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), T, spsig, sparam_vals)
isref && isreturn && T === Any && return Bottom # catch invalid return Ref{T} where T = Any
for v in sparam_vals
Expand Down
46 changes: 37 additions & 9 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,35 @@
"""
struct InternalCodeCache

Internally, each `MethodInstance` keep a unique global cache of code instances
that have been created for the given method instance, stratified by world age
ranges. This struct abstracts over access to this cache.
The internal code cache is keyed on type specializations, represented by
MethodSpecialization{DefaultSpec} aka MethodInstance. External abstract
interpreters may use this same structure by using a different `Spec` for their
`MethodSpecialization{Spec}`. `InternalCodeCache` will match such specializations
by type. Additionally, it is possible to specialize methods on properties other
than types, but this requires custom caching logic. `InternalCodeCache` currently
only supports type-based specialization.
"""
struct InternalCodeCache
owner::Any # `jl_egal` is used for comparison
mitype::DataType # <: MethodSpecialization, but stored as DataType for efficient ===
InternalCodeCache(T::Type{<:MethodSpecialization}) =
new(T)
end
InternalCodeCache() = InternalCodeCache(MethodInstance)

function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
@assert ci.owner === cache.owner
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
ms::MethodSpecialization = mi
while typeof(ms) !== cache.mitype
if !isdefined(ms, :next)
# No specialization for this spec. Try to allocate it now.
newms = cache.mitype(mi.def, mi.specTypes)
if @atomiconce :sequentially_consistent (ms.next = newms)
ms = newms
break
end
end
ms = @atomic :acquire ms.next
end
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), ms, ci)
return cache
end

Expand Down Expand Up @@ -50,11 +68,21 @@ WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
return ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.cache.owner, mi, first(wvc.worlds), last(wvc.worlds)) !== nothing
ms::MethodSpecialization = mi
while typeof(ms) !== wvc.cache.mitype
isdefined(ms, :next) || return false
ms = ms.next
end
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), ms, first(wvc.worlds), last(wvc.worlds)) !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.cache.owner, mi, first(wvc.worlds), last(wvc.worlds))
ms::MethodSpecialization = mi
while typeof(ms) !== wvc.cache.mitype
isdefined(ms, :next) || return default
ms = ms.next
end
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), ms, first(wvc.worlds), last(wvc.worlds))
if r === nothing
return default
end
Expand All @@ -73,7 +101,7 @@ function setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::Meth
end

function code_cache(interp::AbstractInterpreter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this now be code_cache(interp:: NativeInterpreter, ...)? Otherwise foreign interp will silently leak information into the native cache.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can force external absint to explicitly declare their cache

cache = InternalCodeCache(cache_owner(interp))
cache = InternalCodeCache()
worlds = WorldRange(get_inference_world(interp))
return WorldView(cache, worlds)
end
2 changes: 2 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
MethodSpecialization,
MethodInstance, CodeInstance, MethodTable, MethodMatch, PartialOpaque,
TypeofVararg

const getproperty = Core.getfield
const setproperty! = Core.setfield!
const setpropertyonce! = Core.setfieldonce!
const swapproperty! = Core.swapfield!
const modifyproperty! = Core.modifyfield!
const replaceproperty! = Core.replacefield!
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ function sptypes_from_meth_instance(mi::MethodInstance)
def = mi.def
isa(def, Method) || return EMPTY_SPTYPES # toplevel
sig = def.sig
if isempty(mi.sparam_vals)
if isempty(mi.data.sparam_vals)
isa(sig, UnionAll) || return EMPTY_SPTYPES
# mi is unspecialized
spvals = Any[]
Expand All @@ -701,7 +701,7 @@ function sptypes_from_meth_instance(mi::MethodInstance)
sig′ = sig′.body
end
else
spvals = mi.sparam_vals
spvals = mi.data.sparam_vals
end
nvals = length(spvals)
sptypes = Vector{VarState}(undef, nvals)
Expand Down
16 changes: 8 additions & 8 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCod
insert_node!(NewInstruction(Expr(:code_coverage_effect), Nothing, topline))
end
spvals_ssa = nothing
if !validate_sparams(mi.sparam_vals)
if !validate_sparams(mi.data.sparam_vals)
# N.B. This works on the caller-side argexprs, (i.e. before the va fixup below)
spvals_ssa = insert_node!(
removable_if_unused(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
Expand Down Expand Up @@ -775,7 +775,7 @@ end
function compileable_specialization(mi::MethodInstance, effects::Effects,
et::InliningEdgeTracker, @nospecialize(info::CallInfo); compilesig_invokes::Bool=true)
mi_invoke = mi
method, atype, sparams = mi.def::Method, mi.specTypes, mi.sparam_vals
method, atype, sparams = mi.def::Method, mi.specTypes, mi.data.sparam_vals
if compilesig_invokes
new_atype = get_compileable_sig(method, atype, sparams)
new_atype === nothing && return nothing
Expand All @@ -790,7 +790,7 @@ function compileable_specialization(mi::MethodInstance, effects::Effects,
# If this caller does not want us to optimize calls to use their
# declared compilesig, then it is also likely they would handle sparams
# incorrectly if there were any unknown typevars, so we conservatively return nothing
if any(@nospecialize(t)->isa(t, TypeVar), mi.sparam_vals)
if any(@nospecialize(t)->isa(t, TypeVar), mi.data.sparam_vals)
return nothing
end
end
Expand Down Expand Up @@ -1173,7 +1173,7 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
argtypes = invoke_rewrite(sig.argtypes)
if isa(result, ConstPropResult)
mi = result.result.linfo
validate_sparams(mi.sparam_vals) || return nothing
validate_sparams(mi.data.sparam_vals) || return nothing
if Union{} !== argtypes_to_type(argtypes) <: mi.def.sig
item = resolve_todo(mi, result.result, info, flag, state; invokesig)
handle_single_case!(todo, ir, idx, stmt, item, true)
Expand Down Expand Up @@ -1430,7 +1430,7 @@ function handle_const_prop_result!(cases::Vector{InliningCase}, result::ConstPro
allow_typevars::Bool)
mi = result.result.linfo
spec_types = match.spec_types
if !validate_sparams(mi.sparam_vals)
if !validate_sparams(mi.data.sparam_vals)
(allow_typevars && !may_have_fcalls(mi.def::Method)) || return false
end
item = resolve_todo(mi, result.result, info, flag, state)
Expand Down Expand Up @@ -1466,7 +1466,7 @@ function handle_semi_concrete_result!(cases::Vector{InliningCase}, result::SemiC
match::MethodMatch, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState)
mi = result.mi
spec_types = match.spec_types
validate_sparams(mi.sparam_vals) || return false
validate_sparams(mi.data.sparam_vals) || return false
item = semiconcrete_result_item(result, info, flag, state)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
Expand Down Expand Up @@ -1521,7 +1521,7 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}},
result = info.result
if isa(result, ConstPropResult)
mi = result.result.linfo
validate_sparams(mi.sparam_vals) || return nothing
validate_sparams(mi.data.sparam_vals) || return nothing
item = resolve_todo(mi, result.result, info, flag, state)
elseif isa(result, ConcreteResult)
item = concrete_result_item(result, info, state)
Expand Down Expand Up @@ -1793,7 +1793,7 @@ function ssa_substitute_op!(insert_node!::Inserter, subst_inst::Instruction, @no
if isa(val, Expr)
e = val::Expr
head = e.head
sparam_vals = ssa_substitute.mi.sparam_vals
sparam_vals = ssa_substitute.mi.data.sparam_vals
if head === :static_parameter
spidx = e.args[1]::Int
val = sparam_vals[spidx]
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function CodeInstance(interp::AbstractInterpreter, result::InferenceResult;
if !@isdefined edges
edges = DebugInfo(result.linfo)
end
return CodeInstance(result.linfo, owner,
return CodeInstance(result.linfo,
widenconst(result_type), widenconst(result.exc_result), rettype_const, inferred_result,
const_flags, first(result.valid_worlds), last(result.valid_worlds),
# TODO: Actually do something with non-IPO effects
Expand All @@ -348,7 +348,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance
isa(def, Method) || return ci # don't compress toplevel code
cache_the_tree = true
if can_discard_trees
cache_the_tree = is_inlineable(ci) || isa_compileable_sig(mi.specTypes, mi.sparam_vals, def)
cache_the_tree = is_inlineable(ci) || isa_compileable_sig(mi.specTypes, mi.data.sparam_vals, def)
end
if cache_the_tree
if may_compress(interp)
Expand Down Expand Up @@ -935,7 +935,7 @@ more details.
"""
function codeinstance_for_const_with_code(interp::AbstractInterpreter, code::CodeInstance)
src = codeinfo_for_const(interp, code.def, code.rettype_const)
return CodeInstance(code.def, cache_owner(interp), code.rettype, code.exctype, code.rettype_const, src,
return CodeInstance(code.def, code.rettype, code.exctype, code.rettype_const, src,
Int32(0x3), code.min_world, code.max_world,
code.ipo_purity_bits, code.purity_bits, code.analysis_results,
code.relocatability, src.debuginfo)
Expand Down Expand Up @@ -1109,7 +1109,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod
if ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0 && !generating_output(#=incremental=#false)
src = retrieve_code_info(mi, get_inference_world(interp))
src isa CodeInfo || return nothing
return CodeInstance(mi, cache_owner(interp), Any, Any, nothing, src, Int32(0),
return CodeInstance(mi, Any, Any, nothing, src, Int32(0),
get_inference_world(interp), get_inference_world(interp),
UInt32(0), UInt32(0), nothing, UInt8(0), src.debuginfo)
end
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ For the `NativeInterpreter`, we don't need to do an actual cache query to know i
was already inferred. If we reach this point, but the inference flag has been turned off,
then it's in the cache. This is purely for a performance optimization.
"""
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) = !mi.inInference
already_inferred_quick_test(interp::NativeInterpreter, mi::MethodInstance) = !mi.data.inInference
already_inferred_quick_test(interp::AbstractInterpreter, mi::MethodInstance) = false

"""
Expand All @@ -424,13 +424,13 @@ already includes detection and restriction on recursion, so it is hopefully most
benign problem, since it should really only happen during the first phase of bootstrapping
that we encounter this flag.
"""
lock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (mi.inInference = true; nothing)
lock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (ccall(:jl_lock_mi, Cvoid, (Any,), mi); nothing)
lock_mi_inference(::AbstractInterpreter, ::MethodInstance) = return

"""
See `lock_mi_inference`.
"""
unlock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (mi.inInference = false; nothing)
unlock_mi_inference(::NativeInterpreter, mi::MethodInstance) = (ccall(:jl_unlock_mi, Cvoid, (Any,), mi); nothing)
unlock_mi_inference(::AbstractInterpreter, ::MethodInstance) = return

"""
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function get_staged(mi::MethodInstance, world::UInt)
end

function get_cached_uninferred(mi::MethodInstance, world::UInt)
ccall(:jl_cached_uninferred, Any, (Any, UInt), mi.cache, world)::CodeInstance
ccall(:jl_cached_uninferred, Any, (Any, UInt), mi, world)::CodeInstance
end

function retrieve_code_info(mi::MethodInstance, world::UInt)
Expand Down
15 changes: 15 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,18 @@ end
@deprecate invpermute!!(a, p::AbstractVector{<:Integer}) invpermute!(a, p) false

# END 1.11 deprecations

# BEGIN 1.12 deprecations

# This interface is internal, but relied on in some packages, so
# this allows for smoother upgrades. To be removed when packages have
# migrated.
function Base.getproperty(mi::MethodInstance, s::Symbol)
if s === :sparam_vals
return getfield(getfield(mi, :data), s)
else
return getfield(mi, s)
end
end

# END 1.12 deprecations
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ Unlike normal functions, the compilation heuristics still can't generate good di
in some cases, but this may still allow inference not to fall over in some limited cases.
"""
function may_invoke_generator(mi::MethodInstance)
return may_invoke_generator(mi.def::Method, mi.specTypes, mi.sparam_vals)
return may_invoke_generator(mi.def::Method, mi.specTypes, mi.data.sparam_vals)
end
function may_invoke_generator(method::Method, @nospecialize(atype), sparams::SimpleVector)
# If we have complete information, we may always call the generator
Expand Down
3 changes: 2 additions & 1 deletion src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -2380,7 +2380,6 @@ jl_fptr_args_t jl_get_builtin_fptr(jl_datatype_t *dt)
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_atomic_load_relaxed(&dt->name->mt->defs);
jl_method_instance_t *mi = jl_atomic_load_relaxed(&entry->func.method->unspecialized);
jl_code_instance_t *ci = jl_atomic_load_relaxed(&mi->cache);
assert(ci->owner == jl_nothing);
return jl_atomic_load_relaxed(&ci->specptr.fptr1);
}

Expand Down Expand Up @@ -2494,6 +2493,8 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
add_builtin("Function", (jl_value_t*)jl_function_type);
add_builtin("Builtin", (jl_value_t*)jl_builtin_type);
add_builtin("DefaultSpec", (jl_value_t*)jl_default_spec_type);
add_builtin("MethodSpecialization", (jl_value_t*)jl_method_specialization_type);
add_builtin("MethodInstance", (jl_value_t*)jl_method_instance_type);
add_builtin("CodeInfo", (jl_value_t*)jl_code_info_type);
add_builtin("LLVMPtr", (jl_value_t*)jl_llvmpointer_type);
Expand Down
6 changes: 3 additions & 3 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
std::string err = verify_ccall_sig(
/* inputs: */
rt, at, unionall,
ctx.spvals_ptr == NULL ? ctx.linfo->sparam_vals : NULL,
ctx.spvals_ptr == NULL ? jl_mi_default_spec_data(ctx.linfo)->sparam_vals : NULL,
&ctx.emission_context,
/* outputs: */
lrt, ctx.builder.getContext(),
Expand Down Expand Up @@ -1982,8 +1982,8 @@ jl_cgval_t function_sig_t::emit_a_ccall(
// so that the julia_to_native type checks are more likely to be doable (e.g. concrete types) at compile-time
jl_value_t *jargty_in_env = jargty;
if (ctx.spvals_ptr == NULL && !toboxed && unionall_env && jl_has_typevar_from_unionall(jargty, unionall_env) &&
jl_svec_len(ctx.linfo->sparam_vals) > 0) {
jargty_in_env = jl_instantiate_type_in_env(jargty_in_env, unionall_env, jl_svec_data(ctx.linfo->sparam_vals));
jl_svec_len(jl_mi_default_spec_data(ctx.linfo)->sparam_vals) > 0) {
jargty_in_env = jl_instantiate_type_in_env(jargty_in_env, unionall_env, jl_svec_data(jl_mi_default_spec_data(ctx.linfo)->sparam_vals));
if (jargty_in_env != jargty)
jargty_in_env = jl_ensure_rooted(ctx, jargty_in_env);
}
Expand Down
Loading
Loading