Skip to content

Commit e21e77d

Browse files
committed
inference: Model type propagation through exceptions
Currently the type of a caught exception is always modeled as `Any`. This isn't a huge problem, because control flow in Julia is generally assumed to be somewhat slow, so the extra type imprecision of not knowing the return type does not matter all that much. However, there are a few situations where it matters. For example: ``` maybe_getindex(A, i) = try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end ``` At present, we cannot infer :nothrow for this method, even if that is the only error type that `A[i]` can throw. This is particularly noticable, since we can now optimize away `:nothrow` exception frames entirely (#51674). Note that this PR still does not make the above example particularly efficient (at least interprocedurally), though specialized codegen could be added on top of this to make that happen. It does however improve the inference result. A second major motivation of this change is that reasoning about exception types is likely to be a major aspect of any future work on interface checking (since interfaces imply the absence of MethodErrors), so this PR lays the groundwork for appropriate modeling of these error paths. Note that this PR adds all the required plumbing, but does not yet have a particularly precise model of error types for our builtins, bailing to `Any` for any builtin not known to be `:nothrow`. This can be improved in follow up PRs as required.
1 parent 0cf2bf1 commit e21e77d

23 files changed

+396
-212
lines changed

base/boot.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,13 @@ eval(Core, quote
478478
end)
479479

480480
function CodeInstance(
481-
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
481+
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
482482
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
483483
ipo_effects::UInt32, effects::UInt32, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#),
484484
relocatability::UInt8)
485485
return ccall(:jl_new_codeinst, Ref{CodeInstance},
486-
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
487-
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
486+
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
487+
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
488488
ipo_effects, effects, argescapes,
489489
relocatability)
490490
end

base/compiler/abstractinterpretation.jl

Lines changed: 181 additions & 107 deletions
Large diffs are not rendered by default.

base/compiler/effects.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,60 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
171171
nonoverlayed)
172172
end
173173

174+
function better_effects(new::Effects, old::Effects)
175+
any_improved = false
176+
if new.consistent == ALWAYS_TRUE
177+
any_improved |= old.consistent != ALWAYS_TRUE
178+
elseif new.consistent != old.consistent
179+
return false
180+
end
181+
if new.effect_free == ALWAYS_TRUE
182+
any_improved |= old.consistent != ALWAYS_TRUE
183+
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
184+
old.effect_free == ALWAYS_TRUE && return false
185+
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
186+
elseif new.effect_free != old.effect_free
187+
return false
188+
end
189+
if new.nothrow
190+
any_improved |= !old.nothrow
191+
elseif new.nothrow != old.nothrow
192+
return false
193+
end
194+
if new.terminates
195+
any_improved |= !old.terminates
196+
elseif new.terminates != old.terminates
197+
return false
198+
end
199+
if new.notaskstate
200+
any_improved |= !old.notaskstate
201+
elseif new.notaskstate != old.notaskstate
202+
return false
203+
end
204+
if new.inaccessiblememonly == ALWAYS_TRUE
205+
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
206+
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
207+
old.inaccessiblememonly == ALWAYS_TRUE && return false
208+
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
209+
elseif new.inaccessiblememonly != old.inaccessiblememonly
210+
return false
211+
end
212+
if new.noub == ALWAYS_TRUE
213+
any_improved |= old.noub != ALWAYS_TRUE
214+
elseif new.noub == NOUB_IF_NOINBOUNDS
215+
old.noub == ALWAYS_TRUE && return false
216+
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
217+
elseif new.noub != old.noub
218+
return false
219+
end
220+
if new.nonoverlayed
221+
any_improved |= !old.nonoverlayed
222+
elseif new.nonoverlayed != old.nonoverlayed
223+
return false
224+
end
225+
return any_improved
226+
end
227+
174228
function merge_effects(old::Effects, new::Effects)
175229
return Effects(
176230
merge_effectbits(old.consistent, new.consistent),

base/compiler/inferencestate.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
203203
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
204204
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
205205

206+
mutable struct TryCatchFrame
207+
exct
208+
const enter_idx
209+
end
210+
206211
mutable struct InferenceState
207212
#= information about this method instance =#
208213
linfo::MethodInstance
@@ -218,7 +223,8 @@ mutable struct InferenceState
218223
currbb::Int
219224
currpc::Int
220225
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
221-
handler_at::Vector{Int} # current exception handler info
226+
handlers::Vector{TryCatchFrame}
227+
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
222228
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
223229
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
224230
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
@@ -239,6 +245,7 @@ mutable struct InferenceState
239245
unreachable::BitSet # statements that were found to be statically unreachable
240246
valid_worlds::WorldRange
241247
bestguess #::Type
248+
exc_bestguess
242249
ipo_effects::Effects
243250

244251
#= flags =#
@@ -266,7 +273,7 @@ mutable struct InferenceState
266273

267274
currbb = currpc = 1
268275
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
269-
handler_at = compute_trycatch(code, BitSet())
276+
handler_at, handlers = compute_trycatch(code, BitSet())
270277
nssavalues = src.ssavaluetypes::Int
271278
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
272279
nstmts = length(code)
@@ -296,6 +303,7 @@ mutable struct InferenceState
296303

297304
valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
298305
bestguess = Bottom
306+
exc_bestguess = Bottom
299307
ipo_effects = EFFECTS_TOTAL
300308

301309
insert_coverage = should_insert_coverage(mod, src)
@@ -315,9 +323,9 @@ mutable struct InferenceState
315323

316324
return new(
317325
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
318-
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
326+
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
319327
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
320-
result, unreachable, valid_worlds, bestguess, ipo_effects,
328+
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
321329
restrict_abstract_call_sites, cache_mode, insert_coverage,
322330
interp)
323331
end
@@ -347,16 +355,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
347355
empty!(ip)
348356
ip.offset = 0 # for _bits_findnext
349357
push!(ip, n + 1)
350-
handler_at = fill(0, n)
358+
handler_at = fill((0, 0), n)
359+
handlers = TryCatchFrame[]
351360

352361
# start from all :enter statements and record the location of the try
353362
for pc = 1:n
354363
stmt = code[pc]
355364
if isexpr(stmt, :enter)
356365
l = stmt.args[1]::Int
357-
handler_at[pc + 1] = pc
366+
push!(handlers, TryCatchFrame(Bottom, pc))
367+
handler_id = length(handlers)
368+
handler_at[pc + 1] = (handler_id, 0)
358369
push!(ip, pc + 1)
359-
handler_at[l] = pc
370+
handler_at[l] = (handler_id, handler_id)
360371
push!(ip, l)
361372
end
362373
end
@@ -369,25 +380,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
369380
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
370381
pc´ = pc + 1 # next program-counter (after executing instruction)
371382
delete!(ip, pc)
372-
cur_hand = handler_at[pc]
373-
@assert cur_hand != 0 "unbalanced try/catch"
383+
cur_stacks = handler_at[pc]
384+
@assert cur_stacks != (0, 0) "unbalanced try/catch"
374385
stmt = code[pc]
375386
if isa(stmt, GotoNode)
376387
pc´ = stmt.label
377388
elseif isa(stmt, GotoIfNot)
378389
l = stmt.dest::Int
379-
if handler_at[l] != cur_hand
380-
@assert handler_at[l] == 0 "unbalanced try/catch"
381-
handler_at[l] = cur_hand
390+
if handler_at[l] != cur_stacks
391+
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
392+
handler_at[l] = cur_stacks
382393
push!(ip, l)
383394
end
384395
elseif isa(stmt, ReturnNode)
385-
@assert !isdefined(stmt, :val) "unbalanced try/catch"
396+
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
386397
break
387398
elseif isa(stmt, Expr)
388399
head = stmt.head
389400
if head === :enter
390-
cur_hand = pc
401+
# Already set above
402+
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
391403
elseif head === :leave
392404
l = 0
393405
for j = 1:length(stmt.args)
@@ -403,19 +415,21 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
403415
end
404416
l += 1
405417
end
418+
cur_hand = cur_stacks[1]
406419
for i = 1:l
407-
cur_hand = handler_at[cur_hand]
420+
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
408421
end
409-
cur_hand == 0 && break
422+
cur_stacks = (cur_hand, cur_stacks[2])
423+
cur_stacks == (0, 0) && break
424+
elseif head === :pop_exception
425+
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
426+
cur_stacks == (0, 0) && break
410427
end
411428
end
412429

413430
pc´ > n && break # can't proceed with the fast-path fall-through
414-
if handler_at[pc´] != cur_hand
415-
if handler_at[pc´] != 0
416-
@assert false "unbalanced try/catch"
417-
end
418-
handler_at[pc´] = cur_hand
431+
if handler_at[pc´] != cur_stacks
432+
handler_at[pc´] = cur_stacks
419433
elseif !in(pc´, ip)
420434
break # already visited
421435
end
@@ -424,7 +438,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
424438
end
425439

426440
@assert first(ip) == n + 1
427-
return handler_at
441+
return handler_at, handlers
428442
end
429443

430444
# check if coverage mode is enabled

base/compiler/ssair/irinterp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ end
4646

4747
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
4848
si = StmtInfo(true) # TODO better job here?
49-
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
49+
(; rt, exct, effects, info) = abstract_call(interp, arginfo, si, irsv)
5050
irsv.ir.stmts[irsv.curridx][:info] = info
51-
return RTEffects(rt, effects)
51+
return RTEffects(rt, exct, effects)
5252
end
5353

5454
function update_phi!(irsv::IRInterpretationState, from::Int, to::Int)

base/compiler/ssair/slot2ssa.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
584584
end
585585

586586
# Record the correct exception handler for all critical sections
587-
handler_at = compute_trycatch(code, BitSet())
587+
handler_at, handlers = compute_trycatch(code, BitSet())
588588

589589
phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
590590
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
@@ -627,10 +627,12 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
627627
# The slot is live-in into this block. We need to
628628
# Create a PhiC node in the catch entry block and
629629
# an upsilon node in the corresponding enter block
630+
varstate = sv.bb_vartables[li]
631+
if varstate === nothing
632+
continue
633+
end
630634
node = PhiCNode(Any[])
631635
insertpoint = first_insert_for_bb(code, cfg, li)
632-
varstate = sv.bb_vartables[li]
633-
@assert varstate !== nothing
634636
vt = varstate[idx]
635637
phic_ssa = NewSSAValue(
636638
insert_node!(ir, insertpoint,
@@ -690,6 +692,9 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
690692
new_nodes = ir.new_nodes
691693
@timeit "SSA Rename" while !isempty(worklist)
692694
(item::Int, pred, incoming_vals) = pop!(worklist)
695+
if sv.bb_vartables[item] === nothing
696+
continue
697+
end
693698
# Rename existing phi nodes first, because their uses occur on the edge
694699
# TODO: This isn't necessary if inlining stops replacing arguments by slots.
695700
for idx in cfg.blocks[item].stmts
@@ -810,8 +815,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
810815
incoming_vals[id] = Pair{Any, Any}(thisval, thisdef)
811816
has_pinode[id] = false
812817
enter_idx = idx
813-
while handler_at[enter_idx] != 0
814-
enter_idx = handler_at[enter_idx]
818+
while handler_at[enter_idx][1] != 0
819+
(; enter_idx) = handlers[handler_at[enter_idx][1]]
815820
leave_block = block_for_inst(cfg, code[enter_idx].args[1]::Int)
816821
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
817822
if cidx !== nothing

base/compiler/stmtinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and any additional information (`call.info`) for a given generic call.
1010
"""
1111
struct CallMeta
1212
rt::Any
13+
exct::Any
1314
effects::Effects
1415
info::CallInfo
1516
end

0 commit comments

Comments
 (0)