Skip to content

Commit

Permalink
enable/improve constant propagation through varargs methods
Browse files Browse the repository at this point in the history
- Store varargs type information in the InferenceResult object, such that the info can be used during inference/optimization

- Hack in a more precise return type for getfield of a vararg tuple. Ideally, we would handle this by teaching inference to track the types of the individual fields of a Tuple, which would make this unnecessary, but until then, this hack is helpful.

- Spoof parents as well as children during recursion limiting, so that higher degree cycles are appropriately spoofed

- A broadcast test marked as broken is now no longer broken, presumably due to the optimizations in this commit

- Fix relationship between depth/mindepth in limit_type_size/is_derived_type. The relationship should have been inverse over the domain in which they overlap, but was not maintained consistently. An example of problematic case was:

t = Tuple{X,X} where X<:Tuple{Tuple{Int64,Vararg{Int64,N} where N},Tuple{Int64,Vararg{Int64,N} where N}}
c = Tuple{X,X} where X<:Tuple{Int64,Vararg{Int64,N} where N}

because is_derived_type was computing the depth of usage rather than the depth of definition. This change thus makes the depth/mindepth calculations more consistent, and causes the limiting heuristic to return strictly wider types than it did before.

- Move the optimizer's "varargs types to tuple type" rewrite to after cache lookup.Inference is populating the InferenceResult cache using the varargs form, so the optimizer needs to do the lookup before writing the atypes in order to avoid cache misses.

Co-authored-by: Jameson Nash <[email protected]>
Co-authored-by: Keno Fischer <[email protected]>
  • Loading branch information
3 people committed May 3, 2018
1 parent 7e2ce0e commit 1a9f8e8
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 83 deletions.
43 changes: 34 additions & 9 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
method.isva && (nargs -= 1)
length(argtypes) >= nargs || return Any # probably limit_tuple_type made this non-matching method apparently match
haveconst = false
for i in 1:nargs
a = argtypes[i]
for a in argtypes
if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val))
# have new information from argtypes that wasn't available from the signature
haveconst = true
Expand All @@ -144,8 +143,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
tm = _topmod(sv)
if !istopfunction(tm, f, :getproperty) && !istopfunction(tm, f, :setproperty!)
# in this case, see if all of the arguments are constants
for i in 1:nargs
a = argtypes[i]
for a in argtypes
if !isa(a, Const) && !isconstType(a)
return Any
end
Expand All @@ -156,6 +154,17 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector
if inf_result === nothing
inf_result = InferenceResult(code)
atypes = get_argtypes(inf_result)
if method.isva
vargs = argtypes[(nargs + 1):end]
for i in 1:length(vargs)
a = vargs[i]
if i > length(inf_result.vargs)
push!(inf_result.vargs, a)
elseif a isa Const
inf_result.vargs[i] = a
end
end
end
for i in 1:nargs
a = argtypes[i]
if a isa Const
Expand Down Expand Up @@ -187,7 +196,13 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
cyclei = 0
infstate = sv
edgecycle = false
# The `method_for_inference_heuristics` will expand the given method's generator if
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
# access it directly instead (to avoid regeneration).
method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing}
sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match
sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing}
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
Expand All @@ -208,7 +223,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
for parent in infstate.callers_in_cycle
# check in the cycle list first
# all items in here are mutual parents of all others
if parent.linfo.def === sv.linfo.def
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
break
Expand All @@ -218,7 +235,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
# then check the parent link
if topmost === nothing && parent !== nothing
parent = parent::InferenceState
if parent.cached && parent.linfo.def === sv.linfo.def
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
if parent.cached && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
topmost = infstate
edgecycle = true
end
Expand Down Expand Up @@ -315,8 +334,8 @@ function precise_container_type(@nospecialize(arg), @nospecialize(typ), vtypes::
end

arg = ssa_def_expr(arg, sv)
if is_specializable_vararg_slot(arg, sv)
return Any[rewrap_unionall(p, sv.linfo.specTypes) for p in sv.vararg_type_container.parameters]
if is_specializable_vararg_slot(arg, sv.nargs, sv.result.vargs)
return sv.result.vargs
end

tti0 = widenconst(typ)
Expand Down Expand Up @@ -482,7 +501,13 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt
tm = _topmod(sv)
if isa(f, Builtin) || isa(f, IntrinsicFunction)
rt = builtin_tfunction(f, argtypes[2:end], sv)
if (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any})
if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
cti = precise_container_type(fargs[2], argtypes[2], vtypes, sv)
idx = argtypes[3].val
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
end
elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any})
# perform very limited back-propagation of type information for `is` and `isa`
if f === isa
a = ssa_def_expr(fargs[2], sv)
Expand Down
48 changes: 35 additions & 13 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const EMPTY_VECTOR = Vector{Any}()
mutable struct InferenceResult
linfo::MethodInstance
args::Vector{Any}
vargs::Vector{Any}
result # ::Type, or InferenceState if WIP
src::Union{CodeInfo, Nothing} # if inferred copy is available
function InferenceResult(linfo::MethodInstance)
Expand All @@ -13,7 +14,7 @@ mutable struct InferenceResult
else
result = linfo.rettype
end
return new(linfo, EMPTY_VECTOR, result, nothing)
return new(linfo, EMPTY_VECTOR, Any[], result, nothing)
end
end

