Conversation
willtebbutt
left a comment
There was a problem hiding this comment.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
Co-authored-by: Will Tebbutt <willtebbutt00@gmail.com> Signed-off-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
|
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
|
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
endThe optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
|
I think we might need to be slightly more subtle. If an argument to the |
|
Yes. I think my propose code handles this though, or am I missing something? |
|
In the spirit of higher-order AD, we may encounter |
|
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
|
I still need to dig into the different node types we might encounter (and I still don't understand |
|
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
|
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
|
Were the difficulties around renumbering etc not resolved by not |
|
No they weren't. I experimented with |
|
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3)i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)Does this not cause the same kind of problems? |
|
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
|
Do you know what I should do about expressions of type |
|
Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728 The way to remove an instruction from an |
|
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
irjulia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
|
After I review again I'd be in favor of merging but explicitly marking this as experimental, so that DI can add the functionality and then people can stress-test it on real problems. This will probably be a much better way to find bugs than refining a 4k-LOC PR |
|
Great work, @willtebbutt! I’m also supportive of merging this PR now as an experimental feature. |
There was a problem hiding this comment.
Pull Request Overview
This pull request implements forward-mode automatic differentiation (AD) support for Mooncake.jl. The implementation provides a complete framework for forward-mode AD that operates alongside the existing reverse-mode capabilities.
Key changes include:
- Addition of
Dualtype for forward-mode computations and comprehensivefrule!!infrastructure - Introduction of mode-specific
is_primitivefunction and unified@zero_derivative/@from_chainrulesmacros - Implementation of forward-mode rules (
frule!!) for all existing primitives and mathematical operations
Reviewed Changes
Copilot reviewed 82 out of 83 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/tools_for_rules.jl | Adds @zero_derivative and @from_chainrules macros, frule_wrapper functionality, and mooncake_tangent conversion |
| src/test_utils.jl | Updates testing infrastructure to support both forward and reverse mode testing with new test_rule interface |
| src/rrules/*.jl | Implements frule!! methods for all mathematical operations, linear algebra, memory operations, and built-ins |
| test/*.jl | Updates all test files to use new testing interface and mode-specific functionality |
Comments suppressed due to low confidence (3)
src/tools_for_rules.jl:162
- [nitpick] The explicit return type annotation
::Tuple{Bool,Vector{Symbol}}is unnecessary and adds visual clutter. Julia's type inference can determine this automatically from the function body.
passing all of its arguments (including the function itself) to this function. For example:
src/tools_for_rules.jl:194
- [nitpick] The explicit return type annotation
::Tuple{Bool,Vector{Symbol}}is redundant since it's already specified in the previous method and Julia can infer this type.
```jldoctest
src/rrules/lapack.jl:97
- This commented-out code should be removed rather than left as a comment, as it adds clutter and may cause confusion about the intended behavior.
# Restore initial state.
|
@sunxd3, can you help take a look at Turing integration Test failures? |
Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
| ## Public Interface | ||
| - Mooncake offers forward mode AD. |
There was a problem hiding this comment.
That's amazing work, though it's not clear how one can use it? I assume that ADTypes / DifferentiationInterface support will take a bit of time to arrive, but in the meantime do I just replace value_and_gradient!! with value_and_derivative!!?
And congratulations to all involved! 🎉
There was a problem hiding this comment.
DI support landing today
There was a problem hiding this comment.
Seems to have been mistakenly duplicated as part of #389. Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
Seems to have been mistakenly duplicated as part of #389. Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
Seems to have been mistakenly duplicated as part of #389. Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.
Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):
Todo:
make FunctionWrappers work correctlynot going to do this in this PRis_primitiveseparately for forwards and reverse pass.@from_rule:@from_chainrulesor@from_chain_rule, see comments below.UpsilonNodes andPhiCNodes.bump to version 0.5actually not neededOnce the above are complete, I'll request reviews.