Skip to content

Enable second-order differentiation via forward-over-reverse#878

Merged
yebai merged 43 commits intomainfrom
sunxd/f-o-r
Dec 30, 2025
Merged

Enable second-order differentiation via forward-over-reverse#878
yebai merged 43 commits intomainfrom
sunxd/f-o-r

Conversation

@sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented Dec 6, 2025

Ref: #826

This PR aims to enable forward-over-reverse for second order derivatives. Many of the necessary plumbings are already implemented before this PR. This PR smooths out the remaining edges.

Fixes:

  1. World age fix for MistyClosure: Forward mode now uses MistyClosure's original world age instead of current world, fixing failures when zero_tangent(pb_mc) was called.
  2. Union{} type issue (this is somewhat 1.10 specific) Union{} can showup when type inference fails, the approach here is to replace them with Any
  3. Add some necessary rules to avoid hanging when testing some particular functions
    1. Core.memorynew
    2. LazyDerivedRule: avoid diff through get_interpreter (jl_get_world_counter)
    3. jl_alloc_array_1d/2d/3d

Pointers to some possible future work

Closes #632

@sunxd3 sunxd3 mentioned this pull request Dec 6, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Dec 7, 2025

Mooncake.jl documentation for PR #878 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR878/

@github-actions
Copy link
Contributor

github-actions bot commented Dec 7, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                      Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                     String │   String │   String │      String │  String │      String │ String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│                   sum_1000 │ 110.0 ns │     1.83 │         2.0 │    1.09 │        5.19 │   10.2 │
│                  _sum_1000 │ 952.0 ns │     6.71 │        1.01 │  1620.0 │        33.6 │   1.08 │
│               sum_sin_1000 │  6.58 μs │     2.61 │        1.42 │    1.75 │        10.7 │   1.93 │
│              _sum_sin_1000 │  5.34 μs │     3.18 │        2.15 │   271.0 │        13.0 │   2.45 │
│                   kron_sum │ 391.0 μs │     7.66 │        2.95 │    4.72 │       178.0 │   11.1 │
│              kron_view_sum │ 389.0 μs │     8.46 │        4.34 │    19.1 │       191.0 │   5.67 │
│      naive_map_sin_cos_exp │  2.18 μs │     2.45 │        1.44 │ missing │        7.06 │   2.32 │
│            map_sin_cos_exp │  2.17 μs │     2.89 │        1.48 │    1.71 │        6.02 │   2.91 │
│      broadcast_sin_cos_exp │  2.32 μs │     2.48 │        1.44 │    3.88 │        1.47 │   2.25 │
│                 simple_mlp │ 441.0 μs │     4.38 │        2.93 │    1.62 │        6.77 │   3.23 │
│                     gp_lml │ 195.0 μs │     10.5 │        2.48 │    4.28 │     missing │   6.63 │
│ turing_broadcast_benchmark │  1.94 ms │     4.36 │        3.15 │ missing │        26.2 │   2.09 │
│         large_single_block │ 391.0 ns │     4.43 │         2.0 │  4420.0 │        31.2 │    2.2 │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@Technici4n Technici4n self-requested a review December 7, 2025 10:37
@sunxd3 sunxd3 requested a review from yebai December 8, 2025 11:15
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

@yebai @Technici4n this is ready for a review.

Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.

One thing I don't know how to resolve: examples like x -> sum(x .* x) will hang on 1.10. Function like these uses DerivedFRule and frule_type on 1.10 hangs. A hot fix is something like 03eb3c0. it works, but creates allocations.
x -> sum(x .* x) is working fine on 1.10 in fact, I mistake "taking a bit long" as "hangs"

@yebai
Copy link
Member

yebai commented Dec 8, 2025

One thing I don't know how to resolve: examples like x -> sum(x .* x) will hang on 1.10. Function like these uses DerivedFRule and frule_type on 1.10 hangs.

This might be an actual Julia compiler bug. Can you try to create a MWE so we can report and patch over it?

Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.

Probably okay to keep fixing DI related issues in this PR if they are not major. Also happy to start a seprate PR otherwise.

@gdalle
Copy link
Collaborator

gdalle commented Dec 8, 2025

Very cool that you found a way to do this without touching DI internals @sunxd3, congrats and thanks for the effort!!!
Let me know when this is in a good enough state, I'll put it through some more extensive DI testing

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

@yebai

This might be an actual Julia compiler bug.

