-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDE support for GaussAdjoint #945
base: master
Are you sure you want to change the base?
Conversation
You will need to add up the vjp contributions from the drift and the diffusion in the integration. |
Just added a new commit where I include the diffusion vjp (I think?), but still doesn't seem to work |
What's left here? |
@frankschae Could you take a quick look? I think the basics are in there, but still debugging why results are off |
src/gauss_adjoint.jl
Outdated
if sensealg.autojacvec isa ZygoteVJP | ||
if W === nothing | ||
_dy, back = Zygote.pullback(y, p) do u, p | ||
vec(g(u, p, t)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you want to compute the following:
vec(g(u, p, t)*dW)
(sorry for the slow responses. I'm catching up rn -- got a bad cold.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this too (see most recent commit) but don't seem to get the right results
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you check if it's the correct dW that's extracted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I am a but confused because
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in your current version, you assume dW
to be constant. If you use several Gauss points for the integral (how many do you use?), you should probably compute dW appropriately for those times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they do have an order -- EM is 0.5 strong/ 1 weak 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks! Should these be the
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for now the best thing to do to check the implementation is to make sure it only uses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that makes sense, thanks. I will try to implement this now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can force the solver stop at the
@ChrisRackauckas, is there a way to access all the |
It should be there if you did |
Ah I see, this is what is stored in sol.W.u. I will try to build a linear interpolant with this |
IIRC dW already has an interpolation on it. |
There only seems to be interpolation for |
dW is the difference of the W's . You wouldn't ever want to interpolate that. |
Should I use the bridge function in |
@frankschae do you think you can find time for this? |
Not complete yet. Currently works for for gradients with respect to initial condition but not parameters.