Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mooncake Inside Problems #1152

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

willtebbutt
Copy link

@willtebbutt willtebbutt commented Nov 26, 2024

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

This addresses #1105

It's not quite ready for review yet, but is almost there. In particular I need to

  • reduce some code copying
  • apply the SciML style via JuliaFormatter .
  • update the docs -- @ChrisRackauckas I could do with your advice here regarding what I should update.
  • extend the rule to handle non in-place ODE updates. @ChrisRackauckas I seem to be able to pass the tests I've currently extended without actually having to handle out-of-place functions. Am I missing something in the test suite?

I've verified locally that the performance on the cases in test/adjoint.jl is competitive with Enzyme.jl, which is nice.

@willtebbutt
Copy link
Author

Okay. I believe this is now ready for review.

@willtebbutt
Copy link
Author

Note: formatting failures appear unrelated to the changes in this PR.

@@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
elseif autojacvec isa MooncakeVJP
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas, it is probably better to refactor these hard-coded branches (e.g., define an interface function that other packages can overload). It would help

  • autograd tools to integrate with SciMLSensitivity easily
  • move some existing autograd glue code into package extensions to avoid hard deps

It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autograd tools to integrate with SciMLSensitivity easily

Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?

move some existing autograd glue code into package extensions to avoid hard deps

Doesn't doesn't necessarily make sense. Most of the methods are used in the default method so they would be required to be loaded by default anyways?

It might also help to switch to DI where possible to avoid duplicate glue code in the ecosystem. @gdalle

That's the plan when it's able to handle this case well. Currently it's not able to.

Copy link

@yebai yebai Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that really a high priority right now? How many more autograd packages are you going to write this year that will be useful?

Mooncake is getting a new forward mode (an attempt to improve ForwardDiff with GPU compatibility and fewer constraints; see here for more details), so @willtebbutt will likely need to modify these again in the near term.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that just require modifying https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R466 ? I don't see why that would touch this function and instead just dispatch on there.

I mean, if someone wants to do a refactor here that's perfectly fine. But I also don't see why it would be a high priority since it's not like new AD systems get added every year, and modifications to existing ones don't really touch this part of the code much. I would think the time would be better spent just trying to get DI up to speed than refactoring this old code.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss DI integration in #1040 if you want

@yebai yebai mentioned this pull request Nov 28, 2024
8 tasks
Comment on lines +516 to +532
return out
end
elseif isinplace
function (out, u, _p, t)
f(out, u, _p, t)
return out
end
elseif !isinplace && isRODE
function (out, u, _p, t, W)
out .= f(u, _p, t, W)
return out
end
else
# !isinplace
function (out, u, _p, t)
out .= f(u, _p, t)
return out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why force a return of out instead of nothing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so that I can make use of Mooncake's __value_and_pullback!! function. I might make a non-breaking PR in future if this changes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I agree that it's a bit unfortunate that I'm doing this though -- if I made use of lower-level Mooncake interfaces I probably wouldn't have to do it, but this way saves on a lot of code)

_p = Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad))
_t = Mooncake.zero_codual(t)
λ_mem .= λ
dy, _ = Mooncake.__value_and_pullback!!(rule, λ_mem, _pf, _dy_mem, _y, _p, _t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the allocation of the dy necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not -- I'm working on a Mooncake PR at the minute that should make this redundant. If I manage to get that out of the door before we're ready to merge this I'll modify this PR. Otherwise I'll make a small non-breaking follow-up next week.

@ChrisRackauckas
Copy link
Member

This looks good! Though my main question is why the vector is returned since it's just mutated and so normally we just return nothing. If this is an interface thing where it's non-allocating but still wants the vector returned then that's fine but definitely an odd style.

update the docs -- @ChrisRackauckas I could do with your advice here regarding what I should update.

The docstring just needs to be added to the list in the manual part of the docs, with a quick description (which I believe is already in there)

@willtebbutt
Copy link
Author

Thanks for the advice @ChrisRackauckas .

I'm still slightly concerned that the tests don't appear to ever hit the out-of-place ODE implementations (the f(y, p, t) versions), but I'm happy for this to be merged if you're happy.

@willtebbutt
Copy link
Author

Ohhhhhhh I see. 🤦 Thanks for the pointer.

I'll add Mooncake to those tests tomorrow.

@willtebbutt
Copy link
Author

Tests added -- @ChrisRackauckas I think CI should pass if you trigger a run.

@yebai
Copy link

yebai commented Dec 2, 2024

Does this require #1151 to be merged first?

@ChrisRackauckas
Copy link
Member

No, those two are completely independent.

@@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make Mooncake an optional dependency since it's not used in the default algorithm. We plan to make a breaking change in the near future that moves more to be optional dependencies, so this is just a bit forward thinking. However, if you do think that there's a case where we should be defaulting to Mooncake I'd be happy to see the benchmarks and coverage: right now I'd be a little scared but TBF the error messages are much nicer than Enzyme so using it in the default alg wouldn't be such a bad idea. It also has ChainRules fallbacks right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be very happy to add it as an optional dep. I'm not 100% clear how I would do so in the current design though, because I've had to insert code directly into existing functions, rather than adding methods. e.g. in adjointdiffcache. Do you have a view on how we should make it optional?

It also has ChainRules fallbacks right?

What exactly do you mean by this? We largely don't use ChainRules because most of the rules are only useful if you don't support mutation properly. We do permit users to say "please use this rrule" though, via @from_rrule .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly do you mean by this? We largely don't use ChainRules because most of the rules are only useful if you don't support mutation properly. We do permit users to say "please use this rrule" though, via @from_rrule .

I see, so yeah it's probably not ready to be in the default set if it doesn't have coverage of NNLib, SciML tooling (LinearSolve), sparse linear algebra (ARPACK), etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be very happy to add it as an optional dep. I'm not 100% clear how I would do so in the current design though, because I've had to insert code directly into existing functions, rather than adding methods. e.g. in adjointdiffcache. Do you have a view on how we should make it optional?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct though? All of the new code is already in new functions?

https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R466-R471

https://github.com/SciML/SciMLSensitivity.jl/pull/1152/files#diff-1a15b4b5711133c125548ef7f1ca88f761bb124cffc8bfde8c13336968aaccd6R534-R541

If I'm not mistaken, those are the only two calls and both are already isolated. Since it's internal and you know the first is a tuple, you can just do:

function mooncake_run_ad(paramjac_config, y, p, t, λ)
   error("MooncakeVJP requires Mooncake.jl is loaded. Install the package and do `using Mooncake` to use this functionality")
end

and then just change the dispatch in the extension to function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so yeah it's probably not ready to be in the default set if it doesn't have coverage of NNLib, SciML tooling (LinearSolve), sparse linear algebra (ARPACK), etc.

Completely reasonable. We actually do have coverage of NNlib + LuxLib for CPU stuff, but I've not hooked into LinearSolve or ARPACK -- perhaps that's something we could discuss once this PR is merged.

This isn't correct though? All of the new code is already in new functions?

This is a great point -- I'll do exactly this.

@ChrisRackauckas
Copy link
Member

This looks good to go, though we should discuss the extension / default alg part before committing to taking it on as a hard dep, since currently we're looking to go in the other direction as much as possible!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants