Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1379,9 +1379,7 @@ function inline_const_if_inlineable!(inst::Instruction)
end

function assemble_inline_todo!(ir::IRCode, state::InliningState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Pair{Int, Any}[]
et = state.et

for idx in 1:length(ir.stmts)
simpleres = process_simple!(ir, idx, state, todo)
Expand Down Expand Up @@ -1586,6 +1584,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
end
end
end
isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
Expand Down
104 changes: 41 additions & 63 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ struct UndefToken end; const UNDEF_TOKEN = UndefToken()
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
return stmt.val
elseif isa(stmt, Union{SSAValue, NewSSAValue})
op == 1 || return OOB_TOKEN
return stmt
elseif isa(stmt, UpsilonNode)
isdefined(stmt, :val) || return OOB_TOKEN
op == 1 || return OOB_TOKEN
Expand Down Expand Up @@ -430,6 +433,9 @@ end
elseif isa(stmt, ReturnNode)
op == 1 || throw(BoundsError())
stmt = typeof(stmt)(v)
elseif isa(stmt, Union{SSAValue, NewSSAValue})
op == 1 || throw(BoundsError())
stmt = v
elseif isa(stmt, UpsilonNode)
op == 1 || throw(BoundsError())
stmt = typeof(stmt)(v)
Expand Down Expand Up @@ -457,7 +463,7 @@ end

function userefs(@nospecialize(x))
relevant = (isa(x, Expr) && is_relevant_expr(x)) ||
isa(x, GotoIfNot) || isa(x, ReturnNode) ||
isa(x, GotoIfNot) || isa(x, ReturnNode) || isa(x, SSAValue) || isa(x, NewSSAValue) ||
isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode)
return UseRefIterator(x, relevant)
end
Expand All @@ -480,50 +486,10 @@ end

# This function is used from the show code, which may have a different
# `push!`/`used` type since it's in Base.
function scan_ssa_use!(push!, used, @nospecialize(stmt))
if isa(stmt, SSAValue)
push!(used, stmt.id)
end
for useref in userefs(stmt)
val = useref[]
if isa(val, SSAValue)
push!(used, val.id)
end
end
end
scan_ssa_use!(push!, used, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)

# Manually specialized copy of the above with push! === Compiler.push!
function scan_ssa_use!(used::IdSet, @nospecialize(stmt))
if isa(stmt, SSAValue)
push!(used, stmt.id)
end
for useref in userefs(stmt)
val = useref[]
if isa(val, SSAValue)
push!(used, val.id)
end
end
end

function ssamap(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
op[] = f(val)
end
end
return urs[]
end

function foreachssa(f, @nospecialize(stmt))
for op in userefs(stmt)
val = op[]
if isa(val, SSAValue)
f(val)
end
end
end
scan_ssa_use!(used::IdSet, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt)

function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::Bool=false)
node = add!(ir.new_nodes, pos, attach_after)
Expand Down Expand Up @@ -751,20 +717,13 @@ end

function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
needs_late_fixup = false
if isa(v, SSAValue)
compact.used_ssas[v.id] += 1
elseif isa(v, NewSSAValue)
compact.new_new_used_ssas[v.id] += 1
needs_late_fixup = true
else
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
needs_late_fixup = true
end
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
compact.new_new_used_ssas[val.id] += 1
needs_late_fixup = true
end
end
return needs_late_fixup
Expand Down Expand Up @@ -931,6 +890,27 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int)
return compact
end

__set_check_ssa_counts(onoff::Bool) = __check_ssa_counts__[] = onoff
const __check_ssa_counts__ = fill(false)

function _oracle_check(compact::IncrementalCompact)
observed_used_ssas = Core.Compiler.find_ssavalue_uses1(compact)
for i = 1:length(observed_used_ssas)
if observed_used_ssas[i] != compact.used_ssas[i]
return observed_used_ssas
end
end
return nothing
end

function oracle_check(compact::IncrementalCompact)
maybe_oracle_used_ssas = _oracle_check(compact)
if maybe_oracle_used_ssas !== nothing
@eval Main (compact = $compact; oracle_used_ssas = $maybe_oracle_used_ssas)
error("Oracle check failed, inspect Main.compact and Main.oracle_used_ssas")
end
end

getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id)
function getindex(view::TypesView, idx::Int)
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
Expand Down Expand Up @@ -1425,7 +1405,6 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
# result_idx is not, incremented, but that's ok and expected
compact.result[old_result_idx] = compact.ir.stmts[idx]
result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, active_bb, true)
stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx][:inst]
compact.result_idx = result_idx
if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx)
finish_current_bb!(compact, active_bb, old_result_idx)
Expand Down Expand Up @@ -1464,11 +1443,7 @@ function maybe_erase_unused!(
callback(val)
end
if effect_free
if isa(stmt, SSAValue)
kill_ssa_value(stmt)
else
foreachssa(kill_ssa_value, stmt)
end
foreachssa(kill_ssa_value, stmt)
inst[:inst] = nothing
return true
end
Expand Down Expand Up @@ -1570,6 +1545,9 @@ end
function complete(compact::IncrementalCompact)
result_bbs = resize!(compact.result_bbs, compact.active_result_bb-1)
cfg = CFG(result_bbs, Int[first(result_bbs[i].stmts) for i in 2:length(result_bbs)])
if __check_ssa_counts__[]
oracle_check(compact)
end
return IRCode(compact.ir, compact.result, cfg, compact.new_new_nodes)
end

Expand Down
9 changes: 0 additions & 9 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1151,15 +1151,6 @@ function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact
end
end

function count_uses(@nospecialize(stmt), uses::Vector{Int})
for ur in userefs(stmt)
use = ur[]
if isa(use, SSAValue)
uses[use.id] += 1
end
end
end

function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
Expand Down
3 changes: 0 additions & 3 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ function make_ssa!(ci::CodeInfo, code::Vector{Any}, idx, slot, @nospecialize(typ
end

function new_to_regular(@nospecialize(stmt), new_offset::Int)
if isa(stmt, NewSSAValue)
return SSAValue(stmt.id + new_offset)
end
urs = userefs(stmt)
for op in urs
val = op[]
Expand Down
53 changes: 53 additions & 0 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,27 @@ end
# SSAValues/Slots #
###################

function ssamap(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
op[] = f(val)
end
end
return urs[]
end

function foreachssa(f, @nospecialize(stmt))
urs = userefs(stmt)
for op in urs
val = op[]
if isa(val, SSAValue)
f(val)
end
end
end

function find_ssavalue_uses(body::Vector{Any}, nvals::Int)
uses = BitSet[ BitSet() for i = 1:nvals ]
for line in 1:length(body)
Expand Down Expand Up @@ -333,6 +354,38 @@ end
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id :
isa(s, Argument) ? (s::Argument).n : (s::TypedSlot).id

######################
# IncrementalCompact #
######################

# specifically meant to be used with body1 = compact.result and body2 = compact.new_new_nodes, with nvals == length(compact.used_ssas)
function find_ssavalue_uses1(compact)
body1, body2 = compact.result.inst, compact.new_new_nodes.stmts.inst
nvals = length(compact.used_ssas)
nbody1 = length(body1)
nbody2 = length(body2)

uses = zeros(Int, nvals)
function increment_uses(ssa::SSAValue)
uses[ssa.id] += 1
end

for line in 1:(nbody1 + nbody2)
# index into the right body
if line <= nbody1
isassigned(body1, line) || continue
e = body1[line]
else
line -= nbody1
isassigned(body2, line) || continue
e = body2[line]
end

foreachssa(increment_uses, e)
end
return uses
end

###########
# options #
###########
Expand Down
65 changes: 64 additions & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
using Base.Meta
using Core.IR
const Compiler = Core.Compiler
using .Compiler: CFG, BasicBlock
using .Compiler: CFG, BasicBlock, NewSSAValue

make_bb(preds, succs) = BasicBlock(Compiler.StmtRange(0, 0), preds, succs)

Expand Down Expand Up @@ -334,3 +334,66 @@ f_if_typecheck() = (if nothing; end; unsafe_load(Ptr{Int}(0)))
stderr = IOBuffer()
success(pipeline(Cmd(cmd); stdout=stdout, stderr=stderr)) && isempty(String(take!(stderr)))
end

let
function test_useref(stmt, v, op)
if isa(stmt, Expr)
@test stmt.args[op] === v
elseif isa(stmt, GotoIfNot)
@test stmt.cond === v
elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode)
@test stmt.val === v
elseif isa(stmt, SSAValue) || isa(stmt, NewSSAValue)
@test stmt === v
elseif isa(stmt, PiNode)
@test stmt.val === v && stmt.typ === typeof(stmt)
elseif isa(stmt, PhiNode) || isa(stmt, PhiCNode)
@test stmt.values[op] === v
end
end

function _test_userefs(@nospecialize stmt)
ex = Expr(:call, :+, Core.SSAValue(3), 1)
urs = Core.Compiler.userefs(stmt)::Core.Compiler.UseRefIterator
it = Core.Compiler.iterate(urs)
while it !== nothing
ur = getfield(it, 1)::Core.Compiler.UseRef
op = getfield(it, 2)::Int
v1 = Core.Compiler.getindex(ur)
# set to dummy expression and then back to itself to test `_useref_setindex!`
v2 = Core.Compiler.setindex!(ur, ex)
test_useref(v2, ex, op)
Core.Compiler.setindex!(ur, v1)
@test Core.Compiler.getindex(ur) === v1
it = Core.Compiler.iterate(urs, op)
end
end

function test_userefs(body)
for stmt in body
_test_userefs(stmt)
end
end

# this isn't valid code, we just care about looking at a variety of IR nodes
body = Any[
Expr(:enter, 11),
Expr(:call, :+, SSAValue(3), 1),
Expr(:throw_undef_if_not, :expected, false),
Expr(:leave, 1),
Expr(:(=), SSAValue(1), Expr(:call, :+, SSAValue(3), 1)),
UpsilonNode(),
UpsilonNode(SSAValue(2)),
PhiCNode(Any[SSAValue(5), SSAValue(7), SSAValue(9)]),
PhiCNode(Any[SSAValue(6)]),
PhiNode(Int32[8], Any[SSAValue(7)]),
PiNode(SSAValue(6), GotoNode),
GotoIfNot(SSAValue(3), 10),
GotoNode(5),
SSAValue(7),
NewSSAValue(9),
ReturnNode(SSAValue(11)),
]

test_userefs(body)
end