Skip to content

Fix Hessian#833

Closed
sunxd3 wants to merge 27 commits intomainfrom
sunxd/fix_hessian_mwe
Closed

Fix Hessian#833
sunxd3 wants to merge 27 commits intomainfrom
sunxd/fix_hessian_mwe

Conversation

@sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented Nov 2, 2025

Ref: #826

@codecov

This comment was marked as outdated.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 2, 2025

Mooncake.jl documentation for PR #833 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR833/

@github-actions
Copy link
Contributor

github-actions bot commented Nov 2, 2025

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                      Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                     String │   String │   String │      String │  String │      String │ String │
├────────────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│                   sum_1000 │ 100.0 ns │      1.8 │         1.9 │     1.1 │         5.5 │   8.21 │
│                  _sum_1000 │ 941.0 ns │     6.64 │        1.01 │  1410.0 │        33.8 │   1.08 │
│               sum_sin_1000 │  6.54 μs │     2.48 │        1.36 │    1.69 │        10.6 │   2.17 │
│              _sum_sin_1000 │  5.27 μs │     3.03 │        2.18 │   263.0 │        13.2 │   2.47 │
│                   kron_sum │ 363.0 μs │     41.1 │        2.52 │    4.97 │       192.0 │   13.3 │
│              kron_view_sum │ 358.0 μs │     38.4 │        3.26 │    10.8 │       207.0 │   6.23 │
│      naive_map_sin_cos_exp │  2.16 μs │      2.4 │        1.39 │ missing │        7.08 │   2.34 │
│            map_sin_cos_exp │  2.12 μs │     2.71 │        1.45 │    1.49 │        6.08 │   2.91 │
│      broadcast_sin_cos_exp │  2.26 μs │     2.34 │        1.38 │    2.37 │        1.46 │   2.23 │
│                 simple_mlp │ 209.0 μs │     5.95 │        2.86 │    1.69 │        11.0 │   3.24 │
│                     gp_lml │ 255.0 μs │     8.18 │        2.07 │     3.8 │     missing │   5.67 │
│ turing_broadcast_benchmark │  1.74 ms │     4.87 │        3.47 │ missing │        29.4 │   2.29 │
│         large_single_block │ 380.0 ns │     4.53 │        2.03 │  4230.0 │        31.8 │   2.24 │
└────────────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@sunxd3 sunxd3 force-pushed the sunxd/fix_hessian_mwe branch 2 times, most recently from cf3ca03 to c383c05 Compare November 21, 2025 10:12
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 21, 2025

this is still very drafty, but I think it's somewhat ready for a quick look.

@Technici4n sorry for tagging, but maybe you are the best arbiter right now. if you do find some time to look over, please don't hold back

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Nov 21, 2025

A more complicated example would blow up the compilation.

I try to push a little bit, but I couldn't find the exact boundary of making a more complex function work and marking too many things non-differentiable

@sunxd3 sunxd3 changed the title Fix Hessian (forward-over-reverse) by isolating meta paths and handling exception nodes Fix Hessian (forward-over-reverse) Dec 4, 2025
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 4, 2025

Can't say this is ready, but it seems to pass DITest's Hessian test

@yebai yebai requested a review from Technici4n December 4, 2025 17:23
@yebai
Copy link
Member

yebai commented Dec 4, 2025

@gdalle might want to take a look given that this PR interacts deeply with DI. No pressure.

@gdalle
Copy link
Collaborator

gdalle commented Dec 4, 2025

Thanks for the ping, I wasn't aware of this PR.
I don't have a lot of bandwidth at the moment but I don't think it is a good idea to have Mooncake depend on DI, given that DI already depends on Mooncake. This would make things rather confusing for users. @sunxd3 can you explain why the fixes necessarily involve DI functions? Why isn't second-order autodiff fixable within Mooncake itself?

@gdalle
Copy link
Collaborator

gdalle commented Dec 4, 2025

Also, why would it be necessary to define arithmetic on Dual numbers? Isn't that what frules are for?

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @sunxd3 — I’ve added a few points of clarification below.

@@ -0,0 +1,147 @@
module MooncakeDifferentiationInterfaceExt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this file to the DI repo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Until @sunxd3 explains why it is necessary, I'm not convinced this file should be anywhere.

import DifferentiationInterface as DI

# Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang.
# This prevents build_frule from trying to derive rules for the complex gradient closure.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain clearly with a small example, why Mooncake/Julia struggle with shuffled_gradient.

return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx)
end

# Dual of numeric types is self-tangent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain when tangent_type(::Type{Dual}) will be called during forward-over-reverse? It is not self-evident why this is needed.

@sunxd3 sunxd3 removed the request for review from Technici4n December 5, 2025 14:56
@sunxd3 sunxd3 changed the title Fix Hessian (forward-over-reverse) Fix Hessian Dec 5, 2025
@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 5, 2025

Thanks @gdalle for taking a look and the acute observations! In short, current code is more like reverse-over-forward (conceptually, but not implemented as such), but I am not sure if it's standard at all. current code use reverse mode to compute hvp that mathematically equivalent to reverse-over-forward, but computed using one pass of reverse mode only.

This is very much a Hail Mary—I had a ton of trouble trying to nail down the correct set of frules that would let forward mode differentiate through reverse mode without blowing up compilation. That wasn't productive, so I thought: maybe reverse-over-forward could work? I'm not convinced what's here is correct and the right approach, which is why I haven't written it up in detail yet.

The (old) title is misleading for the current implementation (I started by trying to work out forward-over-reverse, but that's not what the current code does). And the DI extension is not necessary—I was hacking the interface to pretend to do forward-over-reverse for testing purposes.

And I should say I am very very far from being fluent in Differentiable programming, so this is mixed with LLM's ideas.


The idea is making Dual self-tangent, then compute hvp using a single reverse-mode pass. Dual arithmetic is needed because the forward pass of reverse mode operates on Dual inputs.

To see how it works, we pass Dual.(x, v) into reverse mode—the forward pass uses Dual arithmetic (this is why all the arithmetic operations on Duals are needed), and the backward pass uses rrules for Duals. The result is $\text{Dual}.(\nabla f(x), Hv)$—the tangent component gives the Hvp.

With $f(x) = x_1^2 + x_1 x_2$:

Forward pass:

$$ \begin{aligned} t_1 &= \text{Dual}(x_1, v_1)^2 = \text{Dual}(x_1^2,; 2x_1 v_1) \\ t_2 &= \text{Dual}(x_1, v_1) \cdot \text{Dual}(x_2, v_2) = \text{Dual}(x_1 x_2,; x_1 v_2 + v_1 x_2) \end{aligned} $$

Backward pass

$$ \begin{aligned} \bar{x}_1 &= 2 \cdot \text{Dual}(x_1, v_1) \cdot \text{Dual}(1, 0) + \text{Dual}(x_2, v_2) \cdot \text{Dual}(1, 0) \\ &= \text{Dual}(2x_1 + x_2,; 2v_1 + v_2) \\ \bar{x}_2 &= \text{Dual}(x_1, v_1) \cdot \text{Dual}(1, 0) = \text{Dual}(x_1,; v_1) \end{aligned} $$

So $Hv = \text{tangent}.([\bar{x}_1, \bar{x}_2])$


Let me know if this make sense and worth pursuing at all.

@sunxd3
Copy link
Collaborator Author

sunxd3 commented Dec 6, 2025

This is not the way to go and commit/conversation histories are getting too confusing, I started #878 instead.

@sunxd3 sunxd3 closed this Dec 6, 2025
@yebai yebai deleted the sunxd/fix_hessian_mwe branch December 20, 2025 17:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants