Skip to content

Commit a0f5598

Browse files
committed
inference: Model type propagation through exceptions
Currently the type of a caught exception is always modeled as `Any`. This isn't a huge problem, because control flow in Julia is generally assumed to be somewhat slow, so the extra type imprecision of not knowing the return type does not matter all that much. However, there are a few situations where it matters. For example: ``` maybe_getindex(A, i) = try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end ``` At present, we cannot infer :nothrow for this method, even if that is the only error type that `A[i]` can throw. This is particularly noticable, since we can now optimize away `:nothrow` exception frames entirely (#51674). Note that this PR still does not make the above example particularly efficient (at least interprocedurally), though specialized codegen could be added on top of this to make that happen. It does however improve the inference result. A second major motivation of this change is that reasoning about exception types is likely to be a major aspect of any future work on interface checking (since interfaces imply the absence of MethodErrors), so this PR lays the groundwork for appropriate modeling of these error paths. Note that this PR adds all the required plumbing, but does not yet have a particularly precise model of error types for our builtins, bailing to `Any` for any builtin not known to be `:nothrow`. This can be improved in follow up PRs as required.
1 parent ad86772 commit a0f5598

File tree

21 files changed

+370
-199
lines changed

21 files changed

+370
-199
lines changed

base/boot.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,13 @@ eval(Core, quote
471471
end)
472472

473473
function CodeInstance(
474-
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
474+
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
475475
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
476476
ipo_effects::UInt32, effects::UInt32, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#),
477477
relocatability::UInt8)
478478
return ccall(:jl_new_codeinst, Ref{CodeInstance},
479-
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
480-
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
479+
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
480+
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
481481
ipo_effects, effects, argescapes,
482482
relocatability)
483483
end

base/compiler/abstractinterpretation.jl

Lines changed: 173 additions & 97 deletions
Large diffs are not rendered by default.

base/compiler/effects.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,60 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
171171
nonoverlayed)
172172
end
173173

174+
function better_effects(new::Effects, old::Effects)
175+
any_improved = false
176+
if new.consistent == ALWAYS_TRUE
177+
any_improved |= old.consistent != ALWAYS_TRUE
178+
elseif new.consistent != old.consistent
179+
return false
180+
end
181+
if new.effect_free == ALWAYS_TRUE
182+
any_improved |= old.consistent != ALWAYS_TRUE
183+
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
184+
old.effect_free == ALWAYS_TRUE && return false
185+
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
186+
elseif new.effect_free != old.effect_free
187+
return false
188+
end
189+
if new.nothrow
190+
any_improved |= !old.nothrow
191+
elseif new.nothrow != old.nothrow
192+
return false
193+
end
194+
if new.terminates
195+
any_improved |= !old.terminates
196+
elseif new.terminates != old.terminates
197+
return false
198+
end
199+
if new.notaskstate
200+
any_improved |= !old.notaskstate
201+
elseif new.notaskstate != old.notaskstate
202+
return false
203+
end
204+
if new.inaccessiblememonly == ALWAYS_TRUE
205+
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
206+
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
207+
old.inaccessiblememonly == ALWAYS_TRUE && return false
208+
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
209+
elseif new.inaccessiblememonly != old.inaccessiblememonly
210+
return false
211+
end
212+
if new.noub == ALWAYS_TRUE
213+
any_improved |= old.noub != ALWAYS_TRUE
214+
elseif new.noub == NOUB_IF_NOINBOUNDS
215+
old.noub == ALWAYS_TRUE && return false
216+
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
217+
elseif new.noub != old.noub
218+
return false
219+
end
220+
if new.nonoverlayed
221+
any_improved |= !old.nonoverlayed
222+
elseif new.nonoverlayed != old.nonoverlayed
223+
return false
224+
end
225+
return any_improved
226+
end
227+
174228
function merge_effects(old::Effects, new::Effects)
175229
return Effects(
176230
merge_effectbits(old.consistent, new.consistent),

base/compiler/inferencestate.jl

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ to enable flow-sensitive analysis.
198198
"""
199199
const VarTable = Vector{VarState}
200200

201+
mutable struct TryCatchFrame
202+
exct
203+
const enter_idx
204+
end
205+
201206
mutable struct InferenceState
202207
#= information about this method instance =#
203208
linfo::MethodInstance
@@ -213,7 +218,8 @@ mutable struct InferenceState
213218
currbb::Int
214219
currpc::Int
215220
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
216-
handler_at::Vector{Int} # current exception handler info
221+
handlers::Vector{TryCatchFrame}
222+
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, excecption stack) value at the pc
217223
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
218224
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
219225
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
@@ -234,6 +240,7 @@ mutable struct InferenceState
234240
unreachable::BitSet # statements that were found to be statically unreachable
235241
valid_worlds::WorldRange
236242
bestguess #::Type
243+
exc_bestguess
237244
ipo_effects::Effects
238245

239246
#= flags =#
@@ -261,7 +268,7 @@ mutable struct InferenceState
261268

262269
currbb = currpc = 1
263270
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
264-
handler_at = compute_trycatch(code, BitSet())
271+
handler_at, handlers = compute_trycatch(code, BitSet())
265272
nssavalues = src.ssavaluetypes::Int
266273
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
267274
nstmts = length(code)
@@ -291,6 +298,7 @@ mutable struct InferenceState
291298

292299
valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
293300
bestguess = Bottom
301+
exc_bestguess = Bottom
294302
ipo_effects = EFFECTS_TOTAL
295303

296304
insert_coverage = should_insert_coverage(mod, src)
@@ -311,9 +319,9 @@ mutable struct InferenceState
311319

312320
return new(
313321
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
314-
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
322+
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
315323
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
316-
result, unreachable, valid_worlds, bestguess, ipo_effects,
324+
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
317325
restrict_abstract_call_sites, cache_mode, insert_coverage,
318326
interp)
319327
end
@@ -343,16 +351,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
343351
empty!(ip)
344352
ip.offset = 0 # for _bits_findnext
345353
push!(ip, n + 1)
346-
handler_at = fill(0, n)
354+
handler_at = fill((0, 0), n)
355+
handlers = TryCatchFrame[]
347356

348357
# start from all :enter statements and record the location of the try
349358
for pc = 1:n
350359
stmt = code[pc]
351360
if isexpr(stmt, :enter)
352361
l = stmt.args[1]::Int
353-
handler_at[pc + 1] = pc
362+
push!(handlers, TryCatchFrame(Bottom, pc))
363+
handler_id = length(handlers)
364+
handler_at[pc + 1] = (handler_id, 0)
354365
push!(ip, pc + 1)
355-
handler_at[l] = pc
366+
handler_at[l] = (handler_id, handler_id)
356367
push!(ip, l)
357368
end
358369
end
@@ -365,25 +376,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
365376
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
366377
pc´ = pc + 1 # next program-counter (after executing instruction)
367378
delete!(ip, pc)
368-
cur_hand = handler_at[pc]
369-
@assert cur_hand != 0 "unbalanced try/catch"
379+
cur_stacks = handler_at[pc]
380+
@assert cur_stacks != (0, 0) "unbalanced try/catch"
370381
stmt = code[pc]
371382
if isa(stmt, GotoNode)
372383
pc´ = stmt.label
373384
elseif isa(stmt, GotoIfNot)
374385
l = stmt.dest::Int
375-
if handler_at[l] != cur_hand
376-
@assert handler_at[l] == 0 "unbalanced try/catch"
377-
handler_at[l] = cur_hand
386+
if handler_at[l] != cur_stacks
387+
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
388+
handler_at[l] = cur_stacks
378389
push!(ip, l)
379390
end
380391
elseif isa(stmt, ReturnNode)
381-
@assert !isdefined(stmt, :val) "unbalanced try/catch"
392+
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
382393
break
383394
elseif isa(stmt, Expr)
384395
head = stmt.head
385396
if head === :enter
386-
cur_hand = pc
397+
# Already set aboves
398+
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
387399
elseif head === :leave
388400
l = 0
389401
for j = 1:length(stmt.args)
@@ -399,19 +411,20 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
399411
end
400412
l += 1
401413
end
414+
cur_hand = cur_stacks[1]
402415
for i = 1:l
403-
cur_hand = handler_at[cur_hand]
416+
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
404417
end
405-
cur_hand == 0 && break
418+
cur_stacks = (cur_hand, cur_stacks[2])
419+
cur_stacks == (0, 0) && break
420+
elseif head === :pop_exception
421+
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
406422
end
407423
end
408424

409425
pc´ > n && break # can't proceed with the fast-path fall-through
410-
if handler_at[pc´] != cur_hand
411-
if handler_at[pc´] != 0
412-
@assert false "unbalanced try/catch"
413-
end
414-
handler_at[pc´] = cur_hand
426+
if handler_at[pc´] != cur_stacks
427+
handler_at[pc´] = cur_stacks
415428
elseif !in(pc´, ip)
416429
break # already visited
417430
end
@@ -420,7 +433,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
420433
end
421434

422435
@assert first(ip) == n + 1
423-
return handler_at
436+
return handler_at, handlers
424437
end
425438

426439
# check if coverage mode is enabled

base/compiler/ssair/irinterp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ end
4646

4747
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
4848
si = StmtInfo(true) # TODO better job here?
49-
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
49+
(; rt, exct, effects, info) = abstract_call(interp, arginfo, si, irsv)
5050
irsv.ir.stmts[irsv.curridx][:info] = info
51-
return RTEffects(rt, effects)
51+
return RTEffects(rt, exct, effects)
5252
end
5353

5454
function update_phi!(irsv::IRInterpretationState, from::Int, to::Int)

base/compiler/ssair/slot2ssa.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
584584
end
585585

586586
# Record the correct exception handler for all critical sections
587-
handler_at = compute_trycatch(code, BitSet())
587+
handler_at, handlers = compute_trycatch(code, BitSet())
588588

589589
phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
590590
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
@@ -810,8 +810,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
810810
incoming_vals[id] = Pair{Any, Any}(thisval, thisdef)
811811
has_pinode[id] = false
812812
enter_idx = idx
813-
while handler_at[enter_idx] != 0
814-
enter_idx = handler_at[enter_idx]
813+
while handler_at[enter_idx][1] != 0
814+
(; enter_idx) = handlers[handler_at[enter_idx][1]]
815815
leave_block = block_for_inst(cfg, code[enter_idx].args[1]::Int)
816816
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
817817
if cidx !== nothing

base/compiler/stmtinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and any additional information (`call.info`) for a given generic call.
1010
"""
1111
struct CallMeta
1212
rt::Any
13+
exct::Any
1314
effects::Effects
1415
info::CallInfo
1516
end

0 commit comments

Comments
 (0)