Skip to content

Commit 72be4a1

Browse files
committed
inference: (slightly) improve type stability of capturing closures
As an idea to improve type stability for capturing closures, such as in #31909, I tried this idea of propagating the closure object as a `PartialStruct` whose `fields` include captured variables of which types are (partially) known. By performing const-prop on this `closure::PartialStruct`, we can achieve a certain level of type stability. Specifically, I made some modifications to `abstract_eval_new` to allow creating `PartialStruct` even for `:new` objects that are `!isconcretedispatch` (since `PartialStruct` can now represent abstract elements). I also adjusted `const_prop_argument_heuristic` to perform aggressive constant propagation using such `closure::PartialStruct`. As a result, the following code now achieves type stability: ```julia julia> Base.infer_return_type((Bool,Int,)) do b, y x = b ? 1 : missing inner = y -> x + y return inner(y) end Any # master Union{Missing, Int64} # this commit ``` However, this alone was not enough to fully resolve #31909. The call graph of `map` is extremely complex, and simply applying constant propagation everywhere does not achieve the type safety requested in the issue. Nevertheless this commit alone would still improve type stability for some cases, so I will go ahead and submit it as a PR.
1 parent 1edc6f1 commit 72be4a1

File tree

3 files changed

+82
-34
lines changed

3 files changed

+82
-34
lines changed

Compiler/src/abstractinterpretation.jl

+65-33
Original file line numberDiff line numberDiff line change
@@ -1010,9 +1010,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
10101010
# N.B. remarks are emitted within `const_prop_rettype_heuristic`
10111011
return nothing
10121012
end
1013-
if !const_prop_argument_heuristic(interp, arginfo, sv)
1013+
arg_result = const_prop_argument_heuristic(interp, arginfo, sv)
1014+
if arg_result === nothing
10141015
add_remark!(interp, sv, "[constprop] Disabled by argument heuristics")
10151016
return nothing
1017+
else
1018+
force |= arg_result
10161019
end
10171020
all_overridden = is_all_overridden(interp, arginfo, sv)
10181021
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
@@ -1079,16 +1082,19 @@ end
10791082
function const_prop_argument_heuristic(interp::AbstractInterpreter, arginfo::ArgInfo, sv::AbsIntState)
10801083
𝕃ᵢ = typeinf_lattice(interp)
10811084
argtypes = arginfo.argtypes
1082-
for i in 1:length(argtypes)
1085+
for i = 1:length(argtypes)
10831086
a = argtypes[i]
10841087
if has_conditional(𝕃ᵢ, sv) && isa(a, Conditional) && arginfo.fargs !== nothing
1085-
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return true
1088+
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return false
10861089
else
10871090
a = widenslotwrapper(a)
1088-
has_nontrivial_extended_info(𝕃ᵢ, a) && is_const_prop_profitable_arg(𝕃ᵢ, a) && return true
1091+
if has_nontrivial_extended_info(𝕃ᵢ, a) && is_const_prop_profitable_arg(𝕃ᵢ, a)
1092+
# force const-prop' if the function object itself has some profitable information
1093+
return i == 1 || widenconst(a) <: Function
1094+
end
10891095
end
10901096
end
1091-
return false
1097+
return nothing
10921098
end
10931099

10941100
function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any}, sv::InferenceState)
@@ -2899,16 +2905,16 @@ end
28992905

