diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index fc70125c4a..9097c0064c 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -75,7 +75,7 @@ function _show_interp(io::IO, ::MIME"text/plain", ::MooncakeInterpreter) return print(io, "MooncakeInterpreter()") end -MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx) +MooncakeInterpreter(; world=Base.get_world_counter()) = MooncakeInterpreter(DefaultCtx; world) context_type(::MooncakeInterpreter{C}) where {C} = C @@ -203,16 +203,16 @@ Globally cached interpreter. Should only be accessed via `get_interpreter`. const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter()) """ - get_interpreter() + get_interpreter(; world=Base.get_world_counter) Returns a `MooncakeInterpreter` appropriate for the current world age. Will use a cached interpreter if one already exists for the current world age, otherwise creates a new one. This should be prefered over constructing a `MooncakeInterpreter` directly. """ -function get_interpreter() - if GLOBAL_INTERPRETER[].world != Base.get_world_counter() - GLOBAL_INTERPRETER[] = MooncakeInterpreter() +function get_interpreter(; world=Base.get_world_counter()) + if GLOBAL_INTERPRETER[].world != world + GLOBAL_INTERPRETER[] = MooncakeInterpreter(; world) end return GLOBAL_INTERPRETER[] end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 8e5d9386f8..2e10fc1d40 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -401,7 +401,7 @@ function make_ad_stmts! end `nothing` as a statement in Julia IR indicates the presence of a line which will later be removed. We emit a no-op on both the forwards- and reverse-passes. No shared data. =# -function make_ad_stmts!(::Nothing, line::ID, ::ADInfo) +function make_ad_stmts!(::Nothing, line::ID, ::ADInfo; world=Base.get_world_counter()) return ad_stmt_info(line, nothing, nothing, nothing) end @@ -424,7 +424,7 @@ end For cases 2 and 3, we also insert a call to `typeassert` to ensure that `info.fwd_ret_type` is respected. A similar check for `info.rvs_ret_type` is handled elsewhere. =# -function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) +function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo; world=Base.get_world_counter()) if !is_reachable_return_node(stmt) return ad_stmt_info(line, nothing, inc_args(stmt), nothing) end @@ -451,12 +451,12 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) end # Identity forwards-pass, no-op reverse. No shared data. -function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo) +function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo; world=Base.get_world_counter()) return ad_stmt_info(line, nothing, inc_args(stmt), nothing) end # Identity forwards-pass, no-op reverse. No shared data. -function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) +function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo; world=Base.get_world_counter()) stmt = inc_args(stmt) # If cond is not going to be wrapped in a `CoDual`, so just return the stmt. @@ -472,7 +472,7 @@ function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) end # Identity forwards-pass, no-op reverse. No shared data. -function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) +function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo; world=Base.get_world_counter()) vals = stmt.values new_vals = Vector{Any}(undef, length(vals)) for n in eachindex(vals) @@ -489,7 +489,7 @@ function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) return ad_stmt_info(line, nothing, _inst, nothing) end -function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) +function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo; world=Base.get_world_counter()) # PiNodes of the form `π (nothing, Union{})` have started appearing in 1.11. These nodes # appear in unreachable sections of code, and appear to serve no purpose. Consequently, @@ -528,7 +528,7 @@ end # Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by # assuming that they are constant, and creating a CoDual with the value. We then check at # run-time that the value has not changed. -function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo) +function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo; world=Base.get_world_counter()) isconst(stmt) && return const_ad_stmt(stmt, line, info) const_id, globalref_id = ID(), ID() @@ -547,10 +547,10 @@ end end # QuoteNodes are constant. -make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) +make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info) # Literal constant. -make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) +make_ad_stmts!(stmt, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info) """ const_ad_stmt(stmt, line::ID, info::ADInfo) @@ -621,12 +621,12 @@ get_const_primal_value(x::QuoteNode) = x.value get_const_primal_value(x) = x # Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered. -function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo) +function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo; world=Base.get_world_counter()) return unhandled_feature("Encountered PhiCNode: $stmt") end # Mooncake does not yet handle `UpsilonNode`s. Throw an error if one is encountered. -function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) +function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo; world=Base.get_world_counter()) return unhandled_feature( "Encountered UpsilonNode: $stmt. These are generated as part of some try / catch " * "/ finally blocks. At the present time, Mooncake.jl cannot differentiate through " * @@ -640,7 +640,7 @@ end # There are quite a number of possible `Expr`s that can be encountered. Each case has its # own comment, explaining what is going on. -function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) +function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo; world=Base.get_world_counter()) is_invoke = Meta.isexpr(stmt, :invoke) if Meta.isexpr(stmt, :call) || is_invoke @@ -669,7 +669,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke mi = stmt.args[1]::Core.MethodInstance - LazyDerivedRule(mi, info.debug_mode) # Static dispatch + LazyDerivedRule(mi, info.debug_mode; world) # Static dispatch else DynamicDerivedRule(info.debug_mode) # Dynamic dispatch end @@ -1066,6 +1066,68 @@ struct DerivedRuleInfo isva::Bool end +function generated_build_rrule_body(world::UInt, lnn, this, sig) + sig isa Type{<:Type{<:Tuple}} || error() + sig = sig.parameters[1] + interp = MooncakeInterpreter(DefaultCtx; world) + rule = build_rrule(interp, sig; world) + + ci = expr_to_codeinfo(@__MODULE__(), [Symbol("#self#"), :sig], [], (), :(return $rule)) + + # Attached edges from MethodInstrances of `f` to to this CodeInfo. + # This should make it so that adding methods to `f` will + # triggers recompilation, fixing the #265 equivalent for generated functions. + matches = Base._methods_by_ftype(sig, -1, world) + if !isnothing(matches) + ci.edges = Core.MethodInstance[] + for match in Base._methods_by_ftype(sig, -1, world) + mi = Core.Compiler.specialize_method(match) # in v1.12 Core.Compiler.specialize_method is Base.specialize_method + push!(ci.edges, mi) + end + end + return ci +end + +function expr_to_codeinfo(m::Module, argnames, spnames, sp, e::Expr) + # This trick comes from https://github.com/NHDaly/StagedFunctions.jl/commit/22fc72740093892baa442850a1fd61d9cd61b4cd + lam = Expr(:lambda, argnames, + Expr(Symbol("scope-block"), + Expr(:block, + Expr(:return, + Expr(:block, + e, + ))))) + ex = if spnames === nothing || isempty(spnames) + lam + else + Expr(Symbol("with-static-parameters"), lam, spnames...) + end + + #= + !!!! + In v1.12+ `jl_expand_and_resolve` was deleted. We'll need to use Base.generated_body_to_codeinfo instead. + See the details in the changes base/expr.jl and src/method.c in commit + https://github.com/JuliaLang/julia/commit/ce507a7e70603631b7c9e242528c1b611f918b01 + !!!! + =# + + # Get the code-info for the generatorbody in order to use it for generating a dummy + # code info object. + ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Core.SimpleVector), ex, m, Core.svec(sp...)) + @assert ci isa Core.CodeInfo "Failed to create a CodeInfo from the given expression. This might mean it contains a closure or comprehension?\n Offending expression: $e" + ci +end + +function refresh_generated_build_rrule_body() + @eval function generated_build_rrule(sig) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, generated_build_rrule_body)) + end +end +#handy util for when working on generate_codeinfo +refresh_generated_build_rrule_body() + + """ build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C} @@ -1075,12 +1137,12 @@ docstring for `rrule!!` for more info. If `debug_mode` is `true`, then all calls to rules are replaced with calls to `DebugRRule`s. """ function build_rrule( - interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true -) where {C} + interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true, world=Base.get_world_counter() + ) where {C} # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. - if Base.get_world_counter() > interp.world + if world > interp.world throw( ArgumentError( "World age associated to interp is behind current world age. Please " * @@ -1111,7 +1173,7 @@ function build_rrule( return _copy(interp.oc_cache[oc_cache_key]) else # Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s. - dri = generate_ir(interp, sig_or_mi; debug_mode) + dri = generate_ir(interp, sig_or_mi; debug_mode, world) fwd_oc = misty_closure(dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...) rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...) @@ -1147,7 +1209,7 @@ end Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir. """ function generate_ir( - interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true + interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true, world=Base.get_world_counter() ) # Reset id count. This ensures that the IDs generated are the same each time this # function runs. @@ -1160,7 +1222,7 @@ function generate_ir( rvs_ret_type = pullback_ret_type(ir) # Normalise the IR, and generated BBCode version of it. - isva, spnames = is_vararg_and_sparam_names(sig_or_mi) + isva, spnames = is_vararg_and_sparam_names(sig_or_mi; world) ir = normalise!(ir, spnames) primal_ir = remove_unreachable_blocks!(BBCode(ir)) @@ -1172,7 +1234,7 @@ function generate_ir( ad_stmts_blocks = map(primal_ir.blocks) do primal_blk ids = primal_blk.inst_ids primal_stmts = map(x -> x.stmt, primal_blk.insts) - return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info))) + return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info); world)) end # Make shared data, and construct BBCode for forwards-pass and pullback. @@ -1805,8 +1867,8 @@ mutable struct LazyDerivedRule{primal_sig,Trule} debug_mode::Bool mi::Core.MethodInstance rule::Trule - function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) - interp = get_interpreter() + function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool; world=Base.get_world_counter()) + interp = get_interpreter(; world) return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi) end function LazyDerivedRule{Tprimal_sig,Trule}( diff --git a/src/utils.jl b/src/utils.jl index 5fb9e7983d..a9b77ee10a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -166,8 +166,7 @@ is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m) Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it. """ -function is_vararg_and_sparam_names(sig)::Tuple{Bool,Vector{Symbol}} - world = Base.get_world_counter() +function is_vararg_and_sparam_names(sig; world=Base.get_world_counter())::Tuple{Bool,Vector{Symbol}} min = Base.RefValue{UInt}(typemin(UInt)) max = Base.RefValue{UInt}(typemax(UInt)) ms = Base._methods_by_ftype(