Skip to content

Commit

Permalink
inference: forward Conditional inter-procedurally
Browse files Browse the repository at this point in the history
The PR #38905 only "back-propagates" conditional constraint
(from callee to caller), but currently we don't "forward" it
(caller to callee), and so inter-procedural constraint propagation
won't happen for e.g.:
```julia
ifelselike(cnd, x, y) = cnd ? x : y
@test Base.return_types((Any,Int,)) do x, y
    ifelselike(isa(x, Int), x, y)
end |> only == Int
```

This commit complements #38905 and enables further inter-procedural
conditional constraint propagation by forwarding `Conditional` to
callees when it imposes a constraint on any other argument,
during constant propagation.
  • Loading branch information
aviatesk committed Oct 7, 2021
1 parent f147333 commit c3a0aef
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 50 deletions.
118 changes: 75 additions & 43 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype))
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
(; fargs, argtypes)::ArgInfo, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
add_remark!(interp, sv, "Skipped call in throw block")
Expand Down Expand Up @@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(edges, edge)
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
Expand All @@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
arginfo = ArgInfo(fargs, this_argtypes)
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
if const_result !== nothing
const_this_rt, const_result = const_result
if const_this_rt !== this_rt && const_this_rt this_rt
Expand Down Expand Up @@ -523,13 +525,13 @@ struct MethodCallResult
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, va_override::Bool)
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
mi === nothing && return nothing
# try constant prop'
inf_cache = get_inference_cache(interp)
inf_result = cache_lookup(mi, argtypes, inf_cache)
inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache)
if inf_result === nothing
# if there might be a cycle, check to make sure we don't end up
# calling ourselves here.
Expand All @@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
return nothing
end
end
inf_result = InferenceResult(mi, argtypes, va_override)
inf_result = InferenceResult(mi; arginfo, va_override)
if !any(inf_result.overridden_by_const)
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
return nothing
Expand All @@ -565,7 +567,7 @@ end
# if there's a possibility we could get a better result (hopefully without doing too much work)
# returns `MethodInstance` with constant arguments, returns nothing otherwise
function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState)
if !InferenceParams(interp).ipo_constant_propagation
add_remark!(interp, sv, "[constprop] Disabled by parameter")
Expand All @@ -580,14 +582,14 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
nargs::Int = method.nargs
method.isva && (nargs -= 1)
length(argtypes) < nargs && return nothing
if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt))
length(arginfo.argtypes) < nargs && return nothing
if !(const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, result.rt))
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
return nothing
end
allconst = is_allconst(argtypes)
allconst = is_allconst(arginfo)
if !force
if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst)
if !const_prop_function_heuristic(interp, f, arginfo, nargs, allconst)
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
return nothing
end
Expand All @@ -599,7 +601,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand All @@ -617,8 +619,13 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC
end

# see if propagating constants may be worthwhile
function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any})
function const_prop_argument_heuristic(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo)
for a in argtypes
if isa(a, Conditional) && fargs !== nothing
if is_const_prop_profitable_conditional(a, fargs)
return true
end
end
a = widenconditional(a)
if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a)
return true
Expand All @@ -642,13 +649,34 @@ function is_const_prop_profitable_arg(@nospecialize(arg))
return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val))
end

function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any})
slotid = find_constrained_arg(cnd, fargs)
if slotid !== nothing
return true
end
return is_const_prop_profitable_arg(widenconditional(cnd))
end

function find_constrained_arg(cnd::Conditional, fargs::Vector{Any})
slot = cnd.var
return findfirst(fargs) do @nospecialize(x)
x === slot
end
end

function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype))
return improvable_via_constant_propagation(rettype)
end

function is_allconst(argtypes::Vector{Any})
function is_allconst((; fargs, argtypes)::ArgInfo)
for a in argtypes
if isa(a, Conditional) && fargs !== nothing
if is_const_prop_profitable_conditional(a, fargs)
continue
end
end
a = widenconditional(a)
# TODO unify these condition with `has_nontrivial_const_info`
if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
return false
end
Expand All @@ -663,7 +691,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
istopfunction(f, :setproperty!)
end

