Enable second-order differentiation via forward-over-reverse#878
Enable second-order differentiation via forward-over-reverse#878
Conversation
|
Mooncake.jl documentation for PR #878 is available at: |
|
Performance Ratio: |
|
@yebai @Technici4n this is ready for a review. Integration with DI is not complete, I have some guess of reasons, but not certain yet. The PR is probably big enough.
|
This might be an actual Julia compiler bug. Can you try to create a MWE so we can report and patch over it?
Probably okay to keep fixing DI related issues in this PR if they are not major. Also happy to start a seprate PR otherwise. |
test/ext/differentiation_interface/differentiation_interface.jl
Outdated
Show resolved
Hide resolved
|
Very cool that you found a way to do this without touching DI internals @sunxd3, congrats and thanks for the effort!!! |
It might be, but whatever it is, it's fixed in 1.11. Curious what it is, though. My guess is that type inference failed, produce I would prefer to stop here and fix DI integration in another PR. |
|
Thanks @gdalle, unfortunately, the DI test doesn't work yet. I wanted to keep it simple, and use Mooncake internal functions to make minimal f-o-r work. I might start DI integration in another PR, I don't think we would need extension, but def need more investigation. |
|
I agree, 2nd order should work inside of Mooncake instead of requiring DI hacks. I just meant that DI has more possible tests to run than what is currently inside your integration tests |
Technici4n
left a comment
There was a problem hiding this comment.
This is really cool! I will have a deeper look later this week. As always, I am a bit suspicious of new rules that need to be added. 😉
src/interpreter/reverse_mode.jl
Outdated
| # Forward-mode primitive for _build_rule! on LazyDerivedRule. | ||
| # This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter. | ||
| # The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building. | ||
| # Only primitive in ForwardMode - reverse mode uses derived rule. |
There was a problem hiding this comment.
Is said derived rule working? If not it's likely better to define a rule that throws rather than letting Mooncake try to differentiate through it and fail.
There was a problem hiding this comment.
added a rrule saying we don't support reverse-o-r yet
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
This certainly means Mooncake can have more leeway for optimizations. I just made a commit with some frules that, at least on my local machine, passes the DITest second order tests. |
Resolve conflict: move interface.jl from rrules to rules directory
test/ext/differentiation_interface/differentiation_interface.jl
Outdated
Show resolved
Hide resolved
|
I suggest we move the rules for |
This seems correct to me. I introduced the DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload()In the case of Mooncake, we're dealing with source transformation, so |
|
The DI second-order test failure on 1.11 is a bit confusing to me:
sorry to tag @gdalle have you seen this before, it does seem to me OpaqueClosure can be a bit flaky on 1.11? |
|
Is it just hitting a 60-minute time limit on the CI job? If you take a look at DI's GitHub Actions log, you'll see that they can go above one hour, depending on what you put in them |
I was curious about this, according to Github the time limit is 6 hours. I don't think Mooncake has a time limit set. But this prompted me to break the second-order DITest out. Unfortunately, it still failed at around 57 mins. (Since then, I limit the DI-second-order to only run the Without digging deep (which I plan to do, potentially in a another optimization PR), I suspect the current implementation is quite taxing on the system, so the segfault might be legitimate OOM. |
|
@copilot Can you bump the patch version? |
* Initial plan * Bump version to 0.4.193 Co-authored-by: sunxd3 <5433119+sunxd3@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: sunxd3 <5433119+sunxd3@users.noreply.github.com>
Ref: #826
This PR aims to enable forward-over-reverse for second order derivatives. Many of the necessary plumbings are already implemented before this PR. This PR smooths out the remaining edges.
Fixes:
zero_tangent(pb_mc)was called.Union{}type issue (this is somewhat 1.10 specific)Union{}can showup when type inference fails, the approach here is to replace them withAnyCore.memorynewLazyDerivedRule: avoid diff throughget_interpreter(jl_get_world_counter)jl_alloc_array_1d/2d/3dPointers to some possible future work
second_derivativeDITest fails on 1.11 CI (pass locally on laptop) Enable second-order differentiation via forward-over-reverse #878 (comment)DI.inner_preparation_behaviorEnable second-order differentiation via forward-over-reverse #878 (comment)Closes #632