Expand All @@ -31,7 +32,35 @@ function get_argtypes(result::InferenceResult)
end
vararg_type = Tuple
else
vararg_type = rewrap(tupleparam_tail(atypes, nargs), linfo.specTypes)
laty = length(atypes)
if nargs > laty
va = atypes[laty]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes)
vararg_type_vec = Any[new_va]
vararg_type = Tuple{new_va}
else
vararg_type_vec = Any[]
vararg_type = Tuple{}
end
else
vararg_type_vec = Any[]
for p in atypes[nargs:laty]
p = isvarargtype(p) ? unconstrain_vararg_length(p) : p
push!(vararg_type_vec, rewrap_unionall(p, linfo.specTypes))
end
vararg_type = tuple_tfunc(Tuple{vararg_type_vec...})
for i in 1:length(vararg_type_vec)
atyp = vararg_type_vec[i]
if isa(atyp, DataType) && isdefined(atyp, :instance)
# replace singleton types with their equivalent Const object
vararg_type_vec[i] = Const(atyp.instance)
elseif isconstType(atyp)
vararg_type_vec[i] = Const(atyp.parameters[1])
end
end
end
result.vargs = vararg_type_vec
end
args[nargs] = vararg_type
nargs -= 1
Expand Down Expand Up @@ -80,19 +109,12 @@ function cache_lookup(code::MethodInstance, argtypes::Vector{Any}, cache::Vector
for cache_code in cache
# try to search cache first
cache_args = cache_code.args
if cache_code.linfo === code && length(cache_args) >= nargs
cache_vargs = cache_code.vargs
if cache_code.linfo === code && length(argtypes) === (length(cache_vargs) + nargs)
cache_match = true
# verify that the trailing args (va) aren't Const
for i in (nargs + 1):length(cache_args)
if isa(cache_args[i], Const)
cache_match = false
break
end
end
cache_match || continue
for i in 1:nargs
for i in 1:length(argtypes)
a = argtypes[i]
ca = cache_args[i]
ca = i <= nargs ? cache_args[i] : cache_vargs[i - nargs]
# verify that all Const argument types match between the call and cache
if (isa(a, Const) || isa(ca, Const)) && !(a === ca)
cache_match = false
Expand Down
16 changes: 3 additions & 13 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ mutable struct InferenceState
# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}
ssavalue_defs::Vector{LineNum}
vararg_type_container #::Type

backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
Expand Down Expand Up @@ -102,19 +101,11 @@ mutable struct InferenceState
# initial types
nslots = length(src.slotnames)
argtypes = get_argtypes(result)
vararg_type_container = nothing
nargs = length(argtypes)
s_argtypes = VarTable(undef, nslots)
src.slottypes = Vector{Any}(undef, nslots)
for i in 1:nslots
at = (i > nargs) ? Bottom : argtypes[i]
if !toplevel && linfo.def.isva && i == nargs
if !(at == Tuple) # would just be a no-op
vararg_type_container = unwrap_unionall(at)
vararg_type = tuple_tfunc(vararg_type_container) # returns a Const object, if applicable
at = rewrap(vararg_type, linfo.specTypes)
end
end
s_argtypes[i] = VarState(at, i > nargs)
src.slottypes[i] = at
end
Expand Down Expand Up @@ -152,7 +143,7 @@ mutable struct InferenceState
nargs, s_types, s_edges,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, ssavalue_defs, vararg_type_container,
ssavalue_uses, ssavalue_defs,
Vector{Tuple{InferenceState,LineNum}}(), # backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
Expand Down Expand Up @@ -238,9 +229,8 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe
nothing
end

function is_specializable_vararg_slot(@nospecialize(arg), sv::InferenceState)
return (isa(arg, Slot) && slot_id(arg) == sv.nargs &&
isa(sv.vararg_type_container, DataType))
function is_specializable_vararg_slot(@nospecialize(arg), nargs, vargs)
return (isa(arg, Slot) && slot_id(arg) == nargs && !isempty(vargs))
end

function print_callstack(sv::InferenceState)
Expand Down
33 changes: 18 additions & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

mutable struct OptimizationState
linfo::MethodInstance
vararg_type_container #::Type
result_vargs::Vector{Any}
backedges::Vector{Any}
src::CodeInfo
mod::Module
Expand All @@ -23,7 +23,7 @@ mutable struct OptimizationState
end
src = frame.src
next_label = max(label_counter(src.code), length(src.code)) + 10
return new(frame.linfo, frame.vararg_type_container,
return new(frame.linfo, frame.result.vargs,
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
next_label, frame.min_valid, frame.max_valid,
Expand Down Expand Up @@ -53,8 +53,8 @@ mutable struct OptimizationState
nargs = 0
end
next_label = max(label_counter(src.code), length(src.code)) + 10
vararg_type_container = nothing # if you want something more accurate, set it yourself :P
return new(linfo, vararg_type_container,
result_vargs = Any[] # if you want something more accurate, set it yourself :P
return new(linfo, result_vargs,
s_edges::Vector{Any},
src, inmodule, nargs,
next_label,
Expand Down Expand Up @@ -100,11 +100,6 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState)
nothing
end

function is_specializable_vararg_slot(@nospecialize(arg), sv::OptimizationState)
return (isa(arg, Slot) && slot_id(arg) == sv.nargs &&
isa(sv.vararg_type_container, DataType))
end

###########
# structs #
###########
Expand Down Expand Up @@ -1333,22 +1328,30 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector
if invoke_api(linfo) == 2
# in this case function can be inlined to a constant
add_backedge!(linfo, sv)
# XXX: @vtjnash thinks this should be `argexprs0`, but doing so exposes a
# downstream optimizer problem that breaks tests, so we're going to avoid
# changing it for now. ref
# https://github.com/JuliaLang/julia/pull/26826#issuecomment-386381103
return inline_as_constant(linfo.inferred_const, argexprs, sv, invoke_data)
end

# see if the method has a InferenceResult in the current cache
# See if the method has a InferenceResult in the current cache
# or an existing inferred code info store in `.inferred`
#
# Above, we may have rewritten trailing varargs in `atypes` to a tuple type. However,
# inference populates the cache with the pre-rewrite version (`atypes0`), so here, we
# check against that instead.
haveconst = false
for i in 1:length(atypes)
a = atypes[i]
for i in 1:length(atypes0)
a = atypes0[i]
if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val))
# have new information from argtypes that wasn't available from the signature
haveconst = true
break
end
end
if haveconst
inf_result = cache_lookup(linfo, atypes, sv.params.cache) # Union{Nothing, InferenceResult}
inf_result = cache_lookup(linfo, atypes0, sv.params.cache) # Union{Nothing, InferenceResult}
else
inf_result = nothing
end
Expand Down Expand Up @@ -2003,8 +2006,8 @@ function inline_call(e::Expr, sv::OptimizationState, stmts::Vector{Any}, boundsc
tmpv = newvar!(sv, t)
push!(newstmts, Expr(:(=), tmpv, aarg))
end
if is_specializable_vararg_slot(aarg, sv)
tp = sv.vararg_type_container.parameters
if is_specializable_vararg_slot(aarg, sv.nargs, sv.result_vargs)
tp = sv.result_vargs
else
tp = t.parameters
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ add_tfunc(===, 2, 2,
end
return Bool
end, 1)
function isdefined_tfunc(args...)
function isdefined_tfunc(@nospecialize(args...))
arg1 = args[1]
if isa(arg1, Const)
a1 = typeof(arg1.val)
Expand Down
14 changes: 6 additions & 8 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,12 @@ function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int)
if t === c
return mindepth == 0
end
if isa(c, TypeVar)
# see if it is replacing a TypeVar upper bound with something simpler
return is_derived_type(t, c.ub, mindepth)
elseif isa(c, Union)
if isa(c, Union)
# see if it is one of the elements of the union
return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1)
elseif isa(c, UnionAll)
# see if it is derived from the body
return is_derived_type(t, c.body, mindepth)
return is_derived_type(t, c.var.ub, mindepth) || is_derived_type(t, c.body, mindepth + 1)
elseif isa(c, DataType)
if isa(t, DataType)
# see if it is one of the supertypes of a parameter
Expand Down Expand Up @@ -96,7 +93,8 @@ function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, minde
return false
end

# type vs. comparison or which was derived from source
# The goal of this function is to return a type of greater "size" and less "complexity" than
# both `t` or `c` over the lattice defined by `sources`, `depth`, and `allowed_tuplelen`.
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
if t === c
return t # quick egal test
Expand Down Expand Up @@ -140,9 +138,9 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
lb = Bottom
end
v2 = TypeVar(tv.name, lb, ub)
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen))
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth, allowed_tuplelen))
end
tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen)
tbody = _limit_type_size(t.body, c, sources, depth, allowed_tuplelen)
tbody === t.body && return t
return UnionAll(t.var, tbody)
elseif isa(c, UnionAll)
Expand Down
14 changes: 0 additions & 14 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,6 @@ function tuple_tail_elem(@nospecialize(init), ct)
return Vararg{widenconst(foldl((a, b) -> tmerge(a, tvar_extent(unwrapva(b))), init, ct))}
end

# t[n:end]
function tupleparam_tail(t::SimpleVector, n)
lt = length(t)
if n > lt
va = t[lt]
if isvarargtype(va)
# assumes that we should never see Vararg{T, x}, where x is a constant (should be guaranteed by construction)
return Tuple{va}
end
return Tuple{}
end
return Tuple{t[n:lt]...}
end

# take a Tuple where one or more parameters are Unions
# and return an array such that those Unions are removed
# and `Union{return...} == ty`
Expand Down
8 changes: 8 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ function unwrapva(@nospecialize(t))
return isvarargtype(t2) ? rewrap_unionall(t2.parameters[1], t) : t
end

function unconstrain_vararg_length(@nospecialize(va))
# construct a new Vararg type where its length is unconstrained,
# but its element type still captures any dependencies the input
# element type may have had on the input length
T = unwrap_unionall(va).parameters[1]
return rewrap_unionall(Vararg{T}, va)
end

typename(a) = error("typename does not apply to this type")
typename(a::DataType) = a.name
function typename(a::Union)
Expand Down
Loading

0 comments on commit 1a9f8e8

Please sign in to comment.