-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential reasons for NaN
during training
#317
Comments
Do you have a hunch, where we need so much precision that float32 is not sufficient? Sounds like a suitable transform of different units might be able to take care of this and could save valuable GPU memory. If you have an example, I could look into it. I just saw NEURON uses double as well. |
This notebook will give NaN if either one uses float32 or if one changes the |
I do not have a clear idea of why the In our channels, e.g. in the sodium part of
Especially when using |
A very simple reproducible example should look sth like this (but you need to turn of the comp = jx.Compartment()
# Strong negative current
current = jx.step_current(0.0, 10.0, -10.0, dt, t_max)
comp.stimulate(current)
comp.insert(HH())
comp.record() |
That being said, I have also observed |
Interesting, might look into this. Thanks for all the hints. |
I spend some time looking into this
I don't exactly know why one get's NaNs yet, but I was at least able to trace it to the checkpointing. Here is a minimally reproducing example that you can run in the notebook above if you have set up the network: config.update("jax_enable_x64", False)
# expose checkpointing and stim_duration as kwargs in loss
def cross_entropy_loss(opt_params, image, label, t_dur=2.301, checkpoint_lengths=None):
params = transform.forward(opt_params)
def simulate(params, image):
tau = 500.0
i_amp = 10.0 / tau
currents = jx.datapoint_to_step_currents(0.1, 1.0, i_amp*image, dt, t_dur)
data_stimuli = net[range(784), 0, 0].data_stimulate(currents, None)
return jx.integrate(net, params=params, data_stimuli=data_stimuli, tridiag_solver="thomas", checkpoint_lengths=checkpoint_lengths)
vs = simulate(params, image)
prediction = vs[:, -1]
prediction += 60.0
prediction /= 10.0
log_prob = prediction[label] - logsumexp(prediction)
return -log_prob
# For a set of parameters
with open(f"results/parameters/tmp_state_2.pkl", "rb") as handle:
opt_params, batch = pickle.load(handle)
# ... on a specific training pair
image_batch, label_batch = tfds.as_numpy(batch)
image, label = image_batch[1], label_batch[1]
# computing the cross_entropy_loss works
l = cross_entropy_loss(opt_params, image, label, t_dur=2.301)
# BUT: Computing the gradient fails, raising a `FloatingPointError`
try:
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label)
except FloatingPointError as e:
print(e)
# Either shortening the time or
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.300)
# ... using multiple checkpoints fixes this issue
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.301, checkpoint_lengths=[103,4]) The model_checkpoint, where this happens is the following, which you also get, by running the training in the notebook. I have not looked at why changing the jaxley/jaxley/utils/jax_utils.py Lines 65 to 66 in 72278f8
lax.scan
Looking forward to hear your thoughts on this. |
I did even more I looked into one particular example (see plot). When I running the jitted version of jx.integrate, no NaNs are returned. However, running the unjitted version, returns NaNs after some point. As pointed out above, the issue was somewhere in lax.scan . Running lax.scan on the unjitted body_fun, yields NaNs, while, running lax.scan until the last good output, and then running body_fun once, does not yield NaNs in the output, while running lax.scan or jit(body_fun) on the problematic inputs, does (thanks @manuelgloeckler for the pointer)! Thanks @michaeldeistler for the hint about tridiax solvers, because this issue is only present for the implicit solve. Have not gotten further than this yet though. I have not checked with unroll=True since it takes ages. you can find the notebook here if you want to have a look. |
From diffrax |
Below are several reasons for experiencing
NaN
, ranked from most to least likely. If you have tried out all of the options below, we would be happy to receive a bug report (as an issue on Github) with the following information:NaN
occur during training or during simulation?1. You are using
float32
.In most cases where I have encountered
NaN
so far,NaN
were resolved by switching tofloat64
:2. The mechanisms (channels or synapses) are unstable.
The most likely reason for this is that the channels you are using contain a
jnp.exp()
whose input gets very large (>100.0) such that the result will beinf
. For example, our initial implementation of channels had givenNaN
when a strong negative stimulus was inserted, such that the neuron was very strongly hyperpolarized (voltages below -200mV). You can prevent this by chaningjnp.exp()
tosave_exp
fromjaxley.solver_gate.py
.3. The
ParamTransform
saturates:To debug this, print the maximum value of the transformed params after every gradient update:
If
max_val > 50.0
then you are probably in trouble.The text was updated successfully, but these errors were encountered: