Skip to content

Commit d2dd076

Browse files
authored
optimize cfg_simplify! (#51958)
> Before ```julia julia> @benchmark CC.cfg_simplify!(ir) setup=(ir = CC.copy(Main.ir)) BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 196.333 μs … 22.691 ms ┊ GC (min … max): 0.00% … 98.80% Time (median): 208.792 μs ┊ GC (median): 0.00% Time (mean ± σ): 219.824 μs ± 390.126 μs ┊ GC (mean ± σ): 4.18% ± 2.40% ▄█▃ ▁▁▁▁▁▁▁▁▁▂▂▃▆███▄▄▅▇▇▅▄▄▄▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 196 μs Histogram: frequency by time 235 μs < Memory estimate: 169.48 KiB, allocs estimate: 4005. ``` > After ```julia julia> @benchmark CC.cfg_simplify!(ir) setup=(ir = CC.copy(Main.ir)) BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 49.541 μs … 31.882 ms ┊ GC (min … max): 0.00% … 99.66% Time (median): 54.042 μs ┊ GC (median): 0.00% Time (mean ± σ): 61.492 μs ± 392.747 μs ┊ GC (mean ± σ): 10.36% ± 1.72% ▄█▇▄▁ ▁▁▁▁▁▂▂▃▄▄▅██████▇▅▅▄▃▃▃▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 49.5 μs Histogram: frequency by time 67.2 μs < Memory estimate: 72.19 KiB, allocs estimate: 808. ``` with ```julia julia> ir, = only(Base.code_ircode(sin, (Float64,)); ```
1 parent 65a0fd0 commit d2dd076

File tree

2 files changed

+76
-82
lines changed

2 files changed

+76
-82
lines changed

base/compiler/ssair/passes.jl

Lines changed: 74 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,29 +1917,65 @@ end
19171917

19181918
is_terminator(@nospecialize(stmt)) = isa(stmt, GotoNode) || isa(stmt, GotoIfNot) || isexpr(stmt, :enter)
19191919

1920-
function cfg_simplify!(ir::IRCode)
1921-
bbs = ir.cfg.blocks
1922-
merge_into = zeros(Int, length(bbs))
1923-
merged_succ = zeros(Int, length(bbs))
1924-
dropped_bbs = Vector{Int}() # sorted
1925-
function follow_merge_into(idx::Int)
1926-
while merge_into[idx] != 0
1927-
idx = merge_into[idx]
1928-
end
1929-
return idx
1920+
function follow_map(map::Vector{Int}, idx::Int)
1921+
while map[idx] 0
1922+
idx = map[idx]
1923+
end
1924+
return idx
1925+
end
1926+
1927+
function ascend_eliminated_preds(bbs::Vector{BasicBlock}, pred::Int)
1928+
while pred != 1 && length(bbs[pred].preds) == 1 && length(bbs[pred].succs) == 1
1929+
pred = bbs[pred].preds[1]
19301930
end
1931-
function follow_merged_succ(idx::Int)
1932-
while merged_succ[idx] != 0
1933-
idx = merged_succ[idx]
1931+
return pred
1932+
end
1933+
1934+
# Compute (renamed) successors and predecessors given (renamed) block
1935+
function compute_succs(merged_succ::Vector{Int}, bbs::Vector{BasicBlock}, result_bbs::Vector{Int}, bb_rename_succ::Vector{Int}, i::Int)
1936+
orig_bb = follow_map(merged_succ, result_bbs[i])
1937+
return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
1938+
end
1939+
1940+
function compute_preds(bbs::Vector{BasicBlock}, result_bbs::Vector{Int}, bb_rename_pred::Vector{Int}, i::Int)
1941+
orig_bb = result_bbs[i]
1942+
preds = copy(bbs[orig_bb].preds)
1943+
res = Int[]
1944+
while !isempty(preds)
1945+
pred = popfirst!(preds)
1946+
if pred == 0
1947+
push!(res, 0)
1948+
continue
1949+
end
1950+
r = bb_rename_pred[pred]
1951+
(r == -2 || r == -1) && continue
1952+
if r == -3
1953+
prepend!(preds, bbs[pred].preds)
1954+
else
1955+
push!(res, r)
19341956
end
1935-
return idx
19361957
end
1937-
function ascend_eliminated_preds(pred)
1938-
while pred != 1 && length(bbs[pred].preds) == 1 && length(bbs[pred].succs) == 1
1939-
pred = bbs[pred].preds[1]
1958+
return res
1959+
end
1960+
1961+
function add_preds!(all_new_preds::Vector{Int32}, bbs::Vector{BasicBlock}, bb_rename_pred::Vector{Int}, old_edge::Int32)
1962+
preds = copy(bbs[old_edge].preds)
1963+
while !isempty(preds)
1964+
old_edge′ = popfirst!(preds)
1965+
new_edge = bb_rename_pred[old_edge′]
1966+
if new_edge > 0 && new_edge all_new_preds
1967+
push!(all_new_preds, Int32(new_edge))
1968+
elseif new_edge == -3
1969+
prepend!(preds, bbs[old_edge′].preds)
19401970
end
1941-
return pred
19421971
end
1972+
end
1973+
1974+
function cfg_simplify!(ir::IRCode)
1975+
bbs = ir.cfg.blocks
1976+
merge_into = zeros(Int, length(bbs))
1977+
merged_succ = zeros(Int, length(bbs))
1978+
dropped_bbs = Vector{Int}() # sorted
19431979

19441980
# Walk the CFG from the entry block and aggressively combine blocks
19451981
for (idx, bb) in enumerate(bbs)
@@ -1953,21 +1989,21 @@ function cfg_simplify!(ir::IRCode)
19531989
end
19541990
# Prevent cycles by making sure we don't end up back at `idx`
19551991
# by following what is to be merged into `succ`
1956-
if follow_merged_succ(succ) != idx
1992+
if follow_map(merged_succ, succ) != idx
19571993
merge_into[succ] = idx
19581994
merged_succ[idx] = succ
19591995
end
19601996
elseif merge_into[idx] == 0 && is_bb_empty(ir, bb) && is_legal_bb_drop(ir, idx, bb)
19611997
# If this BB is empty, we can still merge it as long as none of our successor's phi nodes
19621998
# reference our predecessors.
19631999
found_interference = false
1964-
preds = Int[ascend_eliminated_preds(pred) for pred in bb.preds]
2000+
preds = Int[ascend_eliminated_preds(bbs, pred) for pred in bb.preds]
19652001
for idx in bbs[succ].stmts
19662002
stmt = ir[SSAValue(idx)][:stmt]
19672003
stmt === nothing && continue
19682004
isa(stmt, PhiNode) || break
19692005
for edge in stmt.edges
1970-
edge = ascend_eliminated_preds(edge)
2006+
edge = ascend_eliminated_preds(bbs, Int(edge))
19712007
for pred in preds
19722008
if pred == edge
19732009
found_interference = true
@@ -1986,7 +2022,7 @@ function cfg_simplify!(ir::IRCode)
19862022

19872023
# Assign new BB numbers in DFS order, dropping unreachable blocks
19882024
max_bb_num = 1
1989-
bb_rename_succ = fill(0, length(bbs))
2025+
bb_rename_succ = zeros(Int, length(bbs))
19902026
worklist = BitSetBoundedMinPrioritySet(length(bbs))
19912027
push!(worklist, 1)
19922028
while !isempty(worklist)
@@ -2025,8 +2061,9 @@ function cfg_simplify!(ir::IRCode)
20252061
push!(worklist, terminator.dest)
20262062
end
20272063
elseif isexpr(terminator, :enter)
2028-
if bb_rename_succ[terminator.args[1]] == 0
2029-
push!(worklist, terminator.args[1])
2064+
enteridx = terminator.args[1]::Int
2065+
if bb_rename_succ[enteridx] == 0
2066+
push!(worklist, enteridx)
20302067
end
20312068
end
20322069
ncurr = curr + 1
@@ -2112,7 +2149,7 @@ function cfg_simplify!(ir::IRCode)
21122149
elseif is_multi
21132150
bb_rename_pred[i] = -3
21142151
else
2115-
bbnum = follow_merge_into(pred)
2152+
bbnum = follow_map(merge_into, pred)
21162153
bb_rename_pred[i] = bb_rename_succ[bbnum]
21172154
end
21182155
end
@@ -2134,48 +2171,12 @@ function cfg_simplify!(ir::IRCode)
21342171
bb_starts[i+1] = bb_starts[i] + result_bbs_lengths[i]
21352172
end
21362173

2137-
cresult_bbs = let result_bbs = result_bbs,
2138-
merged_succ = merged_succ,
2139-
merge_into = merge_into,
2140-
bbs = bbs,
2141-
bb_rename_succ = bb_rename_succ
2142-
2143-
# Compute (renamed) successors and predecessors given (renamed) block
2144-
function compute_succs(i::Int)
2145-
orig_bb = follow_merged_succ(result_bbs[i])
2146-
return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
2147-
end
2148-
function compute_preds(i::Int)
2149-
orig_bb = result_bbs[i]
2150-
preds = bbs[orig_bb].preds
2151-
res = Int[]
2152-
function scan_preds!(preds::Vector{Int})
2153-
for pred in preds
2154-
if pred == 0
2155-
push!(res, 0)
2156-
continue
2157-
end
2158-
r = bb_rename_pred[pred]
2159-
(r == -2 || r == -1) && continue
2160-
if r == -3
2161-
scan_preds!(bbs[pred].preds)
2162-
else
2163-
push!(res, r)
2164-
end
2165-
end
2166-
end
2167-
scan_preds!(preds)
2168-
return res
2169-
end
2170-
2171-
BasicBlock[
2172-
BasicBlock(StmtRange(bb_starts[i],
2173-
i+1 > length(bb_starts) ?
2174-
length(compact.result) : bb_starts[i+1]-1),
2175-
compute_preds(i),
2176-
compute_succs(i))
2177-
for i = 1:length(result_bbs)]
2178-
end
2174+
cresult_bbs = BasicBlock[
2175+
BasicBlock(StmtRange(bb_starts[i],
2176+
i+1 > length(bb_starts) ? length(compact.result) : bb_starts[i+1]-1),
2177+
compute_preds(bbs, result_bbs, bb_rename_pred, i),
2178+
compute_succs(merged_succ, bbs, result_bbs, bb_rename_succ, i))
2179+
for i = 1:length(result_bbs)]
21792180

21802181
# Fixup terminators for any blocks that would have caused double edges
21812182
for (bbidx, (new_bb, old_bb)) in enumerate(zip(cresult_bbs, result_bbs))
@@ -2186,7 +2187,7 @@ function cfg_simplify!(ir::IRCode)
21862187
terminator = ir[SSAValue(last(bbs[old_bb2].stmts))]
21872188
@assert terminator[:stmt] isa GotoIfNot
21882189
# N.B.: The dest will be renamed in process_node! below
2189-
terminator[:stmt] = GotoNode(terminator[:stmt].dest)
2190+
terminator[:stmt] = GotoNode(terminator[:stmt].dest::Int)
21902191
pop!(new_bb.succs)
21912192
new_succ = cresult_bbs[new_bb.succs[1]]
21922193
for (i, nsp) in enumerate(new_succ.preds)
@@ -2210,11 +2211,12 @@ function cfg_simplify!(ir::IRCode)
22102211
for i in bbs[ms].stmts
22112212
node = ir.stmts[i]
22122213
compact.result[compact.result_idx] = node
2213-
if isa(node[:stmt], GotoNode) && merged_succ[ms] != 0
2214+
stmt = node[:stmt]
2215+
if isa(stmt, GotoNode) && merged_succ[ms] != 0
22142216
# If we merged a basic block, we need remove the trailing GotoNode (if any)
22152217
compact.result[compact.result_idx][:stmt] = nothing
2216-
elseif isa(node[:stmt], PhiNode)
2217-
phi = node[:stmt]
2218+
elseif isa(stmt, PhiNode)
2219+
phi = stmt
22182220
values = phi.values
22192221
(; ssa_rename, late_fixup, used_ssas, new_new_used_ssas) = compact
22202222
ssa_rename[i] = SSAValue(compact.result_idx)
@@ -2236,17 +2238,7 @@ function cfg_simplify!(ir::IRCode)
22362238
elseif new_edge == -3
22372239
# Multiple predecessors, we need to expand out this phi
22382240
all_new_preds = Int32[]
2239-
function add_preds!(old_edge)
2240-
for old_edge′ in bbs[old_edge].preds
2241-
new_edge = bb_rename_pred[old_edge′]
2242-
if new_edge > 0 && !in(new_edge, all_new_preds)
2243-
push!(all_new_preds, new_edge)
2244-
elseif new_edge == -3
2245-
add_preds!(old_edge′)
2246-
end
2247-
end
2248-
end
2249-
add_preds!(old_edge)
2241+
add_preds!(all_new_preds, bbs, bb_rename_pred, old_edge)
22502242
append!(edges, all_new_preds)
22512243
if isassigned(renamed_values, old_index)
22522244
val = renamed_values[old_index]

test/compiler/irpasses.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,3 +1568,5 @@ let m = Meta.@lower 1 + 1
15681568

15691569
Core.Compiler.verify_ir(ir)
15701570
end
1571+
1572+
# JET.test_opt(Core.Compiler.cfg_simplify!, (Core.Compiler.IRCode,))

0 commit comments

Comments
 (0)