Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding adjoints #366

Open
jnsbck opened this issue May 28, 2024 · 4 comments
Open

Adding adjoints #366

jnsbck opened this issue May 28, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@jnsbck
Copy link
Contributor

jnsbck commented May 28, 2024

I think adding the option to solve with ODE using adjoints rather than backprop+checkpointing could be an interesting addition to the toolbox.

https://papers.nips.cc/paper_files/paper/2021/file/adf8d7f8c53c8688e63a02bfb3055497-Paper.pdf
http://proceedings.mlr.press/v139/kidger21a/kidger21a.pdf

@jnsbck jnsbck added the enhancement New feature or request label May 28, 2024
@michaeldeistler
Copy link
Contributor

Yeah, nice in principle, but I don't think it will work.

Even on single compartment HH I found that one needs a step size that is 100x smaller than for the forward pass. For multicompartment I guess it is really difficult (if not impossible) to even write down a solver for the adjoint.

@jnsbck
Copy link
Contributor Author

jnsbck commented May 28, 2024

Interesting, I would not have suspected that this would be the case. I saw a bunch of papers that used adjoints for pretty complex biophysical models (non-neuro though) and thought this might also work for mutlicompartment HHs.

https://www.biorxiv.org/content/biorxiv/early/2018/02/28/272005.full.pdf
https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1010783

Also from my limited testing in diffrax, the following seems to work (although it might be different for backwardEuler).

import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, SaveAt, Euler, DirectAdjoint

def αm(V, VT):
    return -0.32 * (V - VT - 13.0) / (jnp.exp(-(V - VT - 13.0) / 4.0) - 1.0)

def βm(V, VT):
    return 0.28 * (V - VT - 40.0) / (jnp.exp((V - VT - 40.0) / 5.0) - 1.0)

def αn(V, VT):
    return -0.032 * (V - VT - 15.0) / (jnp.exp(-(V - VT - 15.0) / 5.0) - 1.0)

def βn(V, VT):
    return 0.5 * jnp.exp(-(V - VT - 10.0) / 40.0)

def αh(V, VT):
    return 0.128 * jnp.exp(-(V - VT - 17.0) / 18.0)

def βh(V, VT):
    return 4.0 / (1.0 + jnp.exp(-(V - VT - 40.0) / 5.0))

def vector_field(t, y, args):
    V, n, m, h = y
    gNa, gK, gL = args
    ENa, EK, EL, VT = 53, -107, -70, -60
    INa = gNa * m**3 * h * (ENa - V)
    IK = gK * n**4 * (EK - V)
    IL = gL * (EL - V)
    Iin = 3*jnp.logical_and(t > 10, t < 90)

    dm = αm(V, VT) * (1 - m) - βm(V, VT) * m
    dn = αn(V, VT) * (1 - n) - βn(V, VT) * n
    dh = αh(V, VT) * (1 - h) - βh(V, VT) * h
    dV = (Iin + INa + IK + IL)
    d_y = dV, dn, dm, dh
    return d_y

term = ODETerm(vector_field)
solver = Euler()
t0 = 0
t1 = 100
dt0 = 0.025

y0 = (-70.0, 0, 0, 1)
args = (25.0, 7.0, 0.1)
saveat = SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, adjoint=DirectAdjoint())

plt.plot(sol.ts, sol.ys[0], label="V")
plt.show()

def g(args):
    sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, adjoint=DirectAdjoint())
    return jnp.sum(sol.ys[-1])
print("Adjoint gradient: ", jax.grad(g)(args))

def g(args):
    sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args, saveat=saveat)
    return jnp.sum(sol.ys[-1])
print("Backprop gradient: ", jax.grad(g)(args))

@michaeldeistler
Copy link
Contributor

did you check the gradient against Backprop through the solver?

@jnsbck
Copy link
Contributor Author

jnsbck commented Jun 3, 2024

Yes, now I did. Looks the same to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants