-
Notifications
You must be signed in to change notification settings - Fork 29
Demo of compile-time rule generation #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -401,7 +401,7 @@ function make_ad_stmts! end | |
| `nothing` as a statement in Julia IR indicates the presence of a line which will later be | ||
| removed. We emit a no-op on both the forwards- and reverse-passes. No shared data. | ||
| =# | ||
| function make_ad_stmts!(::Nothing, line::ID, ::ADInfo) | ||
| function make_ad_stmts!(::Nothing, line::ID, ::ADInfo; world=Base.get_world_counter()) | ||
| return ad_stmt_info(line, nothing, nothing, nothing) | ||
| end | ||
|
|
||
|
|
@@ -424,7 +424,7 @@ end | |
| For cases 2 and 3, we also insert a call to `typeassert` to ensure that `info.fwd_ret_type` | ||
| is respected. A similar check for `info.rvs_ret_type` is handled elsewhere. | ||
| =# | ||
| function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) | ||
| function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo; world=Base.get_world_counter()) | ||
| if !is_reachable_return_node(stmt) | ||
| return ad_stmt_info(line, nothing, inc_args(stmt), nothing) | ||
| end | ||
|
|
@@ -451,12 +451,12 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) | |
| end | ||
|
|
||
| # Identity forwards-pass, no-op reverse. No shared data. | ||
| function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo) | ||
| function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo; world=Base.get_world_counter()) | ||
| return ad_stmt_info(line, nothing, inc_args(stmt), nothing) | ||
| end | ||
|
|
||
| # Identity forwards-pass, no-op reverse. No shared data. | ||
| function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) | ||
| function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo; world=Base.get_world_counter()) | ||
| stmt = inc_args(stmt) | ||
|
|
||
| # If cond is not going to be wrapped in a `CoDual`, so just return the stmt. | ||
|
|
@@ -472,7 +472,7 @@ function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) | |
| end | ||
|
|
||
| # Identity forwards-pass, no-op reverse. No shared data. | ||
| function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) | ||
| function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo; world=Base.get_world_counter()) | ||
| vals = stmt.values | ||
| new_vals = Vector{Any}(undef, length(vals)) | ||
| for n in eachindex(vals) | ||
|
|
@@ -489,7 +489,7 @@ function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) | |
| return ad_stmt_info(line, nothing, _inst, nothing) | ||
| end | ||
|
|
||
| function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) | ||
| function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo; world=Base.get_world_counter()) | ||
|
|
||
| # PiNodes of the form `π (nothing, Union{})` have started appearing in 1.11. These nodes | ||
| # appear in unreachable sections of code, and appear to serve no purpose. Consequently, | ||
|
|
@@ -528,7 +528,7 @@ end | |
| # Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by | ||
| # assuming that they are constant, and creating a CoDual with the value. We then check at | ||
| # run-time that the value has not changed. | ||
| function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo) | ||
| function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo; world=Base.get_world_counter()) | ||
| isconst(stmt) && return const_ad_stmt(stmt, line, info) | ||
|
|
||
| const_id, globalref_id = ID(), ID() | ||
|
|
@@ -547,10 +547,10 @@ end | |
| end | ||
|
|
||
| # QuoteNodes are constant. | ||
| make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) | ||
| make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info) | ||
|
|
||
| # Literal constant. | ||
| make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) | ||
| make_ad_stmts!(stmt, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info) | ||
|
|
||
| """ | ||
| const_ad_stmt(stmt, line::ID, info::ADInfo) | ||
|
|
@@ -621,12 +621,12 @@ get_const_primal_value(x::QuoteNode) = x.value | |
| get_const_primal_value(x) = x | ||
|
|
||
| # Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered. | ||
| function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo) | ||
| function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo; world=Base.get_world_counter()) | ||
| return unhandled_feature("Encountered PhiCNode: $stmt") | ||
| end | ||
|
|
||
| # Mooncake does not yet handle `UpsilonNode`s. Throw an error if one is encountered. | ||
| function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) | ||
| function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo; world=Base.get_world_counter()) | ||
| return unhandled_feature( | ||
| "Encountered UpsilonNode: $stmt. These are generated as part of some try / catch " * | ||
| "/ finally blocks. At the present time, Mooncake.jl cannot differentiate through " * | ||
|
|
@@ -640,7 +640,7 @@ end | |
|
|
||
| # There are quite a number of possible `Expr`s that can be encountered. Each case has its | ||
| # own comment, explaining what is going on. | ||
| function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) | ||
| function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo; world=Base.get_world_counter()) | ||
| is_invoke = Meta.isexpr(stmt, :invoke) | ||
| if Meta.isexpr(stmt, :call) || is_invoke | ||
|
|
||
|
|
@@ -669,7 +669,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) | |
| rrule!! # intrinsic / builtin / thing we provably have rule for | ||
| elseif is_invoke | ||
| mi = stmt.args[1]::Core.MethodInstance | ||
| LazyDerivedRule(mi, info.debug_mode) # Static dispatch | ||
| LazyDerivedRule(mi, info.debug_mode; world) # Static dispatch | ||
| else | ||
| DynamicDerivedRule(info.debug_mode) # Dynamic dispatch | ||
| end | ||
|
|
@@ -1066,6 +1066,68 @@ struct DerivedRuleInfo | |
| isva::Bool | ||
| end | ||
|
|
||
| function generated_build_rrule_body(world::UInt, lnn, this, sig) | ||
| sig isa Type{<:Type{<:Tuple}} || error() | ||
| sig = sig.parameters[1] | ||
| interp = MooncakeInterpreter(DefaultCtx; world) | ||
| rule = build_rrule(interp, sig; world) | ||
|
|
||
| ci = expr_to_codeinfo(@__MODULE__(), [Symbol("#self#"), :sig], [], (), :(return $rule)) | ||
|
|
||
| # Attached edges from MethodInstrances of `f` to to this CodeInfo. | ||
| # This should make it so that adding methods to `f` will | ||
| # triggers recompilation, fixing the #265 equivalent for generated functions. | ||
| matches = Base._methods_by_ftype(sig, -1, world) | ||
| if !isnothing(matches) | ||
| ci.edges = Core.MethodInstance[] | ||
| for match in Base._methods_by_ftype(sig, -1, world) | ||
| mi = Core.Compiler.specialize_method(match) # in v1.12 Core.Compiler.specialize_method is Base.specialize_method | ||
| push!(ci.edges, mi) | ||
| end | ||
| end | ||
| return ci | ||
| end | ||
|
|
||
| function expr_to_codeinfo(m::Module, argnames, spnames, sp, e::Expr) | ||
| # This trick comes from https://github.com/NHDaly/StagedFunctions.jl/commit/22fc72740093892baa442850a1fd61d9cd61b4cd | ||
| lam = Expr(:lambda, argnames, | ||
| Expr(Symbol("scope-block"), | ||
| Expr(:block, | ||
| Expr(:return, | ||
| Expr(:block, | ||
| e, | ||
| ))))) | ||
| ex = if spnames === nothing || isempty(spnames) | ||
| lam | ||
| else | ||
| Expr(Symbol("with-static-parameters"), lam, spnames...) | ||
| end | ||
|
|
||
| #= | ||
| !!!! | ||
| In v1.12+ `jl_expand_and_resolve` was deleted. We'll need to use Base.generated_body_to_codeinfo instead. | ||
| See the details in the changes base/expr.jl and src/method.c in commit | ||
| https://github.com/JuliaLang/julia/commit/ce507a7e70603631b7c9e242528c1b611f918b01 | ||
| !!!! | ||
| =# | ||
|
|
||
| # Get the code-info for the generatorbody in order to use it for generating a dummy | ||
| # code info object. | ||
| ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Core.SimpleVector), ex, m, Core.svec(sp...)) | ||
| @assert ci isa Core.CodeInfo "Failed to create a CodeInfo from the given expression. This might mean it contains a closure or comprehension?\n Offending expression: $e" | ||
| ci | ||
| end | ||
|
|
||
| function refresh_generated_build_rrule_body() | ||
| @eval function generated_build_rrule(sig) | ||
| $(Expr(:meta, :generated_only)) | ||
| $(Expr(:meta, :generated, generated_build_rrule_body)) | ||
| end | ||
| end | ||
| #handy util for when working on generate_codeinfo | ||
| refresh_generated_build_rrule_body() | ||
|
|
||
|
|
||
| """ | ||
| build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C} | ||
|
|
||
|
|
@@ -1075,12 +1137,12 @@ docstring for `rrule!!` for more info. | |
| If `debug_mode` is `true`, then all calls to rules are replaced with calls to `DebugRRule`s. | ||
| """ | ||
| function build_rrule( | ||
| interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true | ||
| ) where {C} | ||
| interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true, world=Base.get_world_counter() | ||
| ) where {C} | ||
|
|
||
| # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater | ||
| # than the current world age. | ||
| if Base.get_world_counter() > interp.world | ||
| if world > interp.world | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before merging, we should probably tweak the way that Mooncake handles world ages generally. The current approach is something of a hack (notice, for example, that
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that'd be a good idea. I think it'd also potentially be a good idea to keep around a dictionary of all the different interpreters from different worlds, rather than just the latest interpreter. Not sure how you feel about that though
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea -- what do you have in mind? Make
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah exactly.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be great. My thinking is that we should wind up removing the above check entirely, and just rely on correctly propagating the world age through to everything via the interpreter. Is this something you would be interested in doing as part of this PR, or would you rather it were done as part of a separate one? |
||
| throw( | ||
| ArgumentError( | ||
| "World age associated to interp is behind current world age. Please " * | ||
|
|
@@ -1111,7 +1173,7 @@ function build_rrule( | |
| return _copy(interp.oc_cache[oc_cache_key]) | ||
| else | ||
| # Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s. | ||
| dri = generate_ir(interp, sig_or_mi; debug_mode) | ||
| dri = generate_ir(interp, sig_or_mi; debug_mode, world) | ||
| fwd_oc = misty_closure(dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...) | ||
| rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...) | ||
|
|
||
|
|
@@ -1147,7 +1209,7 @@ end | |
| Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir. | ||
| """ | ||
| function generate_ir( | ||
| interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true | ||
| interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true, world=Base.get_world_counter() | ||
| ) | ||
| # Reset id count. This ensures that the IDs generated are the same each time this | ||
| # function runs. | ||
|
|
@@ -1160,7 +1222,7 @@ function generate_ir( | |
| rvs_ret_type = pullback_ret_type(ir) | ||
|
|
||
| # Normalise the IR, and generated BBCode version of it. | ||
| isva, spnames = is_vararg_and_sparam_names(sig_or_mi) | ||
| isva, spnames = is_vararg_and_sparam_names(sig_or_mi; world) | ||
| ir = normalise!(ir, spnames) | ||
| primal_ir = remove_unreachable_blocks!(BBCode(ir)) | ||
|
|
||
|
|
@@ -1172,7 +1234,7 @@ function generate_ir( | |
| ad_stmts_blocks = map(primal_ir.blocks) do primal_blk | ||
| ids = primal_blk.inst_ids | ||
| primal_stmts = map(x -> x.stmt, primal_blk.insts) | ||
| return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info))) | ||
| return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info); world)) | ||
| end | ||
|
|
||
| # Make shared data, and construct BBCode for forwards-pass and pullback. | ||
|
|
@@ -1805,8 +1867,8 @@ mutable struct LazyDerivedRule{primal_sig,Trule} | |
| debug_mode::Bool | ||
| mi::Core.MethodInstance | ||
| rule::Trule | ||
| function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool) | ||
| interp = get_interpreter() | ||
| function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool; world=Base.get_world_counter()) | ||
| interp = get_interpreter(; world) | ||
| return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi) | ||
| end | ||
| function LazyDerivedRule{Tprimal_sig,Trule}( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's probably true that just returning this is unsafe because the
DerivedRulecontains various bits of state which get read from / written to in specific orders during the forwards- and reverse passes. The caching functionality in build_rrule therefore makes a call to_copy, which ensures that each time we ask for a rule, it gets fresh state which is definitely not shared with any other rule for this signature. Ensuring that no two instances of a rule share the same state ensures that we avoid race conditions if we're running the same rule on multiple threads, and makes recursion work correctly (the way that we handle recursion is a bit weird).Anyway, we'll have to do something similar here. You could simply replace the
:(return $rule)with:(return _copy($rule). This will obviously have worse performance (not allocation-free), but is probably going to be better than what we have at the minute because it will still avoid the type unstable code here).In order to do better, we would need to avoid
_copyaltogether. I suspect that we'll have to do some clever caching here to make this work well. I suspect that it can be done, but maybe it's something to do in a follow-up PR?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, makes sense. I didn't know there was state inside the rule as well, that's unfortunate.