diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 0af689460bb20..3825f33e09f72 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -331,15 +331,15 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes:: end valid_worlds = intersect(valid_worlds, matches.valid_worlds) thisfullmatch = any(match::MethodMatch->match.fully_covers, matches) - found = false + mt_found = false for (i, mt′) in enumerate(mts) if mt′ === mt fullmatches[i] &= thisfullmatch - found = true + mt_found = true break end end - if !found + if !mt_found push!(mts, mt) push!(fullmatches, thisfullmatch) end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index a53b19d43ffb0..727e015b67062 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1399,14 +1399,7 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig if !fully_covered atype = argtypes_to_type(sig.argtypes) # We will emit an inline MethodError so we need a backedge to the MethodTable - unwrapped_info = info isa ConstCallInfo ? info.call : info - if unwrapped_info isa UnionSplitInfo - for (fullmatch, mt) in zip(unwrapped_info.fullmatches, unwrapped_info.mts) - !fullmatch && push!(state.edges, mt, atype) - end - elseif unwrapped_info isa MethodMatchInfo - push!(state.edges, unwrapped_info.mt, atype) - else @assert false end + add_uncovered_edges!(state.edges, info, atype) end elseif !isempty(cases) # if we've not seen all candidates, union split is valid only for dispatch tuples diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index d637f6f67d5de..33fca90b6261e 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -39,6 +39,7 @@ end nsplit_impl(info::MethodMatchInfo) = 1 getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results) getresult_impl(::MethodMatchInfo, ::Int) = nothing +add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); ) """ info::UnionSplitInfo <: CallInfo @@ -66,6 +67,11 @@ end nsplit_impl(info::UnionSplitInfo) = length(info.matches) getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx] getresult_impl(::UnionSplitInfo, ::Int) = nothing +function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype)) + for (mt, fullmatch) in zip(info.mts, info.fullmatches) + !fullmatch && push!(edges, mt, atype) + end +end abstract type ConstResult end @@ -109,6 +115,7 @@ end nsplit_impl(info::ConstCallInfo) = nsplit(info.call) getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx) getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx] +add_uncovered_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_uncovered_edges!(edges, info.call, atype) """ info::MethodResultPure <: CallInfo diff --git a/base/compiler/types.jl b/base/compiler/types.jl index f315b7968fd9b..f3b02337c509f 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -451,9 +451,12 @@ abstract type CallInfo end nsplit(info::CallInfo) = nsplit_impl(info)::Union{Nothing,Int} getsplit(info::CallInfo, idx::Int) = getsplit_impl(info, idx)::MethodLookupResult getresult(info::CallInfo, idx::Int) = getresult_impl(info, idx) +add_uncovered_edges!(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = add_uncovered_edges_impl(edges, info, atype) + nsplit_impl(::CallInfo) = nothing getsplit_impl(::CallInfo, ::Int) = error("unexpected call into `getsplit`") getresult_impl(::CallInfo, ::Int) = nothing +add_uncovered_edges_impl(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = nothing @specialize diff --git a/test/compiler/AbstractInterpreter.jl b/test/compiler/AbstractInterpreter.jl index d95354cefa80c..e92b67f980942 100644 --- a/test/compiler/AbstractInterpreter.jl +++ b/test/compiler/AbstractInterpreter.jl @@ -409,6 +409,7 @@ end CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info) CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx) CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx) +CC.add_uncovered_edges_impl(edges::Vector{Any}, info::NoinlineCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype) function CC.abstract_call(interp::NoinlineInterpreter, arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int) @@ -431,6 +432,8 @@ end @inline function inlined_usually(x, y, z) return x * y + z end +foo_split(x::Float64) = 1 +foo_split(x::Int) = 2 # check if the inlining algorithm works as expected let src = code_typed1((Float64,Float64,Float64)) do x, y, z @@ -444,6 +447,7 @@ let NoinlineModule = Module() main_func(x, y, z) = inlined_usually(x, y, z) @eval NoinlineModule noinline_func(x, y, z) = $inlined_usually(x, y, z) @eval OtherModule other_func(x, y, z) = $inlined_usually(x, y, z) + @eval NoinlineModule bar_split_error() = $foo_split(Core.compilerbarrier(:type, nothing)) interp = NoinlineInterpreter(Set((NoinlineModule,))) @@ -473,6 +477,11 @@ let NoinlineModule = Module() @test count(isinvoke(:inlined_usually), src.code) == 0 @test count(iscall((src, inlined_usually)), src.code) == 0 end + + let src = code_typed1(NoinlineModule.bar_split_error) + @test count(iscall((src, foo_split)), src.code) == 0 + @test count(iscall((src, Core.throw_methoderror)), src.code) > 0 + end end # Make sure that Core.Compiler has enough NamedTuple infrastructure