Skip to content

Forward-mode AD#389

Merged
yebai merged 226 commits intochalk-lab:mainfrom
gdalle:gd/forward
Aug 12, 2025
Merged

Forward-mode AD#389
yebai merged 226 commits intochalk-lab:mainfrom
gdalle:gd/forward

Conversation

@gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation.

Will's edits (apologies for editing your thing @gdalle -- I just want to make sure that the todo list is at the top of the PR):

Todo:

  • make FunctionWrappers work correctly not going to do this in this PR
  • add support for MistyClosures
  • add tests for Hessian vector products
  • define is_primitive separately for forwards and reverse pass.
  • do a complete pass to review design -- are there any high-level things we ought to modify?
  • improve DRY-ness of code, particularly in testing infrastructure in particular.
  • check GPU compatibility, make sure no major design issues prevent future GPU compatibility, and be explicit about what needs to be done in the future.
  • what name should we use for @from_rule: @from_chainrules or @from_chain_rule, see comments below.
  • add support for UpsilonNodes and PhiCNodes.
  • get all tests passing
  • bump to version 0.5 actually not needed

Once the above are complete, I'll request reviews.

@codecov
Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 94.04070% with 82 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/s2s_forward_mode_ad.jl 88.77% 22 Missing ⚠️
src/test_utils.jl 86.66% 16 Missing ⚠️
src/rrules/foreigncall.jl 75.75% 8 Missing ⚠️
src/rrules/memory.jl 87.69% 8 Missing ⚠️
src/utils.jl 76.92% 6 Missing ⚠️
src/rrules/tasks.jl 64.28% 5 Missing ⚠️
src/dual.jl 85.71% 3 Missing ⚠️
src/rrules/builtins.jl 97.82% 3 Missing ⚠️
src/developer_tools.jl 0.00% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 71.42% 2 Missing ⚠️
... and 5 more
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <ø> (ø)
src/interpreter/ir_utils.jl 89.68% <100.00%> (+2.81%) ⬆️
src/rrules/array_legacy.jl 100.00% <100.00%> (ø)
src/rrules/avoiding_non_differentiable_code.jl 100.00% <100.00%> (ø)
src/rrules/blas.jl 99.64% <100.00%> (+0.84%) ⬆️
src/rrules/fastmath.jl 100.00% <100.00%> (ø)
src/rrules/lapack.jl 100.00% <100.00%> (+0.56%) ⬆️
src/rrules/linear_algebra.jl 100.00% <100.00%> (ø)
src/rrules/low_level_maths.jl 100.00% <100.00%> (ø)
src/rrules/new.jl 91.30% <100.00%> (+2.84%) ⬆️
... and 20 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Collaborator

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Collaborator

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Collaborator

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Collaborator

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Collaborator

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See https://github.com/compintell/Mooncake.jl/blob/0f37c079bd1ae064e7b84696eed4a1f7eb763f1f/src/interpreter/s2s_reverse_mode_ad.jl#L728

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@willtebbutt
Copy link
Collaborator

Okay. I think I've now addressed all of @gdalle 's feedback, and this PR is in a state that's basically ready to go.

I'm now going to be offline for a couple of weeks, so I'm happy for @gdalle or @yebai to handle any remaining issues and merge if there's a strong need to merge before I'm back.

@gdalle
Copy link
Collaborator Author

gdalle commented Aug 8, 2025

After I review again I'd be in favor of merging but explicitly marking this as experimental, so that DI can add the functionality and then people can stress-test it on real problems. This will probably be a much better way to find bugs than refining a 4k-LOC PR

@yebai
Copy link
Member

yebai commented Aug 8, 2025

Great work, @willtebbutt! I’m also supportive of merging this PR now as an experimental feature.

@yebai yebai requested a review from Copilot August 8, 2025 21:58
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request implements forward-mode automatic differentiation (AD) support for Mooncake.jl. The implementation provides a complete framework for forward-mode AD that operates alongside the existing reverse-mode capabilities.

Key changes include:

  • Addition of Dual type for forward-mode computations and comprehensive frule!! infrastructure
  • Introduction of mode-specific is_primitive function and unified @zero_derivative/@from_chainrules macros
  • Implementation of forward-mode rules (frule!!) for all existing primitives and mathematical operations

Reviewed Changes

Copilot reviewed 82 out of 83 changed files in this pull request and generated 2 comments.

File Description
src/tools_for_rules.jl Adds @zero_derivative and @from_chainrules macros, frule_wrapper functionality, and mooncake_tangent conversion
src/test_utils.jl Updates testing infrastructure to support both forward and reverse mode testing with new test_rule interface
src/rrules/*.jl Implements frule!! methods for all mathematical operations, linear algebra, memory operations, and built-ins
test/*.jl Updates all test files to use new testing interface and mode-specific functionality
Comments suppressed due to low confidence (3)

src/tools_for_rules.jl:162

  • [nitpick] The explicit return type annotation ::Tuple{Bool,Vector{Symbol}} is unnecessary and adds visual clutter. Julia's type inference can determine this automatically from the function body.
passing all of its arguments (including the function itself) to this function. For example:

src/tools_for_rules.jl:194

  • [nitpick] The explicit return type annotation ::Tuple{Bool,Vector{Symbol}} is redundant since it's already specified in the previous method and Julia can infer this type.
```jldoctest

src/rrules/lapack.jl:97

  • This commented-out code should be removed rather than left as a comment, as it adds clutter and may cause confusion about the intended behavior.
    # Restore initial state.

@yebai
Copy link
Member

yebai commented Aug 11, 2025

@sunxd3, can you help take a look at Turing integration Test failures?

yebai and others added 3 commits August 12, 2025 19:18
Signed-off-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@yebai yebai merged commit 83b4ff7 into chalk-lab:main Aug 12, 2025
88 checks passed
Comment on lines +3 to +4
## Public Interface
- Mooncake offers forward mode AD.
Copy link
Collaborator

@penelopeysm penelopeysm Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's amazing work, though it's not clear how one can use it? I assume that ADTypes / DifferentiationInterface support will take a bit of time to arrive, but in the meantime do I just replace value_and_gradient!! with value_and_derivative!!?

And congratulations to all involved! 🎉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DI support landing today

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gdalle gdalle deleted the gd/forward branch August 13, 2025 15:20
@yebai yebai mentioned this pull request Aug 16, 2025
Technici4n added a commit that referenced this pull request Nov 6, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
yebai pushed a commit that referenced this pull request Nov 6, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
penelopeysm pushed a commit that referenced this pull request Nov 10, 2025
Seems to have been mistakenly duplicated as part of #389.

Signed-off-by: Bruno Ploumhans <13494793+Technici4n@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants