Skip to content

Commit

Permalink
refactoring on SROA passes
Browse files Browse the repository at this point in the history
All changes are cosmetic and do not change the basic functionality:
- Added the interface type to the callbacks received by `simple_walker`
  to clarify which objects are passed as callbacks to `simple_walker`.
- Replaced ambiguous names like `idx` with more descriptive ones like
  `defidx` to make the algorithm easier to understand.
  • Loading branch information
aviatesk committed Jul 29, 2024
1 parent 4dfce5d commit 01aadd7
Showing 1 changed file with 88 additions and 69 deletions.
157 changes: 88 additions & 69 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
function find_curblock(domtree::DomTree, allblocks::BitSet, curblock::Int)
# TODO: This can be much faster by looking at current level and only
# searching for those blocks in a sorted order
while !(curblock in allblocks) && curblock !== 0
while curblock allblocks && curblock 0
curblock = domtree.idoms_bb[curblock]
end
return curblock
Expand Down Expand Up @@ -190,18 +190,21 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
return walk_to_defs(compact, val, typeconstraint, predecessors, 𝕃ₒ)
end

function trivial_walker(@nospecialize(pi), @nospecialize(idx))
return nothing
end
abstract type WalkerCallback end

function pi_walker(@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
return LiftedValue(pi.val)
struct TrivialWalker <: WalkerCallback end
(::TrivialWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue)) = nothing

struct PiWalker <: WalkerCallback end
function (::PiWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if isa(def, PiNode)
return LiftedValue(def.val)
end
return nothing
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), callback=trivial_walker)
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa::AnySSAValue),
walker_callback::WalkerCallback=TrivialWalker())
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand All @@ -218,15 +221,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
def = compact[defssa][:stmt]
if isa(def, AnySSAValue)
callback(def, defssa)
walker_callback(def, defssa)
if isa(def, SSAValue)
is_old(compact, defssa) && (def = OldSSAValue(def.id))
end
defssa = def
elseif isa(def, Union{PhiNode, PhiCNode, GlobalRef})
return defssa
else
new_def = callback(def, defssa)
new_def = walker_callback(def, defssa)
if new_def === nothing
return defssa
end
Expand All @@ -241,16 +244,21 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
return LiftedValue(pi.val)
end
return nothing
mutable struct TypeConstrainingWalker <: WalkerCallback
typeconstraint::Any
TypeConstrainingWalker(@nospecialize(typeconstraint::Any)) = new(typeconstraint)
end
function (walker_callback::TypeConstrainingWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if isa(def, PiNode)
walker_callback.typeconstraint =
typeintersect(walker_callback.typeconstraint, widenconst(def.typ))
return LiftedValue(def.val)
end
def = simple_walk(compact, defssa, callback)
return nothing
end
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(val::AnySSAValue),
@nospecialize(typeconstraint))
def = simple_walk(compact, val, TypeConstrainingWalker(typeconstraint))
return Pair{Any, Any}(def, typeconstraint)
end

Expand Down Expand Up @@ -638,15 +646,17 @@ end

struct SkipToken end; const SKIP_TOKEN = SkipToken()