It might be, but whatever it is, it's fixed in 1.11. Curious what it is, though.
The tests pass, but is taking quite a long while (https://github.com/chalk-lab/Mooncake.jl/actions/runs/20024598254/job/57418981172?pr=878).

My guess is that type inference failed, produce Union{}, but handled by the Union{} fix, so the test didn't fail. see #878 (comment)

I would prefer to stop here and fix DI integration in another PR.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 8, 2025

Thanks @gdalle, unfortunately, the DI test doesn't work yet.

I wanted to keep it simple, and use Mooncake internal functions to make minimal f-o-r work.

I might start DI integration in another PR, I don't think we would need extension, but def need more investigation.

@gdalle
Copy link
Collaborator

gdalle commented Dec 8, 2025

I agree, 2nd order should work inside of Mooncake instead of requiring DI hacks. I just meant that DI has more possible tests to run than what is currently inside your integration tests

Copy link
Collaborator

@Technici4n Technici4n left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool! I will have a deeper look later this week. As always, I am a bit suspicious of new rules that need to be added. 😉

# Forward-mode primitive for _build_rule! on LazyDerivedRule.
# This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter.
# The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building.
# Only primitive in ForwardMode - reverse mode uses derived rule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is said derived rule working? If not it's likely better to define a rule that throws rather than letting Mooncake try to differentiate through it and fail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a rrule saying we don't support reverse-o-r yet

@codecov
Copy link

codecov bot commented Dec 8, 2025

Codecov Report

❌ Patch coverage is 82.85714% with 18 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rules/high_order_derivative_patches.jl 74.60% 16 Missing ⚠️
src/rules/memory.jl 83.33% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 29, 2025

This would involve adding the functions value_and_hessian!! (returns value + Hessian, and fills a gradient buffer) and prepare_hessian_cache, analogous to value_and_gradient!! and prepare_gradient_cache.

This certainly means Mooncake can have more leeway for optimizations.

I just made a commit with some frules that, at least on my local machine, passes the DITest second order tests.

Resolve conflict: move interface.jl from rrules to rules directory
@yebai
Copy link
Member

yebai commented Dec 29, 2025

I suggest we move the rules for _build_rule! and jl_genericmemory_owner into a new file named rules/high_order_derivative_patches.jl. This keeps their utility clear. We can also remove this new patch file later if no longered needed.

@gdalle
Copy link
Collaborator

gdalle commented Dec 29, 2025

From the error, it looks like the forward mode is diffing through the prep. This led me trying to overload (this is not a proposed final solution, just for discussion)

DI.inner_preparation_behavior(::AutoMooncakeForward) = DI.PrepareInnerSimple()

This seems correct to me. I introduced the inner_preparation_behavior trait in DI to handle second order with operator overloading backends like ForwardDiff, where the inner differentiation happens on Dual numbers, and must be prepared accordingly. Hence, we have

DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload()

In the case of Mooncake, we're dealing with source transformation, so DI.PrepareInnerSimple() is more appropriate, good job for finding that.
The default setting is DI.DontPrepareInner() because redoing preparation at each call is the easiest and most natural option, but here it requires differentiating through Mooncake compilation.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 29, 2025

The DI second-order test failure on 1.11 is a bit confusing to me:

sorry to tag @gdalle have you seen this before, it does seem to me OpaqueClosure can be a bit flaky on 1.11?

@gdalle
Copy link
Collaborator

gdalle commented Dec 30, 2025

Is it just hitting a 60-minute time limit on the CI job? If you take a look at DI's GitHub Actions log, you'll see that they can go above one hour, depending on what you put in them

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 30, 2025

Is it just hitting a 60-minute time limit on the CI job?

I was curious about this, according to Github the time limit is 6 hours. I don't think Mooncake has a time limit set. But this prompted me to break the second-order DITest out. Unfortunately, it still failed at around 57 mins. (Since then, I limit the DI-second-order to only run the :hessian test group as a check.)

Without digging deep (which I plan to do, potentially in a another optimization PR), I suspect the current implementation is quite taxing on the system, so the segfault might be legitimate OOM.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sunxd3 -- I am happy with the PR.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 30, 2025

@copilot Can you bump the patch version?

Copilot AI mentioned this pull request Dec 30, 2025
Copy link
Contributor

Copilot AI commented Dec 30, 2025

@sunxd3 I've opened a new pull request, #913, to work on those changes. Once the pull request is ready, I'll request review from you.

* Initial plan

* Bump version to 0.4.193

Co-authored-by: sunxd3 <5433119+sunxd3@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: sunxd3 <5433119+sunxd3@users.noreply.github.com>
@yebai yebai merged commit f937521 into main Dec 30, 2025
139 of 140 checks passed
@yebai yebai deleted the sunxd/f-o-r branch December 30, 2025 22:15
@yebai yebai mentioned this pull request Dec 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DI.hessian: Tuple field type cannot be Union{}

6 participants