Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function _show_interp(io::IO, ::MIME"text/plain", ::MooncakeInterpreter)
return print(io, "MooncakeInterpreter()")
end

MooncakeInterpreter() = MooncakeInterpreter(DefaultCtx)
MooncakeInterpreter(; world=Base.get_world_counter()) = MooncakeInterpreter(DefaultCtx; world)

context_type(::MooncakeInterpreter{C}) where {C} = C

Expand Down Expand Up @@ -203,16 +203,16 @@ Globally cached interpreter. Should only be accessed via `get_interpreter`.
const GLOBAL_INTERPRETER = Ref(MooncakeInterpreter())

"""
get_interpreter()
get_interpreter(; world=Base.get_world_counter)

Returns a `MooncakeInterpreter` appropriate for the current world age. Will use a cached
interpreter if one already exists for the current world age, otherwise creates a new one.

This should be prefered over constructing a `MooncakeInterpreter` directly.
"""
function get_interpreter()
if GLOBAL_INTERPRETER[].world != Base.get_world_counter()
GLOBAL_INTERPRETER[] = MooncakeInterpreter()
function get_interpreter(; world=Base.get_world_counter())
if GLOBAL_INTERPRETER[].world != world
GLOBAL_INTERPRETER[] = MooncakeInterpreter(; world)
end
return GLOBAL_INTERPRETER[]
end
106 changes: 84 additions & 22 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ function make_ad_stmts! end
`nothing` as a statement in Julia IR indicates the presence of a line which will later be
removed. We emit a no-op on both the forwards- and reverse-passes. No shared data.
=#
function make_ad_stmts!(::Nothing, line::ID, ::ADInfo)
function make_ad_stmts!(::Nothing, line::ID, ::ADInfo; world=Base.get_world_counter())
return ad_stmt_info(line, nothing, nothing, nothing)
end

Expand All @@ -424,7 +424,7 @@ end
For cases 2 and 3, we also insert a call to `typeassert` to ensure that `info.fwd_ret_type`
is respected. A similar check for `info.rvs_ret_type` is handled elsewhere.
=#
function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo; world=Base.get_world_counter())
if !is_reachable_return_node(stmt)
return ad_stmt_info(line, nothing, inc_args(stmt), nothing)
end
Expand All @@ -451,12 +451,12 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
end

# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo)
function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo; world=Base.get_world_counter())
return ad_stmt_info(line, nothing, inc_args(stmt), nothing)
end

# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo)
function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo; world=Base.get_world_counter())
stmt = inc_args(stmt)

# If cond is not going to be wrapped in a `CoDual`, so just return the stmt.
Expand All @@ -472,7 +472,7 @@ function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo)
end

# Identity forwards-pass, no-op reverse. No shared data.
function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo)
function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo; world=Base.get_world_counter())
vals = stmt.values
new_vals = Vector{Any}(undef, length(vals))
for n in eachindex(vals)
Expand All @@ -489,7 +489,7 @@ function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo)
return ad_stmt_info(line, nothing, _inst, nothing)
end

function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo; world=Base.get_world_counter())

# PiNodes of the form `π (nothing, Union{})` have started appearing in 1.11. These nodes
# appear in unreachable sections of code, and appear to serve no purpose. Consequently,
Expand Down Expand Up @@ -528,7 +528,7 @@ end
# Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by
# assuming that they are constant, and creating a CoDual with the value. We then check at
# run-time that the value has not changed.
function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo)
function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo; world=Base.get_world_counter())
isconst(stmt) && return const_ad_stmt(stmt, line, info)

const_id, globalref_id = ID(), ID()
Expand All @@ -547,10 +547,10 @@ end
end

# QuoteNodes are constant.
make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info)

# Literal constant.
make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
make_ad_stmts!(stmt, line::ID, info::ADInfo; world=Base.get_world_counter()) = const_ad_stmt(stmt, line, info)

