Skip to content

Commit

Permalink
Inline statically known method errors.
Browse files Browse the repository at this point in the history
This replaces the Expr(:call, ...) with a call to Core.throw_methoderror

Co-authored-by: Cody Tapscott <[email protected]>
  • Loading branch information
gbaraldi and topolarity committed Sep 6, 2024
1 parent 3a7a52b commit 0d854f5
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 62 deletions.
35 changes: 18 additions & 17 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
rettype = exctype = Any
all_effects = Effects()
else
if (matches isa MethodMatches ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches)))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
exctype = exctype ₚ MethodError
Expand Down Expand Up @@ -275,21 +274,23 @@ struct MethodMatches
applicable::Vector{Any}
info::MethodMatchInfo
valid_worlds::WorldRange
mt::MethodTable
fullmatch::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(result::MethodLookupResult) = result.ambig
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.matches)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(info.fullmatches)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
Expand All @@ -307,7 +308,7 @@ is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodMatchInfo[]
infos = MethodLookupResult[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
Expand All @@ -323,7 +324,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
push!(infos, matches)
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
Expand All @@ -343,9 +344,9 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
push!(fullmatches, thisfullmatch)
end
end
info = UnionSplitInfo(infos)
info = UnionSplitInfo(infos, mts, fullmatches)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
applicable, applicable_argtypes, info, valid_worlds)
end

function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
Expand All @@ -360,10 +361,9 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
info = MethodMatchInfo(matches)
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(
matches.matches, info, matches.valid_worlds, mt, fullmatch)
info = MethodMatchInfo(matches, mt, fullmatch)
return MethodMatches(matches.matches, info, matches.valid_worlds)
end

"""
Expand Down Expand Up @@ -584,9 +584,10 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
matches::UnionSplitMethodMatches
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand Down
56 changes: 34 additions & 22 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ struct InliningCase
end

struct UnionSplit
fully_covered::Bool
handled_all_cases::Bool # All possible dispatches are included in the cases
fully_covered::Bool # All handled cases are fully covering
atype::DataType
cases::Vector{InliningCase}
bbs::Vector{Int}
UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(fully_covered, atype, cases, Int[])
UnionSplit(handled_all_cases::Bool, fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(handled_all_cases, fully_covered, atype, cases, Int[])
end

struct InliningEdgeTracker
Expand Down Expand Up @@ -215,7 +216,7 @@ end

function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
state::CFGInliningState, params::OptimizationParams)
(; fully_covered, #=atype,=# cases, bbs) = union_split
(; handled_all_cases, fully_covered, #=atype,=# cases, bbs) = union_split
inline_into_block!(state, block_for_inst(ir, idx))
from_bbs = Int[]
delete!(state.split_targets, length(state.new_cfg_blocks))
Expand All @@ -235,7 +236,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
push!(from_bbs, length(state.new_cfg_blocks))
if !(i == length(cases) && fully_covered)
if !(i == length(cases) && (handled_all_cases && fully_covered))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand All @@ -244,7 +245,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
# The edge from the fallback block.
fully_covered || push!(from_bbs, length(state.new_cfg_blocks))
!handled_all_cases && push!(from_bbs, length(state.new_cfg_blocks))
# This block will be the block everyone returns to
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs))
join_bb = length(state.new_cfg_blocks)
Expand Down Expand Up @@ -523,7 +524,7 @@ assuming their order stays the same post-discovery in `ml_matches`.
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any},
union_split::UnionSplit, boundscheck::Symbol,
todo_bbs::Vector{Tuple{Int,Int}}, interp::AbstractInterpreter)
(; fully_covered, atype, cases, bbs) = union_split
(; handled_all_cases, fully_covered, atype, cases, bbs) = union_split
stmt, typ, line = compact.result[idx][:stmt], compact.result[idx][:type], compact.result[idx][:line]
join_bb = bbs[end]
pn = PhiNode()
Expand All @@ -538,7 +539,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
cond = true
nparams = fieldcount(atype)
@assert nparams == fieldcount(mtype)
if !(i == ncases && fully_covered)
if !(i == ncases && fully_covered && handled_all_cases)
for i = 1:nparams
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -597,14 +598,18 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
end
bb += 1
# We're now in the fall through block, decide what to do
if !fully_covered
if !handled_all_cases
ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line))
push!(pn.edges, bb)
push!(pn.values, ssa)
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), Any, line))
finish_current_bb!(compact, 0)
elseif !fully_covered
insert_node_here!(compact, NewInstruction(Expr(:call, GlobalRef(Core, :throw_methoderror), argexprs...), Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
finish_current_bb!(compact, 0)
ncases == 0 && return insert_node_here!(compact, NewInstruction(nothing, Any, line))
end

# We're now in the join block.
return insert_node_here!(compact, NewInstruction(pn, typ, line))
end
Expand Down Expand Up @@ -1348,10 +1353,6 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
handled_all_cases = false
continue
end
local split_fully_covered = false
for (j, match) in enumerate(meth)
Expand Down Expand Up @@ -1392,22 +1393,33 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
handled_all_cases &= handle_any_const_result!(cases,
result, match, argtypes, info, flag, state; allow_typevars=true)
end
if !fully_covered
atype = argtypes_to_type(sig.argtypes)
# We will emit an inline MethodError so we need a backedge to the MethodTable
unwrapped_info = info isa ConstCallInfo ? info.call : info
if unwrapped_info isa UnionSplitInfo
for (fullmatch, mt) in zip(unwrapped_info.fullmatches, unwrapped_info.mts)
!fullmatch && push!(state.edges, mt, atype)
end
elseif unwrapped_info isa MethodMatchInfo
push!(state.edges, unwrapped_info.mt, atype)
else @assert false end
end
elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

return cases, (handled_all_cases & fully_covered), joint_effects
return cases, handled_all_cases, fully_covered, joint_effects
end

function handle_call!(todo::Vector{Pair{Int,Any}},
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature,
state::InliningState)
cases = compute_inlining_cases(info, flag, sig, state)
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
cases, handled_all_cases, fully_covered, joint_effects = cases
atype = argtypes_to_type(sig.argtypes)
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
handle_cases!(todo, ir, idx, stmt, atype, cases, handled_all_cases, fully_covered, joint_effects)
end

function handle_match!(cases::Vector{InliningCase},
Expand Down Expand Up @@ -1496,19 +1508,19 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
end

function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr,
@nospecialize(atype), cases::Vector{InliningCase}, all_covered::Bool,
@nospecialize(atype), cases::Vector{InliningCase}, handled_all_cases::Bool, fully_covered::Bool,
joint_effects::Effects)
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if all_covered && length(cases) == 1
if fully_covered && handled_all_cases && length(cases) == 1
handle_single_case!(todo, ir, idx, stmt, cases[1].item)
elseif length(cases) > 0
elseif length(cases) > 0 || handled_all_cases
isa(atype, DataType) || return nothing
for case in cases
isa(case.sig, DataType) || return nothing
end
push!(todo, idx=>UnionSplit(all_covered, atype, cases))
push!(todo, idx=>UnionSplit(handled_all_cases, fully_covered, atype, cases))
else
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
end
Expand Down
10 changes: 7 additions & 3 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ not a call to a generic function.
"""
struct MethodMatchInfo <: CallInfo
results::MethodLookupResult
mt::MethodTable
fullmatch::Bool
end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
Expand All @@ -48,19 +50,21 @@ each partition (`info.matches::Vector{MethodMatchInfo}`).
This info is illegal on any statement that is not a call to a generic function.
"""
struct UnionSplitInfo <: CallInfo
matches::Vector{MethodMatchInfo}
matches::Vector{MethodLookupResult}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += nmatches(mminfo)
n += length(mminfo)
end
return n
end
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit_impl(info.matches[idx], 1)
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
getresult_impl(::UnionSplitInfo, ::Int) = nothing

abstract type ConstResult end
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2983,9 +2983,9 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand All @@ -3001,8 +3001,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
add_backedge!(sv, edge)
end

if isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
end
Expand Down
4 changes: 2 additions & 2 deletions test/compiler/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2240,13 +2240,13 @@ end
# accounts for ThrownEscape via potential MethodError

# no method error
@noinline identity_if_string(x::SafeRef) = (println("preventing inlining"); nothing)
@noinline identity_if_string(x::SafeRef{<:AbstractString}) = (println("preventing inlining"); nothing)
let result = code_escapes((SafeRef{String},)) do x
identity_if_string(x)
end
@test has_no_escape(ignore_argescape(result.state[Argument(2)]))
end
let result = code_escapes((Union{SafeRef{String},Nothing},)) do x
let result = code_escapes((SafeRef,)) do x
identity_if_string(x)
end
i = only(findall(iscall((result.ir, identity_if_string)), result.ir.stmts.stmt))
Expand Down
Loading

0 comments on commit 0d854f5

Please sign in to comment.