Skip to content

Commit

Permalink
simplify the fields of UnionSplitInfo (#55815)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk authored Sep 20, 2024
1 parent b30f80d commit 7f7a472
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 63 deletions.
68 changes: 36 additions & 32 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,41 @@ any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)
function add_uncovered_edges!(sv::AbsIntState, info::MethodMatchInfo, @nospecialize(atype))
fully_covering(info) || add_mt_backedge!(sv, info.mt, atype)
nothing
end
add_uncovered_edges!(sv::AbsIntState, matches::MethodMatches, @nospecialize(atype)) =
add_uncovered_edges!(sv, matches.info, atype)

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
end
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.matches)
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(info.fullmatches)
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)
function add_uncovered_edges!(sv::AbsIntState, info::UnionSplitInfo, @nospecialize(atype))
all(fully_covering, info.split) && return nothing
# add mt backedges with removing duplications
for mt in uncovered_method_tables(info)
add_mt_backedge!(sv, mt, atype)
end
end
add_uncovered_edges!(sv::AbsIntState, matches::UnionSplitMethodMatches, @nospecialize(atype)) =
add_uncovered_edges!(sv, matches.info, atype)
function uncovered_method_tables(info::UnionSplitInfo)
mts = MethodTable[]
for mminfo in info.split
fully_covering(mminfo) && continue
any(mt′::MethodTable->mt′===mminfo.mt, mts) && continue
push!(mts, mminfo.mt)
end
return mts
end

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
Expand All @@ -308,43 +332,30 @@ is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodLookupResult[]
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
matches = findall(sig_n, method_table(interp); limit = max_methods)
if matches === nothing
thismatches = findall(sig_n, method_table(interp); limit = max_methods)
if thismatches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, matches)
for m in matches
for m in thismatches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
mt_found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
mt_found = true
break
end
end
if !mt_found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
valid_worlds = intersect(valid_worlds, thismatches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, thismatches)
thisinfo = MethodMatchInfo(thismatches, mt, thisfullmatch)
push!(infos, thisinfo)
end
info = UnionSplitInfo(infos, mts, fullmatches)
info = UnionSplitInfo(infos)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds)
end
Expand Down Expand Up @@ -583,14 +594,7 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
matches::UnionSplitMethodMatches
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
add_uncovered_edges!(sv, matches, atype)
return nothing
end

Expand Down
23 changes: 13 additions & 10 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
getresult_impl(::MethodMatchInfo, ::Int) = nothing
add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); )
function add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype))
fully_covering(info) || push!(edges, info.mt, atype)
nothing
end

"""
info::UnionSplitInfo <: CallInfo
Expand All @@ -51,25 +54,25 @@ each partition (`info.matches::Vector{MethodMatchInfo}`).
This info is illegal on any statement that is not a call to a generic function.
"""
struct UnionSplitInfo <: CallInfo
matches::Vector{MethodLookupResult}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
split::Vector{MethodMatchInfo}
end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += length(mminfo)
for mminfo in info.split
n += nmatches(mminfo)
end
return n
end
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
nsplit_impl(info::UnionSplitInfo) = length(info.split)
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit(info.split[idx], 1)
getresult_impl(::UnionSplitInfo, ::Int) = nothing
function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
for (mt, fullmatch) in zip(info.mts, info.fullmatches)
!fullmatch && push!(edges, mt, atype)
all(fully_covering, info.split) && return nothing
# add mt backedges with removing duplications
for mt in uncovered_method_tables(info)
push!(edges, mt, atype)
end
end

Expand Down
32 changes: 11 additions & 21 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2979,33 +2979,23 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
else
(; valid_worlds, applicable) = matches
update_valid_age!(sv, valid_worlds)

# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end

napplicable = length(applicable)
if napplicable == 0
rt = Const(false) # never any matches
elseif !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
else
rt = Const(true) # has applicable matches
for i in 1:napplicable
match = applicable[i]::MethodMatch
edge = specialize_method(match)::MethodInstance
add_backedge!(sv, edge)
end

if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
end
end
for i in 1:napplicable
match = applicable[i]::MethodMatch
edge = specialize_method(match)::MethodInstance
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_uncovered_edges!(sv, matches, atype)
end
return CallMeta(rt, Union{}, EFFECTS_TOTAL, NoCallInfo())
end
Expand Down

0 comments on commit 7f7a472

Please sign in to comment.