function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool)
function const_prop_function_heuristic(
interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
nargs::Int, allconst::Bool)
if nargs > 1
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
arrty = argtypes[2]
Expand Down Expand Up @@ -705,7 +735,7 @@ end
# result anyway.
function const_prop_methodinstance_heuristic(
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
argtypes::Vector{Any}, sv::InferenceState)
(; argtypes)::ArgInfo, sv::InferenceState)
method = match.method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
Expand Down Expand Up @@ -835,7 +865,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
return Any[Vararg{Any}], nothing
end
@assert !isvarargtype(itertype)
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv)
stateordonet = call.rt
info = call.info
# Return Bottom if this is not an iterator.
Expand Down Expand Up @@ -871,7 +901,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
valtype = getfield_tfunc(stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv)
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv)
stateordonet = call.rt
push!(calls, call)
end
Expand Down Expand Up @@ -905,7 +935,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt
stateordonet = widenconst(stateordonet)
end
if valtype !== Union{}
Expand Down Expand Up @@ -994,7 +1024,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
break
end
end
call = abstract_call(interp, nothing, ct, sv, max_methods)
call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods)
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if bail_out_apply(interp, res, sv)
Expand Down Expand Up @@ -1058,8 +1088,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
return argtypes[i:n]
end

function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo,
sv::InferenceState, max_methods::Int)
@nospecialize f
la = length(argtypes)
if f === ifelse && fargs isa Vector{Any} && la == 4
Expand Down Expand Up @@ -1191,7 +1221,7 @@ function abstract_call_unionall(argtypes::Vector{Any})
return Any
end

function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState)
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, false)
Expand Down Expand Up @@ -1219,14 +1249,17 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
argtypes′ = argtypes[4:end]
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
pushfirst!(argtypes′, ft)
fargs′ = fargs[4:end]
pushfirst!(fargs′, fargs[1])
arginfo = ArgInfo(fargs′, argtypes′)
const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
# for i in 1:length(argtypes′)
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false)
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
Expand All @@ -1238,21 +1271,20 @@ end

# call where the function is known exactly
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
sv::InferenceState,
arginfo::ArgInfo, sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)

(; fargs, argtypes) = arginfo
la = length(argtypes)

if isa(f, Builtin)
if f === _apply_iterate
return abstract_apply(interp, argtypes, sv, max_methods)
elseif f === invoke
return abstract_invoke(interp, argtypes, sv)
return abstract_invoke(interp, arginfo, sv)
elseif f === modifyfield!
return abstract_modifyfield!(interp, argtypes, sv)
end
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false)
elseif f === Core.kwfunc
if la == 2
ft = widenconst(argtypes[2])
Expand Down Expand Up @@ -1285,12 +1317,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
# handle Conditional propagation through !Bool
aty = argtypes[2]
if isa(aty, Conditional)
call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
end
elseif la == 3 && istopfunction(f, :!==)
# mark !== as exactly a negated call to ===
rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt
rty = abstract_call_known(interp, (===), arginfo, sv).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
elseif isa(rty, Const)
Expand All @@ -1306,7 +1338,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
fargs = nothing
end
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false)
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false)
elseif la == 2 &&
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
istopfunction(f, :length)
Expand All @@ -1329,7 +1361,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return CallMeta(val === false ? Type : val, MethodResultPure())
end
atype = argtypes_to_type(argtypes)
return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods)
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
Expand All @@ -1342,8 +1374,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
info = OpaqueClosureCallInfo(match)
if !result.edgecycle
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
match, sv, closure.isva)
const_result = abstract_call_method_with_const_args(interp, result, closure,
ArgInfo(nothing, argtypes), match, sv, closure.isva)
if const_result !== nothing
const_rettype, const_result = const_result
if const_rettype rt
Expand All @@ -1366,9 +1398,9 @@ function most_general_argtypes(closure::PartialOpaque)
end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
#print("call ", e.args[1], argtypes, "\n\n")
argtypes = arginfo.argtypes
ft = argtypes[1]
f = singleton_type(ft)
if isa(ft, PartialOpaque)
Expand All @@ -1382,9 +1414,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods)
return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods)
end
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
return abstract_call_known(interp, f, arginfo, sv, max_methods)
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
Expand Down Expand Up @@ -1434,7 +1466,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V
# this may be the wrong world for the call,
# but some of the result is likely to be valid anyways
# and that may help generate better codegen
abstract_call(interp, nothing, at, sv)
abstract_call(interp, ArgInfo(nothing, at), sv)
nothing
end

Expand Down Expand Up @@ -1512,7 +1544,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if argtypes === nothing
t = Bottom
else
callinfo = abstract_call(interp, ea, argtypes, sv)
callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv)
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
end
Expand Down
Loading

0 comments on commit c3a0aef

Please sign in to comment.