-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding adjoints #366
Comments
Yeah, nice in principle, but I don't think it will work. Even on single compartment HH I found that one needs a step size that is 100x smaller than for the forward pass. For multicompartment I guess it is really difficult (if not impossible) to even write down a solver for the adjoint. |
Interesting, I would not have suspected that this would be the case. I saw a bunch of papers that used adjoints for pretty complex biophysical models (non-neuro though) and thought this might also work for mutlicompartment HHs. https://www.biorxiv.org/content/biorxiv/early/2018/02/28/272005.full.pdf Also from my limited testing in diffrax, the following seems to work (although it might be different for backwardEuler). import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Euler, DirectAdjoint
def αm(V, VT):
return -0.32 * (V - VT - 13.0) / (jnp.exp(-(V - VT - 13.0) / 4.0) - 1.0)
def βm(V, VT):
return 0.28 * (V - VT - 40.0) / (jnp.exp((V - VT - 40.0) / 5.0) - 1.0)
def αn(V, VT):
return -0.032 * (V - VT - 15.0) / (jnp.exp(-(V - VT - 15.0) / 5.0) - 1.0)
def βn(V, VT):
return 0.5 * jnp.exp(-(V - VT - 10.0) / 40.0)
def αh(V, VT):
return 0.128 * jnp.exp(-(V - VT - 17.0) / 18.0)
def βh(V, VT):
return 4.0 / (1.0 + jnp.exp(-(V - VT - 40.0) / 5.0))
def vector_field(t, y, args):
V, n, m, h = y
gNa, gK, gL = args
ENa, EK, EL, VT = 53, -107, -70, -60
INa = gNa * m**3 * h * (ENa - V)
IK = gK * n**4 * (EK - V)
IL = gL * (EL - V)
Iin = 3*jnp.logical_and(t > 10, t < 90)
dm = αm(V, VT) * (1 - m) - βm(V, VT) * m
dn = αn(V, VT) * (1 - n) - βn(V, VT) * n
dh = αh(V, VT) * (1 - h) - βh(V, VT) * h
dV = (Iin + INa + IK + IL)
d_y = dV, dn, dm, dh
return d_y
term = ODETerm(vector_field)
solver = Euler()
t0 = 0
t1 = 100
dt0 = 0.025
y0 = (-70.0, 0, 0, 1)
args = (25.0, 7.0, 0.1)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, adjoint=DirectAdjoint())
plt.plot(sol.ts, sol.ys[0], label="V")
plt.show()
def g(args):
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, adjoint=DirectAdjoint())
return jnp.sum(sol.ys[-1])
print("Adjoint gradient: ", jax.grad(g)(args))
def g(args):
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)
return jnp.sum(sol.ys[-1])
print("Backprop gradient: ", jax.grad(g)(args)) |
did you check the gradient against Backprop through the solver? |
Yes, now I did. Looks the same to me. |
I think adding the option to solve with ODE using adjoints rather than backprop+checkpointing could be an interesting addition to the toolbox.
https://papers.nips.cc/paper_files/paper/2021/file/adf8d7f8c53c8688e63a02bfb3055497-Paper.pdf
http://proceedings.mlr.press/v139/kidger21a/kidger21a.pdf
The text was updated successfully, but these errors were encountered: