Conversation
This comment was marked as outdated.
This comment was marked as outdated.
|
Mooncake.jl documentation for PR #833 is available at: |
|
Performance Ratio: |
This reverts commit 933efc4.
cf3ca03 to
c383c05
Compare
|
this is still very drafty, but I think it's somewhat ready for a quick look. @Technici4n sorry for tagging, but maybe you are the best arbiter right now. if you do find some time to look over, please don't hold back |
|
A more complicated example would blow up the compilation. I try to push a little bit, but I couldn't find the exact boundary of making a more complex function work and marking too many things non-differentiable |
test/ext/differentiation_interface/differentiation_interface.jl
Outdated
Show resolved
Hide resolved
|
Can't say this is ready, but it seems to pass DITest's Hessian test |
|
@gdalle might want to take a look given that this PR interacts deeply with DI. No pressure. |
|
Thanks for the ping, I wasn't aware of this PR. |
|
Also, why would it be necessary to define arithmetic on Dual numbers? Isn't that what frules are for? |
| @@ -0,0 +1,147 @@ | |||
| module MooncakeDifferentiationInterfaceExt | |||
There was a problem hiding this comment.
Let's move this file to the DI repo.
There was a problem hiding this comment.
Until @sunxd3 explains why it is necessary, I'm not convinced this file should be anywhere.
| import DifferentiationInterface as DI | ||
|
|
||
| # Mark shuffled_gradient as forward-mode primitive to avoid expensive type inference hang. | ||
| # This prevents build_frule from trying to derive rules for the complex gradient closure. |
There was a problem hiding this comment.
Can you explain clearly with a small example, why Mooncake/Julia struggle with shuffled_gradient.
| return Dual{@isdefined(P) ? Type{P} : typeof(x),NoTangent}(x, dx) | ||
| end | ||
|
|
||
| # Dual of numeric types is self-tangent |
There was a problem hiding this comment.
Can you explain when tangent_type(::Type{Dual}) will be called during forward-over-reverse? It is not self-evident why this is needed.
|
Thanks @gdalle for taking a look and the acute observations! In short, This is very much a Hail Mary—I had a ton of trouble trying to nail down the correct set of frules that would let forward mode differentiate through reverse mode without blowing up compilation. That wasn't productive, so I thought: maybe reverse-over-forward could work? I'm not convinced what's here is correct and the right approach, which is why I haven't written it up in detail yet. The (old) title is misleading for the current implementation (I started by trying to work out forward-over-reverse, but that's not what the current code does). And the DI extension is not necessary—I was hacking the interface to pretend to do forward-over-reverse for testing purposes. And I should say I am very very far from being fluent in Differentiable programming, so this is mixed with LLM's ideas. The idea is making To see how it works, we pass With Forward pass: Backward pass So Let me know if this make sense and worth pursuing at all. |
|
This is not the way to go and commit/conversation histories are getting too confusing, I started #878 instead. |
Ref: #826