function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=::AnySSAValue=#), @nospecialize(old_value),
lifted_philikes::Vector{LiftedPhilike}, lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int},
walker_callback)
function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa::AnySSAValue),
@nospecialize(old_value), lifted_philikes::Vector{LiftedPhilike},
lifted_leaves::Union{LiftedLeaves, LiftedDefs},
reverse_mapping::IdDict{AnySSAValue, Int},
walker_callback::WalkerCallback)
val = old_value
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
val = simple_walk(compact, val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
val = simple_walk(compact, val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
end
if val in keys(lifted_leaves)
lifted_val = lifted_leaves[val]
Expand All @@ -656,7 +666,7 @@ function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=:
lifted_val === nothing && return UNDEF_TOKEN
val = lifted_val.val
if isa(val, AnySSAValue)
val = simple_walk(compact, val, pi_walker)
val = simple_walk(compact, val, PiWalker())
end
return val
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
Expand All @@ -673,7 +683,7 @@ function is_old(compact, @nospecialize(old_node_ssa))
return true
end

struct PhiNest{C}
struct PhiNest{C<:WalkerCallback}
visited_philikes::Vector{AnySSAValue}
lifted_philikes::Vector{LiftedPhilike}
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
Expand Down Expand Up @@ -743,20 +753,29 @@ function finish_phi_nest!(compact::IncrementalCompact, nest::PhiNest)
end
end

function def_walker(lifted_leaves::Union{LiftedLeaves, LiftedDefs}, reverse_mapping::IdDict{AnySSAValue, Int}, walker_callback)
function (@nospecialize(walk_def), @nospecialize(defssa))
if (defssa in keys(lifted_leaves)) || (isa(defssa, AnySSAValue) && defssa in keys(reverse_mapping))
return nothing
end
isa(walk_def, PiNode) && return LiftedValue(walk_def.val)
return walker_callback(walk_def, defssa)
struct LiftedLeaveWalker{C<:WalkerCallback} <: WalkerCallback
lifted_leaves::Union{LiftedLeaves, LiftedDefs}
reverse_mapping::IdDict{AnySSAValue, Int}
inner_walker_callback::C
function LiftedLeaveWalker(@nospecialize(lifted_leaves::Union{LiftedLeaves, LiftedDefs}),
@nospecialize(reverse_mapping::IdDict{AnySSAValue, Int}),
inner_walker_callback::C) where C<:WalkerCallback
return new{C}(lifted_leaves, reverse_mapping, inner_walker_callback)
end
end
function (walker_callback::LiftedLeaveWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
(; lifted_leaves, reverse_mapping, inner_walker_callback) = walker_callback
if defssa in keys(lifted_leaves) || defssa in keys(reverse_mapping)
return nothing
end
isa(def, PiNode) && return LiftedValue(def.val)
return inner_walker_callback(def, defssa)
end

function perform_lifting!(compact::IncrementalCompact,
visited_philikes::Vector{AnySSAValue}, @nospecialize(cache_key),
@nospecialize(result_t), lifted_leaves::Union{LiftedLeaves, LiftedDefs}, @nospecialize(stmt_val),
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback = trivial_walker)
lazydomtree::Union{LazyDomtree,Nothing}, walker_callback::WalkerCallback = TrivialWalker())
reverse_mapping = IdDict{AnySSAValue, Int}()
for id in 1:length(visited_philikes)
reverse_mapping[visited_philikes[id]] = id
Expand Down Expand Up @@ -839,7 +858,7 @@ function perform_lifting!(compact::IncrementalCompact,

# Fixup the stmt itself
if isa(stmt_val, Union{SSAValue, OldSSAValue})
stmt_val = simple_walk(compact, stmt_val, def_walker(lifted_leaves, reverse_mapping, walker_callback))
stmt_val = simple_walk(compact, stmt_val, LiftedLeaveWalker(lifted_leaves, reverse_mapping, walker_callback))
end

if stmt_val in keys(lifted_leaves)
Expand Down Expand Up @@ -948,6 +967,17 @@ function keyvalue_predecessors(@nospecialize(key), 𝕃ₒ::AbstractLattice)
end
end

struct KeyValueWalker <: WalkerCallback
compact::IncrementalCompact
end
function (walker_callback::KeyValueWalker)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, walker_callback.compact)
@assert length(def.args) in (5, 6)
return LiftedValue(def.args[end-2])
end
return nothing
end

function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
collection = stmt.args[end-1]
key = stmt.args[end]
Expand All @@ -964,16 +994,9 @@ function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr,
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
end

function keyvalue_walker(@nospecialize(def), _)
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
@assert length(def.args) in (5, 6)
return LiftedValue(def.args[end-2])
end
return nothing
end
(lifted_val, nest) = perform_lifting!(compact,
visited_philikes, key, result_t, lifted_leaves, collection, nothing,
keyvalue_walker)
KeyValueWalker(compact))

compact[idx] = lifted_val === nothing ? nothing : Expr(:call, GlobalRef(Core, :tuple), lifted_val.val)
finish_phi_nest!(compact, nest)
Expand Down Expand Up @@ -1139,13 +1162,11 @@ end
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

struct IntermediaryCollector
struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (this::IntermediaryCollector)(@nospecialize(pi), @nospecialize(ssa))
if !isa(pi, Expr)
push!(this.intermediaries, ssa.id)
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
return nothing
end

Expand Down Expand Up @@ -1242,7 +1263,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
update_scope_mapping!(scope_mapping, bb+1, bbs)
end
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false
is_setfield = is_isdefined = is_finalizer = false
field_ordering = :unspecified
if is_known_call(stmt, setfield!, compact)
4 <= length(stmt.args) <= 5 || continue
Expand Down Expand Up @@ -1371,8 +1392,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
if ismutabletypename(struct_typ_name)
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = IntermediaryCollector(intermediaries)
def = simple_walk(compact, val, callback)
def = simple_walk(compact, val, IntermediaryCollector(intermediaries))
# Mutable stuff here
isa(def, SSAValue) || continue
if defuses === nothing
Expand Down Expand Up @@ -1680,24 +1700,23 @@ end
function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, lazydomtree::LazyDomtree, inlining::Union{Nothing, InliningState})
𝕃ₒ = inlining === nothing ? SimpleInferenceLattice.instance : optimizer_lattice(inlining.interp)
lazypostdomtree = LazyPostDomtree(ir)
for (idx, (intermediaries, defuse)) in defuses
for (defidx, (intermediaries, defuse)) in defuses
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
# show up in the nuses_total count.
nleaves = length(defuse.uses) + length(defuse.defs)
nuses = 0
for idx in intermediaries
nuses += used_ssas[idx]
for iidx in intermediaries
nuses += used_ssas[iidx]
end
nuses_total = used_ssas[idx] + nuses - length(intermediaries)
nuses_total = used_ssas[defidx] + nuses - length(intermediaries)
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)][:stmt]
defexpr = ir[SSAValue(defidx)][:stmt]
isexpr(defexpr, :new) || continue
newidx = idx
typ = unwrap_unionall(ir.stmts[newidx][:type])
typ = unwrap_unionall(ir.stmts[defidx][:type])
# Could still end up here if we tried to setfield! on an immutable, which would
# error at runtime, but is not illegal to have in the IR.
typ = widenconst(typ)
Expand All @@ -1713,7 +1732,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
end
end
if finalizer_idx !== nothing && inlining !== nothing
try_resolve_finalizer!(ir, idx, finalizer_idx, defuse, inlining,
try_resolve_finalizer!(ir, defidx, finalizer_idx, defuse, inlining,
lazydomtree, lazypostdomtree, ir[SSAValue(finalizer_idx)][:info])
continue
end
Expand Down Expand Up @@ -1752,11 +1771,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# but we should come up with semantics for well defined semantics
# for uninitialized fields first.
ndefuse = length(fielddefuse)
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# BitSet}}(undef, ndefuse)
blocks = Vector{Tuple{#=phiblocks=#Vector{Int},#=allblocks=#BitSet}}(undef, ndefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, newidx)
push!(du.defs, defidx)
ldu = compute_live_ins(ir.cfg, du)
if isempty(ldu.live_in_bbs)
phiblocks = Int[]
Expand All @@ -1769,7 +1788,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
for i = 1:length(du.uses)
use = du.uses[i]
if use.kind === :isdefined
if has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx)
if has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx)
ir[SSAValue(use.idx)][:stmt] = true
else
all_eliminated = false
Expand All @@ -1782,7 +1801,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
continue
end
end
has_safe_def(ir, get!(lazydomtree), allblocks, du, newidx, use.idx) || @goto skip
has_safe_def(ir, get!(lazydomtree), allblocks, du, defidx, use.idx) || @goto skip
end
else # always have some definition at the allocation site
for i = 1:length(du.uses)
Expand Down Expand Up @@ -1849,19 +1868,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# all "usages" (i.e. `getfield` and `isdefined` calls) are eliminated,
# now eliminate "definitions" (i.e. `setfield!`) calls
# (NOTE the allocation itself will be eliminated by DCE pass later)
for idx in du.defs
idx == newidx && continue # this is allocation
for didx in du.defs
didx == defidx && continue # this is allocation
# verify this statement won't throw, otherwise it can't be eliminated safely
ssa = SSAValue(idx)
if is_nothrow(ir, ssa)
ir[ssa][:stmt] = nothing
setfield_ssa = SSAValue(didx)
if is_nothrow(ir, setfield_ssa)
ir[setfield_ssa][:stmt] = nothing
else
# We can't eliminate this statement, because it might still
# throw an error, but we can mark it as effect-free since we
# know we have removed all uses of the mutable allocation.
# As a result, if we ever do prove nothrow, we can delete
# this statement then.
add_flag!(ir[ssa], IR_FLAG_EFFECT_FREE)
add_flag!(ir[setfield_ssa], IR_FLAG_EFFECT_FREE)
end
end
end
Expand All @@ -1870,7 +1889,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# this means all ccall preserves have been replaced with forwarded loads
# so we can potentially eliminate the allocation, otherwise we must preserve
# the whole allocation.
push!(intermediaries, newidx)
push!(intermediaries, defidx)
end
# Insert the new preserves
for (useidx, new_preserves) in preserve_uses
Expand Down

0 comments on commit 01aadd7

Please sign in to comment.