29002906
function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing},
29012907
sv::AbsIntState)
2902-
𝕃ᵢ = typeinf_lattice(interp)
2908+
lat = typeinf_lattice(interp)
29032909
rt, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
29042910
ut = unwrap_unionall(rt)
29052911
exct = Union{ErrorException,TypeError}
29062912
if isa(ut, DataType) && !isabstracttype(ut)
29072913
ismutable = ismutabletype(ut)
29082914
fcount = datatype_fieldcount(ut)
29092915
nargs = length(e.args) - 1
2910-
has_any_uninitialized = (fcount === nothing || (fcount > nargs && (let t = rt
2911-
any(i::Int -> !is_undefref_fieldtype(fieldtype(t, i)), (nargs+1):fcount)
2916+
has_any_uninitialized = (fcount === nothing || (fcount > nargs && (let boxed = Core.Box(rt)
2917+
any(i::Int -> !is_undefref_fieldtype(fieldtype(boxed.contents, i)), (nargs+1):fcount)
29122918
end)))
29132919
if has_any_uninitialized
29142920
# allocation with undefined field is inconsistent always
@@ -2920,43 +2926,69 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
29202926
else
29212927
consistent = ALWAYS_TRUE # immutable allocation is consistent
29222928
end
2923-
if isconcretedispatch(rt)
2924-
nothrow = true
2925-
@assert fcount !== nothing && fcount nargs "malformed :new expression" # syntactically enforced by the front-end
2926-
ats = Vector{Any}(undef, nargs)
2927-
local anyrefine = false
2928-
local allconst = true
2929+
@inline function compute_fields_info(@nospecialize(rt))
2930+
local anyrefine, allconst, nothrow = false, true, true
2931+
, , = partialorder(lat), strictneqpartialorder(lat), meet(lat)
29292932
for i = 1:nargs
29302933
at = widenslotwrapper(abstract_eval_value(interp, e.args[i+1], vtypes, sv))
29312934
ft = fieldtype(rt, i)
2932-
nothrow && (nothrow = (𝕃ᵢ, at, ft))
2933-
at = tmeet(𝕃ᵢ, at, ft)
2934-
at === Bottom && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
2935+
nothrow && (nothrow = at ft)
2936+
at = at ft
2937+
at === Bottom && return nothing
29352938
if ismutable && !isconst(rt, i)
2936-
ats[i] = ft # can't constrain this field (as it may be modified later)
2939+
# can't constrain this field (as it may be modified later)
2940+
allconst = false
29372941
continue
29382942
end
29392943
allconst &= isa(at, Const)
2940-
if !anyrefine
2941-
anyrefine = has_nontrivial_extended_info(𝕃ᵢ, at) || # extended lattice information
2942-
(𝕃ᵢ, at, ft) # just a type-level information, but more precise than the declared type
2944+
anyrefine || (anyrefine =
2945+
has_nontrivial_extended_info(lat, at) || # extended lattice information
2946+
at ft) # just a type-level information, but more precise than the declared type
2947+
end
2948+
return anyrefine, allconst, nothrow
2949+
end
2950+
@noinline function compute_fields(@nospecialize(rt), unwrap_const::Bool=false)
2951+
local fields = Vector{Any}(undef, nargs)
2952+
= meet(lat)
2953+
for i = 1:nargs
2954+
at = widenslotwrapper(abstract_eval_value(interp, e.args[i+1], vtypes, sv))
2955+
ft = fieldtype(rt, i)
2956+
if ismutable && !isconst(rt, i)
2957+
@assert !unwrap_const
2958+
fields[i] = ft # can't constrain this field (as it may be modified later)
2959+
else
2960+
at = at ft
2961+
if unwrap_const
2962+
fields[i] = (at::Const).val
2963+
else
2964+
fields[i] = at
2965+
end
29432966
end
2944-
ats[i] = at
29452967
end
2968+
return fields
2969+
end
2970+
if isconcretedispatch(rt)
2971+
@assert fcount !== nothing && fcount nargs "malformed :new expression" # syntactically enforced by the front-end
2972+
ret = compute_fields_info(rt)
2973+
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
2974+
anyrefine, allconst, nothrow = ret
29462975
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
2947-
argvals = Vector{Any}(undef, nargs)
2948-
for j in 1:nargs
2949-
argvals[j] = (ats[j]::Const).val
2950-
end
2976+
argvals = compute_fields(rt, #=unwrap_const=#true)
29512977
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
29522978
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
29532979
# propagate partially initialized struct as `PartialStruct` when:
29542980
# - any refinement information is available (`anyrefine`), or when
29552981
# - `nargs` is greater than `n_initialized` derived from the struct type
29562982
# information alone
2957-
rt = PartialStruct(𝕃ᵢ, rt, ats)
2983+
rt = PartialStruct(lat, rt, compute_fields(rt))
29582984
end
29592985
else
2986+
ret = compute_fields_info(rt)
2987+
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
2988+
anyrefine = ret[1]
2989+
if anyrefine || nargs > datatype_min_ninitialized(ut)
2990+
rt = PartialStruct(lat, rt, compute_fields(rt))
2991+
end
29602992
rt = refine_partial_type(rt)
29612993
nothrow = false
29622994
end
@@ -2978,18 +3010,18 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, vtypes::Un
29783010
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
29793011
n = fieldcount(rt)
29803012
if (isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
2981-
(let t = rt, at = at
2982-
all(i::Int -> getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n)
3013+
(let boxed = Core.Box(rt)
3014+
all(i::Int -> getfield(at.val::Tuple, i) isa fieldtype(boxed.contents, i), 1:n)
29833015
end))
29843016
nothrow = isexact
29853017
rt = Const(ccall(:jl_new_structt, Any, (Any, Any), rt, at.val))
29863018
elseif (isa(at, PartialStruct) && (𝕃ᵢ, at, Tuple) && n > 0 &&
2987-
n == length(at.fields::Vector{Any}) && !isvarargtype(at.fields[end]) &&
2988-
(let t = rt, at = at
2989-
all(i::Int -> (𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n)
3019+
n == length(at.fields) && !isvarargtype(at.fields[end]) &&
3020+
(let boxed = Core.Box(rt)
3021+
all(i::Int -> (𝕃ᵢ, (at.fields)[i], fieldtype(boxed.contents, i)), 1:n)
29903022
end))
29913023
nothrow = isexact
2992-
rt = PartialStruct(𝕃ᵢ, rt, at.fields::Vector{Any})
3024+
rt = PartialStruct(𝕃ᵢ, rt, at.fields)
29933025
end
29943026
else
29953027
rt = refine_partial_type(rt)

Compiler/test/inference.jl

+16
Original file line numberDiff line numberDiff line change
@@ -6088,3 +6088,19 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
60886088
types[index]
60896089
end
60906090
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}
6091+
6092+
@test Base.infer_return_type((Bool,Int,)) do b, y
6093+
x = b ? 1 : missing
6094+
inner = y -> x + y
6095+
return inner(y)
6096+
end == Union{Int,Missing}
6097+
6098+
function issue31909(ys)
6099+
x = if @noinline rand(Bool)
6100+
1
6101+
else
6102+
missing
6103+
end
6104+
map(y -> x + y, ys)
6105+
end
6106+
@test_broken Base.infer_return_type(issue31909, (Vector{Int},)) == Vector{Union{Int,Missing}}

base/reflection.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ end
348348

349349
const REFLECTION_COMPILER = RefValue{Union{Nothing, Module}}(nothing)
350350

351-
function invoke_in_typeinf_world(args...)
351+
function invoke_in_typeinf_world(@nospecialize args...)
352352
vargs = Any[args...]
353353
return ccall(:jl_call_in_typeinf_world, Any, (Ptr{Any}, Cint), vargs, length(vargs))
354354
end

0 commit comments

Comments
 (0)