Skip to content

Commit

Permalink
compiler: add flag manipulation utilities (JuliaLang#52269)
Browse files Browse the repository at this point in the history
Right now, we're checking if a flag exists using bare operations like
`&` and `==`. This works fine, but I think unifying these interfaces
could make our code easier to maintain. So, this commit introduces some
new high-level operators and refactored the code to incorporate them:
- `has_flag(curr::UInt32, flag::UInt32)`
- `has_flag(inst::Instruction, flag::UInt32)`
- `add_flag!(inst::Instruction, flag::UInt32)`
- `sub_flag!(inst::Instruction, flag::UInt32)`
  • Loading branch information
aviatesk authored Nov 23, 2023
1 parent 418423b commit 44b8983
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 72 deletions.
6 changes: 3 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,7 @@ end

function stmt_taints_inbounds_consistency(sv::AbsIntState)
propagate_inbounds(sv) && return true
return (get_curr_ssaflag(sv) & IR_FLAG_INBOUNDS) != 0
return has_curr_ssaflag(sv, IR_FLAG_INBOUNDS)
end

function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
Expand All @@ -2718,7 +2718,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
else
(; rt, exct, effects) = abstract_eval_statement_expr(interp, e, vtypes, sv)
if effects.noub === NOUB_IF_NOINBOUNDS
if !iszero(get_curr_ssaflag(sv) & IR_FLAG_INBOUNDS)
if has_curr_ssaflag(sv, IR_FLAG_INBOUNDS)
effects = Effects(effects; noub=ALWAYS_FALSE)
elseif !propagate_inbounds(sv)
# The callee read our inbounds flag, but unless we propagate inbounds,
Expand Down Expand Up @@ -3258,7 +3258,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if exct !== Union{}
update_exc_bestguess!(exct, frame, ipo_lattice(interp))
end
if (get_curr_ssaflag(frame) & IR_FLAG_NOTHROW) != IR_FLAG_NOTHROW
if !has_curr_ssaflag(frame, IR_FLAG_NOTHROW)
propagate_to_error_handler!(currstate, frame, 𝕃ᵢ)
end
if rt === Bottom
Expand Down
7 changes: 5 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ end
get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

has_curr_ssaflag(sv::InferenceState, flag::UInt32) = has_flag(sv.src.ssaflags[sv.currpc], flag)
has_curr_ssaflag(sv::IRInterpretationState, flag::UInt32) = has_flag(sv.ir.stmts[sv.curridx][:flag], flag)

function set_curr_ssaflag!(sv::InferenceState, flag::UInt32, mask::UInt32=typemax(UInt32))
curr_flag = sv.src.ssaflags[sv.currpc]
sv.src.ssaflags[sv.currpc] = (curr_flag & ~mask) | flag
Expand All @@ -898,10 +901,10 @@ function set_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32, mask::UInt32
end

add_curr_ssaflag!(sv::InferenceState, flag::UInt32) = sv.src.ssaflags[sv.currpc] |= flag
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = sv.ir.stmts[sv.curridx][:flag] |= flag
add_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = add_flag!(sv.ir.stmts[sv.curridx], flag)

sub_curr_ssaflag!(sv::InferenceState, flag::UInt32) = sv.src.ssaflags[sv.currpc] &= ~flag
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = sv.ir.stmts[sv.curridx][:flag] &= ~flag
sub_curr_ssaflag!(sv::IRInterpretationState, flag::UInt32) = sub_flag!(sv.ir.stmts[sv.curridx], flag)

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
if effects.effect_free === EFFECT_FREE_GLOBALLY
Expand Down
44 changes: 23 additions & 21 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,38 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError

# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c

const IR_FLAG_NULL = UInt32(0)
const IR_FLAG_NULL = zero(UInt32)
# This statement is marked as @inbounds by user.
# Ff replaced by inlining, any contained boundschecks may be removed.
const IR_FLAG_INBOUNDS = UInt32(1) << 0
const IR_FLAG_INBOUNDS = one(UInt32) << 0
# This statement is marked as @inline by user
const IR_FLAG_INLINE = UInt32(1) << 1
const IR_FLAG_INLINE = one(UInt32) << 1
# This statement is marked as @noinline by user
const IR_FLAG_NOINLINE = UInt32(1) << 2
const IR_FLAG_THROW_BLOCK = UInt32(1) << 3
const IR_FLAG_NOINLINE = one(UInt32) << 2
const IR_FLAG_THROW_BLOCK = one(UInt32) << 3
# This statement was proven :effect_free
const IR_FLAG_EFFECT_FREE = UInt32(1) << 4
const IR_FLAG_EFFECT_FREE = one(UInt32) << 4
# This statement was proven not to throw
const IR_FLAG_NOTHROW = UInt32(1) << 5
const IR_FLAG_NOTHROW = one(UInt32) << 5
# This is :consistent
const IR_FLAG_CONSISTENT = UInt32(1) << 6
const IR_FLAG_CONSISTENT = one(UInt32) << 6
# An optimization pass has updated this statement in a way that may
# have exposed information that inference did not see. Re-running
# inference on this statement may be profitable.
const IR_FLAG_REFINED = UInt32(1) << 7
const IR_FLAG_REFINED = one(UInt32) << 7
# This is :noub == ALWAYS_TRUE
const IR_FLAG_NOUB = UInt32(1) << 8
const IR_FLAG_NOUB = one(UInt32) << 8

# TODO: Both of these should eventually go away once
# This is :effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
const IR_FLAG_EFIIMO = UInt32(1) << 9
const IR_FLAG_EFIIMO = one(UInt32) << 9
# This is :inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
const IR_FLAG_INACCESSIBLE_OR_ARGMEM = UInt32(1) << 10
const IR_FLAG_INACCESSIBLE_OR_ARGMEM = one(UInt32) << 10

const IR_FLAGS_EFFECTS = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW | IR_FLAG_CONSISTENT | IR_FLAG_NOUB

has_flag(curr::UInt32, flag::UInt32) = (curr & flag) == flag

const TOP_TUPLE = GlobalRef(Core, :tuple)

# This corresponds to the type of `CodeInfo`'s `inlining_cost` field
Expand Down Expand Up @@ -218,9 +220,9 @@ end

_topmod(sv::OptimizationState) = _topmod(sv.mod)

is_stmt_inline(stmt_flag::UInt32) = stmt_flag & IR_FLAG_INLINE 0
is_stmt_noinline(stmt_flag::UInt32) = stmt_flag & IR_FLAG_NOINLINE 0
is_stmt_throw_block(stmt_flag::UInt32) = stmt_flag & IR_FLAG_THROW_BLOCK 0
is_stmt_inline(stmt_flag::UInt32) = has_flag(stmt_flag, IR_FLAG_INLINE)
is_stmt_noinline(stmt_flag::UInt32) = has_flag(stmt_flag, IR_FLAG_NOINLINE)
is_stmt_throw_block(stmt_flag::UInt32) = has_flag(stmt_flag, IR_FLAG_THROW_BLOCK)

function new_expr_effect_flags(𝕃ₒ::AbstractLattice, args::Vector{Any}, src::Union{IRCode,IncrementalCompact}, pattern_match=nothing)
Targ = args[1]
Expand Down Expand Up @@ -468,7 +470,7 @@ end

function any_stmt_may_throw(ir::IRCode, bb::Int)
for stmt in ir.cfg.blocks[bb].stmts
if (ir[SSAValue(stmt)][:flag] & IR_FLAG_NOTHROW) != 0
if has_flag(ir[SSAValue(stmt)], IR_FLAG_NOTHROW)
return true
end
end
Expand Down Expand Up @@ -702,7 +704,7 @@ function scan_non_dataflow_flags!(inst::Instruction, sv::PostOptAnalysisState)
# ignore control flow node – they are not removable on their own and thus not
# have `IR_FLAG_EFFECT_FREE` but still do not taint `:effect_free`-ness of
# the whole method invocation
sv.all_effect_free &= !iszero(flag & IR_FLAG_EFFECT_FREE)
sv.all_effect_free &= has_flag(flag, IR_FLAG_EFFECT_FREE)
end
elseif sv.all_effect_free
if (isexpr(stmt, :invoke) || isexpr(stmt, :new) ||
Expand All @@ -714,8 +716,8 @@ function scan_non_dataflow_flags!(inst::Instruction, sv::PostOptAnalysisState)
sv.all_effect_free = false
end
end
sv.all_nothrow &= !iszero(flag & IR_FLAG_NOTHROW)
if iszero(flag & IR_FLAG_NOUB)
sv.all_nothrow &= has_flag(flag, IR_FLAG_NOTHROW)
if !has_flag(flag, IR_FLAG_NOUB)
# Special case: `:boundscheck` into `getfield` or memory operations is `:noub_if_noinbounds`
if is_conditional_noub(inst, sv)
sv.any_conditional_ub = true
Expand Down Expand Up @@ -960,11 +962,11 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
((block + 1) != destblock) && cfg_delete_edge!(sv.cfg, block, destblock)
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
elseif i + 1 in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
@assert has_flag(ci.ssaflags[i], IR_FLAG_NOTHROW)
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif expr.dest in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
@assert has_flag(ci.ssaflags[i], IR_FLAG_NOTHROW)
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
end
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ using ._TOP_MOD: # Base definitions
unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, , , , , , , ,
using Core.Compiler: # Core.Compiler specific definitions
Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
argextype, check_effect_free!, fieldcount_noerror, hasintersect, intrinsic_nothrow,
is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow, singleton_type,
try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice
argextype, check_effect_free!, fieldcount_noerror, hasintersect, has_flag,
intrinsic_nothrow, is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow,
singleton_type, try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice

include(x) = _TOP_MOD.include(@__MODULE__, x)
if _TOP_MOD === Core.Compiler
Expand Down Expand Up @@ -975,7 +975,7 @@ end
error("unexpected assignment found: inspect `Main.pc` and `Main.pc`")
end

is_nothrow(ir::IRCode, pc::Int) = ir[SSAValue(pc)][:flag] & IR_FLAG_NOTHROW 0
is_nothrow(ir::IRCode, pc::Int) = has_flag(ir[SSAValue(pc)], IR_FLAG_NOTHROW)

# NOTE if we don't maintain the alias set that is separated from the lattice state, we can do
# something like below: it essentially incorporates forward escape propagation in our default
Expand Down
19 changes: 10 additions & 9 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector

ssa_substitute = ir_prepare_inlining!(InsertHere(compact), compact, item.ir, item.mi, inlined_at, argexprs)

boundscheck = iszero(compact.result[idx][:flag] & IR_FLAG_INBOUNDS) ? boundscheck : :off
boundscheck = has_flag(compact.result[idx], IR_FLAG_INBOUNDS) ? :off : boundscheck

# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
Expand Down Expand Up @@ -1032,7 +1032,7 @@ function handle_single_case!(todo::Vector{Pair{Int,Any}},
stmt.head = :invoke
pushfirst!(stmt.args, case.invoke)
end
ir[SSAValue(idx)][:flag] |= flags_for_effects(case.effects)
add_flag!(ir[SSAValue(idx)], flags_for_effects(case.effects))
elseif case === nothing
# Do, well, nothing
else
Expand Down Expand Up @@ -1257,21 +1257,22 @@ function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecia
end
function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt), 𝕃ₒ::AbstractLattice)
(consistent, effect_free_and_nothrow, nothrow) = stmt_effect_flags(𝕃ₒ, stmt, rt, ir)
inst = ir.stmts[idx]
if consistent
ir.stmts[idx][:flag] |= IR_FLAG_CONSISTENT
add_flag!(inst, IR_FLAG_CONSISTENT)
end
if effect_free_and_nothrow
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
add_flag!(inst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
elseif nothrow
ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW
add_flag!(inst, IR_FLAG_NOTHROW)
end
if !(isexpr(stmt, :call) || isexpr(stmt, :invoke))
# There is a bit of a subtle point here, which is that some non-call
# statements (e.g. PiNode) can be UB:, however, we consider it
# illegal to introduce such statements that actually cause UB (for any
# input). Ideally that'd be handled at insertion time (TODO), but for
# the time being just do that here.
ir.stmts[idx][:flag] |= IR_FLAG_NOUB
add_flag!(inst, IR_FLAG_NOUB)
end
return effect_free_and_nothrow
end
Expand Down Expand Up @@ -1583,7 +1584,7 @@ function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::
end
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
else
ir[SSAValue(idx)][:flag] |= flags_for_effects(joint_effects)
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
end
return nothing
end
Expand Down Expand Up @@ -1682,7 +1683,7 @@ function inline_const_if_inlineable!(inst::Instruction)
inst[:stmt] = quoted(rt.val)
return true
end
inst[:flag] |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
add_flag!(inst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
return false
end

Expand Down Expand Up @@ -1879,7 +1880,7 @@ function ssa_substitute_op!(insert_node!::Inserter, subst_inst::Instruction, @no
return quoted(val)
else
flag = subst_inst[:flag]
maybe_undef = (flag & IR_FLAG_NOTHROW) == 0 && isa(val, TypeVar)
maybe_undef = !has_flag(flag, IR_FLAG_NOTHROW) && isa(val, TypeVar)
(ret, tcheck_not) = insert_spval!(insert_node!, ssa_substitute.spvals_ssa::SSAValue, spidx, maybe_undef)
if maybe_undef
insert_node!(
Expand Down
30 changes: 20 additions & 10 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ function setindex!(node::Instruction, newval::Instruction)
return node
end

has_flag(inst::Instruction, flag::UInt32) = has_flag(inst[:flag], flag)
add_flag!(inst::Instruction, flag::UInt32) = inst[:flag] |= flag
sub_flag!(inst::Instruction, flag::UInt32) = inst[:flag] &= ~flag

struct NewNodeInfo
# Insertion position (interpretation depends on which array this is in)
pos::Int
Expand Down Expand Up @@ -334,18 +338,24 @@ function NewInstruction(inst::Instruction;
return NewInstruction(stmt, type, info, line, flag)
end
@specialize
effect_free_and_nothrow(newinst::NewInstruction) = NewInstruction(newinst; flag=add_flag(newinst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW))
with_flags(newinst::NewInstruction, flags::UInt32) = NewInstruction(newinst; flag=add_flag(newinst, flags))
without_flags(newinst::NewInstruction, flags::UInt32) = NewInstruction(newinst; flag=sub_flag(newinst, flags))
effect_free_and_nothrow(newinst::NewInstruction) = add_flag(newinst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
function add_flag(newinst::NewInstruction, newflag::UInt32)
flag = newinst.flag
flag === nothing && return newflag
return flag | newflag
if flag === nothing
flag = newflag
else
flag |= newflag
end
return NewInstruction(newinst; flag)
end
function sub_flag(newinst::NewInstruction, newflag::UInt32)
flag = newinst.flag
flag === nothing && return IR_FLAG_NULL
return flag & ~newflag
if flag === nothing
flag = IR_FLAG_NULL
else
flag &= ~newflag
end
return NewInstruction(newinst; flag)
end

struct IRCode
Expand Down Expand Up @@ -1325,7 +1335,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
elseif isa(stmt, GlobalRef)
total_flags = IR_FLAG_CONSISTENT | IR_FLAG_EFFECT_FREE
flag = result[result_idx][:flag]
if (flag & total_flags) == total_flags
if has_flag(flag, total_flags)
ssa_rename[idx] = stmt
else
ssa_rename[idx] = SSAValue(result_idx)
Expand Down Expand Up @@ -1532,7 +1542,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
else
# Constant assign, replace uses of this ssa value with its result
end
if (inst[:flag] & IR_FLAG_REFINED) != 0 && !isa(stmt, Refined)
if has_flag(inst, IR_FLAG_REFINED) && !isa(stmt, Refined)
# If we're compacting away an instruction that was marked as refined,
# leave a marker in the ssa_rename, so we can taint any users.
stmt = Refined(stmt)
Expand Down Expand Up @@ -1767,7 +1777,7 @@ function maybe_erase_unused!(callback::Function, compact::IncrementalCompact, id
stmt = inst[:stmt]
stmt === nothing && return false
inst[:type] === Bottom && return false
effect_free = (inst[:flag] & (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)) == IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
effect_free = has_flag(inst, (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW))
effect_free || return false
foreachssa(stmt) do val::SSAValue
if compact.used_ssas[val.id] == 1
Expand Down
22 changes: 11 additions & 11 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
if bb === nothing
bb = block_for_inst(ir, idx)
end
inst[:flag] |= IR_FLAG_NOTHROW
add_flag!(inst, IR_FLAG_NOTHROW)
if condval
inst[:stmt] = nothing
inst[:type] = Any
Expand All @@ -156,14 +156,14 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
head = stmt.head
if head === :call || head === :foreigncall || head === :new || head === :splatnew || head === :static_parameter || head === :isdefined || head === :boundscheck
(; rt, effects) = abstract_eval_statement_expr(interp, stmt, nothing, irsv)
inst[:flag] |= flags_for_effects(effects)
add_flag!(inst, flags_for_effects(effects))
elseif head === :invoke
rt, (nothrow, noub) = concrete_eval_invoke(interp, stmt, stmt.args[1]::MethodInstance, irsv)
if nothrow
inst[:flag] |= IR_FLAG_NOTHROW
add_flag!(inst, IR_FLAG_NOTHROW)
end
if noub
inst[:flag] |= IR_FLAG_NOUB
add_flag!(inst, IR_FLAG_NOUB)
end
elseif head === :throw_undef_if_not
condval = maybe_extract_const_bool(argextype(stmt.args[2], ir))
Expand Down Expand Up @@ -197,7 +197,7 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
if rt !== nothing
if isa(rt, Const)
inst[:type] = rt
if is_inlineable_constant(rt.val) && (inst[:flag] & (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)) == IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
if is_inlineable_constant(rt.val) && has_flag(inst, (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW))
inst[:stmt] = quoted(rt.val)
end
return true
Expand Down Expand Up @@ -299,9 +299,9 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
typ = inst[:type]
flag = inst[:flag]
any_refined = false
if (flag & IR_FLAG_REFINED) != 0
if has_flag(flag, IR_FLAG_REFINED)
any_refined = true
inst[:flag] &= ~IR_FLAG_REFINED
sub_flag!(inst, IR_FLAG_REFINED)
end
for ur in userefs(stmt)
val = ur[]
Expand Down Expand Up @@ -351,8 +351,8 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
irsv.curridx = idx
stmt = inst[:stmt]
flag = inst[:flag]
if (flag & IR_FLAG_REFINED) != 0
inst[:flag] &= ~IR_FLAG_REFINED
if has_flag(flag, IR_FLAG_REFINED)
sub_flag!(inst, IR_FLAG_REFINED)
push!(stmt_ip, idx)
end
check_ret!(stmt, idx)
Expand Down Expand Up @@ -408,8 +408,8 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
nothrow = noub = true
for idx = 1:length(ir.stmts)
flag = ir[SSAValue(idx)][:flag]
nothrow &= !iszero(flag & IR_FLAG_NOTHROW)
noub &= !iszero(flag & IR_FLAG_NOUB)
nothrow &= has_flag(flag, IR_FLAG_NOTHROW)
noub &= has_flag(flag, IR_FLAG_NOUB)
(nothrow | noub) || break
end

Expand Down
Loading

0 comments on commit 44b8983

Please sign in to comment.