Skip to content

Commit

Permalink
Move nargs/isva to CodeInfo
Browse files Browse the repository at this point in the history
This changes the canonical source of truth for va handling from
`Method` to `CodeInfo`. There are multiple goals for this change:

1. This addresses a longstanding complaint about the way that
   CodeInfo-returning generated functions work. Previously, the
   va-ness or not of the returned CodeInfo always had to match that
   of the generator. For Cassette-like transforms that generally have
   one big generator function that is varargs (while then looking up
   lowered code that is not varargs), this could become quite annoying.
   It's possible to workaround, but there is really no good reason to tie
   the two together. As we observed when we implemented OpaqueClosures, the
   vararg-ness of the signature and the `vararg arguments`->`tuple`
   transformation are mostly independent concepts. With this PR, generated
   functions can return CodeInfos with whatever combination of nargs/isva
   is convenient.

2. This change requires clarifying where the va processing boundary is
   in inference. #54076 was already moving in that direction for irinterp,
   and this essentially does much of the same for regular inference. As a
   consequence the constprop cache is now using non-va-cooked signatures,
   which I think is preferable.

3. This further decouples codegen from the presence of a `Method` (which
   is already not assumed, since the code being generated could be a
   toplevel thunk, but some codegen features are only available to things
   that come from Methods). There are a number of upcoming features that
   will require codegen of things that are not quite method specializations
   (See design doc linked in #52797 and things like #50641). This helps
   pave the road for that.

4. I've previously considered expanding the kinds of vararg signatures that
   can be described (see e.g. #53851), which also requires a decoupling of
   the signature and ast notions of vararg. This again lays the groundwork
   for that, although I have no immediate plans to implement this change.

Impact wise, this adds an internal field, which is not too breaking,
but downstream clients vary in how they construct their `CodeInfo`s and
the current way they're doing it will likely be incorrect after this change,
so they will require a small two-line adjustment. We should perhaps consider
pulling out some of the more common patterns into a more stable package, since
interface in most of the last few releases, but that's a separate issue.
  • Loading branch information
Keno committed May 3, 2024
1 parent 0b70d26 commit 5320bd9
Show file tree
Hide file tree
Showing 15 changed files with 189 additions and 153 deletions.
43 changes: 17 additions & 26 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,12 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
if isa(rt, InterConditional) && rt.slot == i
return rt
else
thentype = elsetype = tmeet(𝕃ᵢ, widenslotwrapper(argtypes[i]), fieldtype(sig, i))
argt = widenslotwrapper(argtypes[i])
if isvarargtype(argt)
@assert fieldcount(sig) == i
argt = unwrapva(argt)
end
thentype = elsetype = tmeet(𝕃ᵢ, argt, fieldtype(sig, i))
condval = maybe_extract_const_bool(rt)
condval === true && (elsetype = Bottom)
condval === false && (thentype = Bottom)
Expand Down Expand Up @@ -986,15 +991,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
# N.B. remarks are emitted within `const_prop_entry_heuristic`
return nothing
end
nargs::Int = method.nargs
method.isva && (nargs -= 1)
length(arginfo.argtypes) < nargs && return nothing
if !const_prop_argument_heuristic(interp, arginfo, sv)
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
return nothing
end
all_overridden = is_all_overridden(interp, arginfo, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
return nothing
end
Expand Down Expand Up @@ -1113,9 +1115,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
end

function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, nargs::Int, all_overridden::Bool, sv::AbsIntState)
arginfo::ArgInfo, all_overridden::Bool, sv::AbsIntState)
argtypes = arginfo.argtypes
if nargs > 1
if length(argtypes) > 1
𝕃ᵢ = typeinf_lattice(interp)
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
arrty = argtypes[2]
Expand Down Expand Up @@ -1349,20 +1351,6 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
end
given_argtypes[i] = widenslotwrapper(argtype)
end
if condargs !== nothing
given_argtypes = let condargs=condargs
va_process_argtypes(𝕃, given_argtypes, mi) do isva_given_argtypes::Vector{Any}, last::Int
# invalidate `Conditional` imposed on varargs
for (slotid, i) in condargs
if slotid last && (1 i length(isva_given_argtypes)) # `Conditional` is already widened to vararg-tuple otherwise
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
end
end
end
end
else
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
end
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

Expand Down Expand Up @@ -1721,7 +1709,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
return CallMeta(res, exct, effects, retinfo)
end

function argtype_by_index(argtypes::Vector{Any}, i::Int)
function argtype_by_index(argtypes::Vector{Any}, i::Integer)
n = length(argtypes)
na = argtypes[n]
if isvarargtype(na)
Expand Down Expand Up @@ -2890,12 +2878,12 @@ end
struct BestguessInfo{Interp<:AbstractInterpreter}
interp::Interp
bestguess
nargs::Int
nargs::UInt
slottypes::Vector{Any}
changes::VarTable
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::Int,
function BestguessInfo(interp::Interp, @nospecialize(bestguess), nargs::UInt,
slottypes::Vector{Any}, changes::VarTable) where Interp<:AbstractInterpreter
new{Interp}(interp, bestguess, nargs, slottypes, changes)
new{Interp}(interp, bestguess, Int(nargs), slottypes, changes)
end
end

Expand Down Expand Up @@ -2970,7 +2958,7 @@ end
# pick up the first "interesting" slot, convert `rt` to its `Conditional`
# TODO: ideally we want `Conditional` and `InterConditional` to convey
# constraints on multiple slots
for slot_id = 1:info.nargs
for slot_id = 1:Int(info.nargs)
rt = bool_rt_to_conditional(rt, slot_id, info)
rt isa InterConditional && break
end
Expand All @@ -2981,6 +2969,9 @@ end
= (typeinf_lattice(info.interp))
old = info.slottypes[slot_id]
new = widenslotwrapper(info.changes[slot_id].typ) # avoid nested conditional
if isvarargtype(old) || isvarargtype(new)
return rt
end
if new ᵢ old && !(old ᵢ new)
if isa(rt, Const)
val = rt.val
Expand Down
183 changes: 88 additions & 95 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,53 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
for i = 1:length(argtypes)
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(𝕃, given_argtypes, mi)
return pick_const_args!(𝕃, given_argtypes, cache_argtypes)
end

function pick_const_arg(𝕃::AbstractLattice, @nospecialize(given_argtype), @nospecialize(cache_argtype))
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `mi`
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
!(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtype = tmeet(𝕃, given_argtype, cache_argtype)
end
else
given_argtype = cache_argtype
end
return given_argtype
end

function pick_const_args!(𝕃::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any})
nargtypes = length(given_argtypes)
@assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`"
for i = 1:nargtypes
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(𝕃, given_argtype, cache_argtype, false)
# prefer the argtype we were given over the one computed from `mi`
if (isa(given_argtype, PartialStruct) && isa(cache_argtype, Type) &&
!(𝕃, given_argtype, cache_argtype))
# if the type information of this `PartialStruct` is less strict than
# declared method signature, narrow it down using `tmeet`
given_argtypes[i] = tmeet(𝕃, given_argtype, cache_argtype)
end
if length(given_argtypes) == 0 || length(cache_argtypes) == 0
return Any[]
end
given_va = given_argtypes[end]
cache_va = cache_argtypes[end]
if isvarargtype(given_va)
if isvarargtype(cache_va)
# Process the common prefix, then join
nprocessargs = max(length(given_argtypes)-1, length(cache_argtypes)-1)
resize!(given_argtypes, nprocessargs+1)
given_argtypes[end] = Vararg{pick_const_arg(𝕃, unwrapva(given_va), unwrapva(cache_va))}
else
given_argtypes[i] = cache_argtype
nprocessargs = length(cache_argtypes)
resize!(given_argtypes, nprocessargs)
end
elseif isvarargtype(cache_va)
nprocessargs = length(given_argtypes)
resize!(given_argtypes, nprocessargs)
else
@assert length(given_argtypes) == length(cache_argtypes)
nprocessargs = length(given_argtypes)
resize!(given_argtypes, nprocessargs)
end
for i = 1:nprocessargs
given_argtype = argtype_by_index(given_argtypes, i)
cache_argtype = argtype_by_index(cache_argtypes, i)
given_argtype = pick_const_arg(𝕃, given_argtype, cache_argtype)
given_argtypes[i] = given_argtype
end
return given_argtypes
end
Expand All @@ -60,25 +86,33 @@ function is_argtype_match(𝕃::AbstractLattice,
end
end

va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) =
va_process_argtypes(Returns(nothing), 𝕃, given_argtypes, mi)
function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance)
def = mi.def::Method
isva = def.isva
nargs = Int(def.nargs)
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool)
if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end]))
isva_given_argtypes = Vector{Any}(undef, Int(nargs))
for i = 1:(nargs-isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
newarg = argtype_by_index(given_argtypes, i)
if isva && has_conditional(𝕃) && isa(newarg, Conditional)
if newarg.slotid > (nargs-isva)
newarg = widenconditional(newarg)
end
end
isva_given_argtypes[i] = newarg
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
if has_conditional(𝕃)
for i = last:length(given_argtypes)
newarg = given_argtypes[i]
if isa(newarg, Conditional) && newarg.slotid > (nargs-isva)
given_argtypes[i] = widenconditional(newarg)
end
end
end
end
isva_given_argtypes[nargs] = tuple_tfunc(𝕃, given_argtypes[last:end])
va_handler!(isva_given_argtypes, last)
end
return isva_given_argtypes
end
Expand All @@ -87,84 +121,44 @@ function va_process_argtypes(@specialize(va_handler!), 𝕃::AbstractLattice, gi
end

function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes))
toplevel = method === nothing
isva = !toplevel && method.isva
mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...]
nargs::Int = toplevel ? 0 : method.nargs
cache_argtypes = Vector{Any}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
mi_argtypes_length = length(mi_argtypes)
if !toplevel && isva
if specTypes::Type == Tuple
mi_argtypes = Any[Any for i = 1:nargs]
if nargs > 1
mi_argtypes[end] = Tuple
end
vargtype = Tuple
else
if nargs > mi_argtypes_length
va = mi_argtypes[mi_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
vargtype = Tuple{new_va}
else
vargtype = Tuple{}
end
else
vargtype_elements = Any[]
for i in nargs:mi_argtypes_length
p = mi_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
vargtype_elements[i] = Const(atyp.instance)
elseif isconstType(atyp)
vargtype_elements[i] = Const(atyp.parameters[1])
end
end
vargtype = tuple_tfunc(fallback_lattice, vargtype_elements)
end
end
cache_argtypes[nargs] = vargtype
nargs -= 1
nargtypes = length(mi_argtypes)
nargs = isa(method, Method) ? method.nargs : 0
if length(mi_argtypes) < nargs && isvarargtype(mi_argtypes[end])
resize!(mi_argtypes, nargs)
end
# Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some
# type info as we go (where possible). Note that if we're dealing with a varargs method,
# we already handled the last element of `cache_argtypes` (and decremented `nargs` so that
# we don't overwrite the result of that work here).
if mi_argtypes_length > 0
tail_index = nargtypes = min(mi_argtypes_length, nargs)
local lastatype
for i = 1:nargtypes
atyp = mi_argtypes[i]
if i == nargtypes && isvarargtype(atyp)
atyp = unwrapva(atyp)
tail_index -= 1
end
atyp = unwraptv(atyp)
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == nargtypes && (lastatype = atyp)
cache_argtypes[i] = atyp
tail_index = min(nargtypes, nargs)
local lastatype
for i = 1:nargtypes
atyp = mi_argtypes[i]
wasva = false
if i == nargtypes && isvarargtype(atyp)
wasva = true
atyp = unwrapva(atyp)
end
for i = (tail_index+1):nargs
cache_argtypes[i] = lastatype
atyp = unwraptv(atyp)
if issingletontype(atyp)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
mi_argtypes[i] = atyp
if wasva
lastatype = atyp
mi_argtypes[end] = Vararg{atyp}
end
end
for i = (tail_index+1):(nargs-1)
mi_argtypes[i] = lastatype
end
return cache_argtypes
return mi_argtypes
end

# eliminate free `TypeVar`s in order to make the life much easier down the road:
Expand All @@ -184,7 +178,6 @@ function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes:
cache::Vector{InferenceResult})
method = mi.def::Method
nargtypes = length(given_argtypes)
@assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`"
for cached_result in cache
cached_result.linfo === mi || @goto next_cache
cache_argtypes = cached_result.argtypes
Expand Down
10 changes: 6 additions & 4 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ mutable struct InferenceState
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
argtypes = result.argtypes

argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)

nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
Expand Down Expand Up @@ -766,10 +769,9 @@ function print_callstack(sv::InferenceState)
end

function narguments(sv::InferenceState, include_va::Bool=true)
def = sv.linfo.def
nargs = length(sv.result.argtypes)
nargs = sv.src.nargs
if !include_va
nargs -= isa(def, Method) && def.isva
nargs -= sv.src.isva
end
return nargs
end
Expand Down Expand Up @@ -831,7 +833,7 @@ function IRInterpretationState(interp::AbstractInterpreter,
end
method_info = MethodInfo(src)
ir = inflate_ir(src, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
return IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
codeinst.min_world, codeinst.max_world)
end
Expand Down
5 changes: 2 additions & 3 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1264,14 +1264,13 @@ end
function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
# need `ci` for the slot metadata, IR for the code
svdef = sv.linfo.def
nargs = isa(svdef, Method) ? Int(svdef.nargs) : 0
@timeit "domtree 1" domtree = construct_domtree(ir)
defuse_insts = scan_slot_def_use(nargs, ci, ir.stmts.stmt)
defuse_insts = scan_slot_def_use(ci.nargs, ci, ir.stmts.stmt)
𝕃ₒ = optimizer_lattice(sv.inlining.interp)
@timeit "construct_ssa" ir = construct_ssa!(ci, ir, sv, domtree, defuse_insts, 𝕃ₒ) # consumes `ir`
# NOTE now we have converted `ir` to the SSA form and eliminated slots
# let's resize `argtypes` now and remove unnecessary types for the eliminated slots
resize!(ir.argtypes, nargs)
resize!(ir.argtypes, ci.nargs)
return ir
end

Expand Down
Loading

0 comments on commit 5320bd9

Please sign in to comment.