-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mooncake Inside Problems #1152
base: master
Are you sure you want to change the base?
Mooncake Inside Problems #1152
Conversation
Okay. I believe this is now ready for review. |
Note: formatting failures appear unrelated to the changes in this PR. |
@@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f | |||
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg) | |||
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE) | |||
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf)) | |||
elseif autojacvec isa MooncakeVJP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChrisRackauckas, it is probably better to refactor these hard-coded branches (e.g., define an interface function that other packages can overload). It would help
- autograd tools to integrate with SciMLSensitivity easily
- move some existing autograd glue code into package extensions to avoid hard deps
It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
autograd tools to integrate with SciMLSensitivity easily
Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?
move some existing autograd glue code into package extensions to avoid hard deps
Doesn't doesn't necessarily make sense. Most of the methods are used in the default method so they would be required to be loaded by default anyways?
It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle
That's the plan when it's able to handle this case well. Currently it's not able to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?
Mooncake is getting a new forward mode (an attempt to improve ForwardDiff with GPU compatibility and fewer constraints; see here for more details), so @willtebbutt will likely need to modify these again in the near term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't that just require modifying https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R466 ? I don't see why that would touch this function and instead just dispatch on there.
I mean, if someone wants to do a refactor here that's perfectly fine. But I also don't see why it would be a high priority since it's not like new AD systems get added every year, and modifications to existing ones don't really touch this part of the code much. I would think the time would be better spent just trying to get DI up to speed than refactoring this old code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can discuss DI integration in #1040 if you want
return out | ||
end | ||
elseif isinplace | ||
function (out, u, _p, t) | ||
f(out, u, _p, t) | ||
return out | ||
end | ||
elseif !isinplace && isRODE | ||
function (out, u, _p, t, W) | ||
out .= f(u, _p, t, W) | ||
return out | ||
end | ||
else | ||
# !isinplace | ||
function (out, u, _p, t) | ||
out .= f(u, _p, t) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why force a return of out
instead of nothing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so that I can make use of Mooncake's __value_and_pullback!!
function. I might make a non-breaking PR in future if this changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I agree that it's a bit unfortunate that I'm doing this though -- if I made use of lower-level Mooncake interfaces I probably wouldn't have to do it, but this way saves on a lot of code)
src/adjoint_common.jl
Outdated
_p = Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad)) | ||
_t = Mooncake.zero_codual(t) | ||
λ_mem .= λ | ||
dy, _ = Mooncake.__value_and_pullback!!(rule, λ_mem, _pf, _dy_mem, _y, _p, _t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the allocation of the dy
necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not -- I'm working on a Mooncake PR at the minute that should make this redundant. If I manage to get that out of the door before we're ready to merge this I'll modify this PR. Otherwise I'll make a small non-breaking follow-up next week.
This looks good! Though my main question is why the vector is returned since it's just mutated and so normally we just
The docstring just needs to be added to the list in the manual part of the docs, with a quick description (which I believe is already in there) |
Thanks for the advice @ChrisRackauckas . I'm still slightly concerned that the tests don't appear to ever hit the out-of-place ODE implementations (the |
Ohhhhhhh I see. 🤦 Thanks for the pointer. I'll add Mooncake to those tests tomorrow. |
Tests added -- @ChrisRackauckas I think CI should pass if you trigger a run. |
Does this require #1151 to be merged first? |
No, those two are completely independent. |
@@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" | |||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | |||
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" | |||
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" | |||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably make Mooncake an optional dependency since it's not used in the default algorithm. We plan to make a breaking change in the near future that moves more to be optional dependencies, so this is just a bit forward thinking. However, if you do think that there's a case where we should be defaulting to Mooncake I'd be happy to see the benchmarks and coverage: right now I'd be a little scared but TBF the error messages are much nicer than Enzyme so using it in the default alg wouldn't be such a bad idea. It also has ChainRules fallbacks right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be very happy to add it as an optional dep. I'm not 100% clear how I would do so in the current design though, because I've had to insert code directly into existing functions, rather than adding methods. e.g. in adjointdiffcache
. Do you have a view on how we should make it optional?
It also has ChainRules fallbacks right?
What exactly do you mean by this? We largely don't use ChainRules because most of the rules are only useful if you don't support mutation properly. We do permit users to say "please use this rrule" though, via @from_rrule
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly do you mean by this? We largely don't use ChainRules because most of the rules are only useful if you don't support mutation properly. We do permit users to say "please use this rrule" though, via @from_rrule .
I see, so yeah it's probably not ready to be in the default set if it doesn't have coverage of NNLib, SciML tooling (LinearSolve), sparse linear algebra (ARPACK), etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be very happy to add it as an optional dep. I'm not 100% clear how I would do so in the current design though, because I've had to insert code directly into existing functions, rather than adding methods. e.g. in adjointdiffcache. Do you have a view on how we should make it optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't correct though? All of the new code is already in new functions?
If I'm not mistaken, those are the only two calls and both are already isolated. Since it's internal and you know the first is a tuple, you can just do:
function mooncake_run_ad(paramjac_config, y, p, t, λ)
error("MooncakeVJP requires Mooncake.jl is loaded. Install the package and do `using Mooncake` to use this functionality")
end
and then just change the dispatch in the extension to function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so yeah it's probably not ready to be in the default set if it doesn't have coverage of NNLib, SciML tooling (LinearSolve), sparse linear algebra (ARPACK), etc.
Completely reasonable. We actually do have coverage of NNlib + LuxLib for CPU stuff, but I've not hooked into LinearSolve or ARPACK -- perhaps that's something we could discuss once this PR is merged.
This isn't correct though? All of the new code is already in new functions?
This is a great point -- I'll do exactly this.
This looks good to go, though we should discuss the extension / default alg part before committing to taking it on as a hard dep, since currently we're looking to go in the other direction as much as possible! |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
This addresses #1105
It's not quite ready for review yet, but is almost there. In particular I need toI've verified locally that the performance on the cases in
test/adjoint.jl
is competitive with Enzyme.jl, which is nice.