Skip to content

Commit

Permalink
Merge 27ea06e into 001c666
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk authored Nov 12, 2024
2 parents 001c666 + 27ea06e commit 218eeba
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
46 changes: 31 additions & 15 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1010,9 +1010,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
# N.B. remarks are emitted within `const_prop_rettype_heuristic`
return nothing
end
if !const_prop_argument_heuristic(interp, arginfo, sv)
arg_result = const_prop_argument_heuristic(interp, arginfo, sv)
if arg_result === nothing
add_remark!(interp, sv, "[constprop] Disabled by argument heuristics")
return nothing
else
force |= arg_result
end
all_overridden = is_all_overridden(interp, arginfo, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
Expand Down Expand Up @@ -1079,16 +1082,19 @@ end
function const_prop_argument_heuristic(interp::AbstractInterpreter, arginfo::ArgInfo, sv::AbsIntState)
𝕃ᵢ = typeinf_lattice(interp)
argtypes = arginfo.argtypes
for i in 1:length(argtypes)
for i = 1:length(argtypes)
a = argtypes[i]
if has_conditional(𝕃ᵢ, sv) && isa(a, Conditional) && arginfo.fargs !== nothing
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return true
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return false
else
a = widenslotwrapper(a)
has_nontrivial_extended_info(𝕃ᵢ, a) && is_const_prop_profitable_arg(𝕃ᵢ, a) && return true
if has_nontrivial_extended_info(𝕃ᵢ, a) && is_const_prop_profitable_arg(𝕃ᵢ, a)
# force const-prop' if the function object itself has some profitable information
return i == 1 || widenconst(a) <: Function
end
end
end
return false
return nothing
end

function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any}, sv::InferenceState)
Expand Down Expand Up @@ -2920,18 +2926,15 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
else
consistent = ALWAYS_TRUE # immutable allocation is consistent
end
if isconcretedispatch(rt)
nothrow = true
@assert fcount !== nothing && fcount nargs "malformed :new expression" # syntactically enforced by the front-end
ats = Vector{Any}(undef, nargs)
local anyrefine = false
local allconst = true
function compute_fields(@nospecialize(rt))
local ats = Vector{Any}(undef, nargs)
local anyrefine, allconst, nothrow = false, true, true
for i = 1:nargs
at = widenslotwrapper(abstract_eval_value(interp, e.args[i+1], vtypes, sv))
ft = fieldtype(rt, i)
nothrow && (nothrow = (𝕃ᵢ, at, ft))
at = tmeet(𝕃ᵢ, at, ft)
at === Bottom && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
at === Bottom && return nothing
if ismutable && !isconst(rt, i)
ats[i] = ft # can't constrain this field (as it may be modified later)
continue
Expand All @@ -2943,6 +2946,13 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
end
ats[i] = at
end
return ats, anyrefine, allconst, nothrow
end
if isconcretedispatch(rt)
@assert fcount !== nothing && fcount nargs "malformed :new expression" # syntactically enforced by the front-end
ret = compute_fields(rt)
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
ats, anyrefine, allconst, nothrow = ret
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
Expand All @@ -2957,6 +2967,12 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
rt = PartialStruct(𝕃ᵢ, rt, ats)
end
else
ret = compute_fields(rt)
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
ats, anyrefine, _... = ret
if anyrefine || nargs > datatype_min_ninitialized(ut)
rt = PartialStruct(𝕃ᵢ, rt, ats)
end
rt = refine_partial_type(rt)
nothrow = false
end
Expand Down Expand Up @@ -2984,12 +3000,12 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, vtypes::Un
nothrow = isexact
rt = Const(ccall(:jl_new_structt, Any, (Any, Any), rt, at.val))
elseif (isa(at, PartialStruct) && (𝕃ᵢ, at, Tuple) && n > 0 &&
n == length(at.fields::Vector{Any}) && !isvarargtype(at.fields[end]) &&
n == length(at.fields) && !isvarargtype(at.fields[end]) &&
(let t = rt, at = at
all(i::Int -> (𝕃ᵢ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n)
all(i::Int -> (𝕃ᵢ, (at.fields)[i], fieldtype(t, i)), 1:n)
end))
nothrow = isexact
rt = PartialStruct(𝕃ᵢ, rt, at.fields::Vector{Any})
rt = PartialStruct(𝕃ᵢ, rt, at.fields)
end
else
rt = refine_partial_type(rt)
Expand Down
16 changes: 16 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6088,3 +6088,19 @@ function issue56387(nt::NamedTuple, field::Symbol=:a)
types[index]
end
@test Base.infer_return_type(issue56387, (typeof((;a=1)),)) == Type{Int}

@test Base.infer_return_type((Bool,Int,)) do b, y
x = b ? 1 : missing
inner = y -> x + y
return inner(y)
end == Union{Int,Missing}

function issue31909(ys)
x = if @noinline rand(Bool)
1
else
missing
end
map(y -> x + y, ys)
end
@test_broken Base.infer_return_type(issue31909, (Vector{Int},)) == Vector{Union{Int,Missing}}

0 comments on commit 218eeba

Please sign in to comment.