diff --git a/src/interpreter/forward_mode.jl b/src/interpreter/forward_mode.jl index 66598f27b3..eaee2710cc 100644 --- a/src/interpreter/forward_mode.jl +++ b/src/interpreter/forward_mode.jl @@ -46,7 +46,7 @@ function build_frule( else # Derive forward-pass IR, and shove in a `MistyClosure`. dual_ir, captures, info = generate_dual_ir(interp, sig_or_mi; debug_mode) - dual_oc = optimized_misty_closure( + dual_oc = misty_closure( info.dual_ret_type, dual_ir, captures...; do_compile=true ) raw_rule = DerivedFRule{sig,typeof(dual_oc),info.isva,info.nargs}(dual_oc) @@ -109,6 +109,9 @@ function generate_dual_ir( # Grab code associated to the primal. primal_ir, _ = lookup_ir(interp, sig_or_mi) + @static if VERSION > v"1.12-" + primal_ir = set_valid_world!(primal_ir, interp.world) + end nargs = length(primal_ir.argtypes) # Normalise the IR. diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 043347a353..8423336a82 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -136,7 +136,7 @@ function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) propagate_inbounds = true spec_info = CC.SpecInfo(nargs, isva, propagate_inbounds, nothing) min_world = world = get_inference_world(interp) - max_world = Base.get_world_counter() + max_world = Base.get_world_counter() # TODO: why do we allow other worlds? irsv = CC.IRInterpretationState( interp, spec_info, ir, mi, ir.argtypes, world, min_world, max_world ) @@ -285,6 +285,37 @@ function lookup_ir(::CC.AbstractInterpreter, mc::MistyClosure; optimize_until=no return mc.ir[], return_type(mc.oc) end +@static if VERSION > v"1.12-" + """ + set_valid_world!(ir::IRCode, world::UInt)::IRCode + + (1.12+ only) + Create a shallow copy of the given IR code, with its `valid_worlds` field updated + to a single valid world. This allows the compiler to perform more inlining. + + In particular, if the IR comes from say a function `f` which makes a call to another + function `g` which only got defined after `f`, then at the min_world when `f` was + defined, `g` was not available yet. If we restrict the IR to a world where `g` is + available then `g` can be inlined. + + Will error if `world` is not in the existing `valid_worlds` of `ir`. + """ + function set_valid_world!(ir::IRCode, world::UInt) + if world ∉ ir.valid_worlds + error("World $world is not valid for this IRCode: $(ir.valid_worlds).") + end + return CC.IRCode( + ir.stmts, + ir.cfg, + ir.debuginfo, + ir.argtypes, + ir.meta, + ir.sptypes, + CC.WorldRange(world, world), + ) + end +end + return_type(::Core.OpaqueClosure{A,B}) where {A,B} = B """ diff --git a/src/interpreter/reverse_mode.jl b/src/interpreter/reverse_mode.jl index b80be6f3e1..e9b50aca4f 100644 --- a/src/interpreter/reverse_mode.jl +++ b/src/interpreter/reverse_mode.jl @@ -1134,10 +1134,10 @@ function build_rrule( else # Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s. dri = generate_ir(interp, sig_or_mi; debug_mode) - fwd_oc = optimized_misty_closure( + fwd_oc = misty_closure( dri.fwd_ret_type, dri.fwd_ir, dri.shared_data... ) - rvs_oc = optimized_misty_closure( + rvs_oc = misty_closure( dri.rvs_ret_type, dri.rvs_ir, dri.shared_data... ) @@ -1181,6 +1181,9 @@ function generate_ir( # Grab code associated to the primal. ir, _ = lookup_ir(interp, sig_or_mi) + @static if VERSION > v"1.12-" + ir = set_valid_world!(ir, interp.world) + end Treturn = compute_ir_rettype(ir) fwd_ret_type = forwards_ret_type(ir) rvs_ret_type = pullback_ret_type(ir) diff --git a/src/utils.jl b/src/utils.jl index b5c01fe61e..5745b3b4cc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -354,55 +354,6 @@ function opaque_closure( )::Core.OpaqueClosure{sig,ret_type} end -function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...) - oc = opaque_closure(rtype, ir, env...; kwargs...) - world = UInt(oc.world) - set_world_bounds_for_optimization!(oc) - optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...) - return optimized_oc -end - -function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...) - method = oc.source - ci = method.specializations.cache - world = UInt(oc.world) - ir = reinfer_and_inline(ci, world) - ir === nothing && return oc # nothing to optimize - return opaque_closure(rtype, ir, env...; kwargs...) -end - -# Allows optimization to make assumptions about binding access, -# enabling inlining and other optimizations. -function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure) - ci = oc.source.specializations.cache - ci.inferred === nothing && return nothing - ci.inferred.min_world = oc.world - ci.inferred.max_world = oc.world -end - -function reinfer_and_inline(ci::Core.CodeInstance, world::UInt) - interp = CC.NativeInterpreter(world) - mi = get_mi(ci) - argtypes = collect(Any, mi.specTypes.parameters) - irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world) - irsv === nothing && return nothing - for stmt in irsv.ir.stmts - inst = stmt[:inst] - if isexpr(inst, :loopinfo) || - isexpr(inst, :pop_exception) || - isa(inst, CC.GotoIfNot) || - isa(inst, CC.GotoNode) || - isexpr(inst, :copyast) - continue - end - stmt[:flag] |= CC.IR_FLAG_REFINED - end - CC.ir_abstract_constant_propagation(interp, irsv) - state = CC.InliningState(interp) - ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv)) - return ir -end - """ misty_closure( ret_type::Type, @@ -425,18 +376,6 @@ function misty_closure( return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)) end -function optimized_misty_closure( - ret_type::Type, - ir::IRCode, - @nospecialize env...; - isva::Bool=false, - do_compile::Bool=true, -) - return MistyClosure( - optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir) - ) -end - @static if VERSION > v"1.12-" compute_ir_rettype(ir) = CC.compute_ir_rettype(ir) compute_oc_signature(ir, nargs, isva) = CC.compute_oc_signature(ir, nargs, isva) diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index e4d91472f0..1e7b5f64ed 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -20,8 +20,6 @@ end @test length(stmt(ir.stmts)) == length(src.code) oc = Mooncake.opaque_closure(rettype, ir) @test oc(args...) == f(args...) - oc = Mooncake.optimized_opaque_closure(rettype, ir) - @test oc(args...) == f(args...) end @testset "infer_ir!" begin @@ -44,8 +42,6 @@ end ir.argtypes[1] = Tuple oc = Mooncake.opaque_closure(Float64, ir) @test oc(5.0) == cos(sin(5.0)) - oc = Mooncake.optimized_opaque_closure(Float64, ir) - @test oc(5.0) == cos(sin(5.0)) end @testset "lookup_ir" begin tt = Tuple{typeof(sin),Float64} diff --git a/test/utils.jl b/test/utils.jl index 09ee3d251f..6d075e2488 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -116,14 +116,9 @@ oc2 = Mooncake.opaque_closure(Any, ir) @test oc2 isa Core.OpaqueClosure{Tuple{Float64},Any} @test oc2(5.0) == sin(5.0) - oc2 = Mooncake.optimized_opaque_closure(Any, ir) - @test oc2 isa Core.OpaqueClosure{Tuple{Float64},Any} - @test oc2(5.0) == sin(5.0) # Check that we can get a MistyClosure also. mc = Mooncake.misty_closure(Float64, ir) @test mc(5.0) == sin(5.0) - mc = Mooncake.optimized_misty_closure(Float64, ir) - @test mc(5.0) == sin(5.0) end end