"""
const_ad_stmt(stmt, line::ID, info::ADInfo)
Expand Down Expand Up @@ -621,12 +621,12 @@ get_const_primal_value(x::QuoteNode) = x.value
get_const_primal_value(x) = x

# Mooncake does not yet handle `PhiCNode`s. Throw an error if one is encountered.
function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo)
function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo; world=Base.get_world_counter())
return unhandled_feature("Encountered PhiCNode: $stmt")
end

# Mooncake does not yet handle `UpsilonNode`s. Throw an error if one is encountered.
function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo)
function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo; world=Base.get_world_counter())
return unhandled_feature(
"Encountered UpsilonNode: $stmt. These are generated as part of some try / catch " *
"/ finally blocks. At the present time, Mooncake.jl cannot differentiate through " *
Expand All @@ -640,7 +640,7 @@ end

# There are quite a number of possible `Expr`s that can be encountered. Each case has its
# own comment, explaining what is going on.
function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo; world=Base.get_world_counter())
is_invoke = Meta.isexpr(stmt, :invoke)
if Meta.isexpr(stmt, :call) || is_invoke

Expand Down Expand Up @@ -669,7 +669,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
rrule!! # intrinsic / builtin / thing we provably have rule for
elseif is_invoke
mi = stmt.args[1]::Core.MethodInstance
LazyDerivedRule(mi, info.debug_mode) # Static dispatch
LazyDerivedRule(mi, info.debug_mode; world) # Static dispatch
else
DynamicDerivedRule(info.debug_mode) # Dynamic dispatch
end
Expand Down Expand Up @@ -1066,6 +1066,68 @@ struct DerivedRuleInfo
isva::Bool
end

function generated_build_rrule_body(world::UInt, lnn, this, sig)
sig isa Type{<:Type{<:Tuple}} || error()
sig = sig.parameters[1]
interp = MooncakeInterpreter(DefaultCtx; world)
rule = build_rrule(interp, sig; world)

ci = expr_to_codeinfo(@__MODULE__(), [Symbol("#self#"), :sig], [], (), :(return $rule))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably true that just returning this is unsafe because the DerivedRule contains various bits of state which get read from / written to in specific orders during the forwards- and reverse passes. The caching functionality in build_rrule therefore makes a call to _copy, which ensures that each time we ask for a rule, it gets fresh state which is definitely not shared with any other rule for this signature. Ensuring that no two instances of a rule share the same state ensures that we avoid race conditions if we're running the same rule on multiple threads, and makes recursion work correctly (the way that we handle recursion is a bit weird).

Anyway, we'll have to do something similar here. You could simply replace the :(return $rule) with :(return _copy($rule). This will obviously have worse performance (not allocation-free), but is probably going to be better than what we have at the minute because it will still avoid the type unstable code here).

In order to do better, we would need to avoid _copy altogether. I suspect that we'll have to do some clever caching here to make this work well. I suspect that it can be done, but maybe it's something to do in a follow-up PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, makes sense. I didn't know there was state inside the rule as well, that's unfortunate.


# Attached edges from MethodInstrances of `f` to to this CodeInfo.
# This should make it so that adding methods to `f` will
# triggers recompilation, fixing the #265 equivalent for generated functions.
matches = Base._methods_by_ftype(sig, -1, world)
if !isnothing(matches)
ci.edges = Core.MethodInstance[]
for match in Base._methods_by_ftype(sig, -1, world)
mi = Core.Compiler.specialize_method(match) # in v1.12 Core.Compiler.specialize_method is Base.specialize_method
push!(ci.edges, mi)
end
end
return ci
end

function expr_to_codeinfo(m::Module, argnames, spnames, sp, e::Expr)
# This trick comes from https://github.com/NHDaly/StagedFunctions.jl/commit/22fc72740093892baa442850a1fd61d9cd61b4cd
lam = Expr(:lambda, argnames,
Expr(Symbol("scope-block"),
Expr(:block,
Expr(:return,
Expr(:block,
e,
)))))
ex = if spnames === nothing || isempty(spnames)
lam
else
Expr(Symbol("with-static-parameters"), lam, spnames...)
end

#=
!!!!
In v1.12+ `jl_expand_and_resolve` was deleted. We'll need to use Base.generated_body_to_codeinfo instead.
See the details in the changes base/expr.jl and src/method.c in commit
https://github.com/JuliaLang/julia/commit/ce507a7e70603631b7c9e242528c1b611f918b01
!!!!
=#

# Get the code-info for the generatorbody in order to use it for generating a dummy
# code info object.
ci = ccall(:jl_expand_and_resolve, Any, (Any, Any, Core.SimpleVector), ex, m, Core.svec(sp...))
@assert ci isa Core.CodeInfo "Failed to create a CodeInfo from the given expression. This might mean it contains a closure or comprehension?\n Offending expression: $e"
ci
end

function refresh_generated_build_rrule_body()
@eval function generated_build_rrule(sig)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, generated_build_rrule_body))
end
end
#handy util for when working on generate_codeinfo
refresh_generated_build_rrule_body()


"""
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Expand All @@ -1075,12 +1137,12 @@ docstring for `rrule!!` for more info.
If `debug_mode` is `true`, then all calls to rules are replaced with calls to `DebugRRule`s.
"""
function build_rrule(
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true
) where {C}
interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true, world=Base.get_world_counter()
) where {C}

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
if world > interp.world
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, we should probably tweak the way that Mooncake handles world ages generally. The current approach is something of a hack (notice, for example, that DynamicDerivedRules don't accept a MooncakeInterpreter, and therefore don't have a pinned world age -- they just always grab a new interpreter on-the-fly, and use it to derive a rule). We can worry about that a bit later though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that'd be a good idea. I think it'd also potentially be a good idea to keep around a dictionary of all the different interpreters from different worlds, rather than just the latest interpreter. Not sure how you feel about that though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea -- what do you have in mind? Make get_intepreter accept a world age + make the constant containing the interpreter a dictionary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah exactly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be great. My thinking is that we should wind up removing the above check entirely, and just rely on correctly propagating the world age through to everything via the interpreter.

Is this something you would be interested in doing as part of this PR, or would you rather it were done as part of a separate one?

throw(
ArgumentError(
"World age associated to interp is behind current world age. Please " *
Expand Down Expand Up @@ -1111,7 +1173,7 @@ function build_rrule(
return _copy(interp.oc_cache[oc_cache_key])
else
# Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s.
dri = generate_ir(interp, sig_or_mi; debug_mode)
dri = generate_ir(interp, sig_or_mi; debug_mode, world)
fwd_oc = misty_closure(dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...)
rvs_oc = misty_closure(dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...)

Expand Down Expand Up @@ -1147,7 +1209,7 @@ end
Used by `build_rrule`, and the various debugging tools: primal_ir, fwds_ir, adjoint_ir.
"""
function generate_ir(
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true, world=Base.get_world_counter()
)
# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
Expand All @@ -1160,7 +1222,7 @@ function generate_ir(
rvs_ret_type = pullback_ret_type(ir)

# Normalise the IR, and generated BBCode version of it.
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
isva, spnames = is_vararg_and_sparam_names(sig_or_mi; world)
ir = normalise!(ir, spnames)
primal_ir = remove_unreachable_blocks!(BBCode(ir))

Expand All @@ -1172,7 +1234,7 @@ function generate_ir(
ad_stmts_blocks = map(primal_ir.blocks) do primal_blk
ids = primal_blk.inst_ids
primal_stmts = map(x -> x.stmt, primal_blk.insts)
return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info)))
return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info); world))
end

# Make shared data, and construct BBCode for forwards-pass and pullback.
Expand Down Expand Up @@ -1805,8 +1867,8 @@ mutable struct LazyDerivedRule{primal_sig,Trule}
debug_mode::Bool
mi::Core.MethodInstance
rule::Trule
function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool)
interp = get_interpreter()
function LazyDerivedRule(mi::Core.MethodInstance, debug_mode::Bool; world=Base.get_world_counter())
interp = get_interpreter(; world)
return new{mi.specTypes,rule_type(interp, mi;debug_mode)}(debug_mode, mi)
end
function LazyDerivedRule{Tprimal_sig,Trule}(
Expand Down
3 changes: 1 addition & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m)

Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it.
"""
function is_vararg_and_sparam_names(sig)::Tuple{Bool,Vector{Symbol}}
world = Base.get_world_counter()
function is_vararg_and_sparam_names(sig; world=Base.get_world_counter())::Tuple{Bool,Vector{Symbol}}
min = Base.RefValue{UInt}(typemin(UInt))
max = Base.RefValue{UInt}(typemax(UInt))
ms = Base._methods_by_ftype(
Expand Down
Loading