Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential reasons for NaN during training #317

Open
michaeldeistler opened this issue Apr 4, 2024 · 10 comments
Open

Potential reasons for NaN during training #317

michaeldeistler opened this issue Apr 4, 2024 · 10 comments
Labels
documentation Improvements or additions to documentation

Comments

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Apr 4, 2024

Below are several reasons for experiencing NaN, ranked from most to least likely. If you have tried out all of the options below, we would be happy to receive a bug report (as an issue on Github) with the following information:

  • are you using your own channel and/or synapse models? If yes, ideally share your channel and synapse models.
  • do the NaN occur during training or during simulation?
  • are you using point neurons or neurons with morphology? If the neurons have a morphology, did you define the morphology from scratch or did you read it from an SWC file?

1. You are using float32.

In most cases where I have encountered NaN so far, NaN were resolved by switching to float64:

from jax import config
config.update("jax_enable_x64", True)

2. The mechanisms (channels or synapses) are unstable.

The most likely reason for this is that the channels you are using contain a jnp.exp() whose input gets very large (>100.0) such that the result will be inf. For example, our initial implementation of channels had given NaN when a strong negative stimulus was inserted, such that the neuron was very strongly hyperpolarized (voltages below -200mV). You can prevent this by chaning jnp.exp() to save_exp from jaxley.solver_gate.py.

3. The ParamTransform saturates:

transform = jx.ParamTransform(lowers={"x": -1.0}, uppers={"x": 1.0})
def tf(params):
    return jnp.sum(transform.forward(params)[0]["x"])

# Interestingly, only negative values return `NaN` gradient, positive values return `0`.
p = [{"x": jnp.asarray([-100.0])}]

tf_grad_fn = value_and_grad(tf)
print(tf(p))
print(tf_grad_fn(p))

To debug this, print the maximum value of the transformed params after every gradient update:

parameters = net.get_parameters()
opt_params = transform.inverse(parameters)

leaves, _ = tree_util.tree_flatten(x)
max_val = jnp.max([jnp.abs(leaf) for leaf in leaves])

If max_val > 50.0 then you are probably in trouble.

@michaeldeistler michaeldeistler added bug Something isn't working documentation Improvements or additions to documentation and removed bug Something isn't working labels Apr 5, 2024
@jnsbck
Copy link
Contributor

jnsbck commented Apr 15, 2024

Do you have a hunch, where we need so much precision that float32 is not sufficient? Sounds like a suitable transform of different units might be able to take care of this and could save valuable GPU memory. If you have an example, I could look into it. I just saw NEURON uses double as well.

@michaeldeistler
Copy link
Contributor Author

This notebook will give NaN if either one uses float32 or if one changes the max_value of save_exp e.g. to 100 or even more

@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented Apr 15, 2024

I do not have a clear idea of why the NaN happen. The most concrete thing I observed so far is the following (and it is the reason I introduced the save_exp:

In our channels, e.g. in the sodium part of HH(), we often take an exp(). This value can become huge if the neuron is depolarized beyond the normal range of voltages. For example if v=-300mV, one can get gigantic values in this exp:

v = -300
x = -(v + 35) / 10) + 1 = 27.5
# exp(x) -> huge

Especially when using float32 these huge values are a problem and make things unstable. Such strong negative values can happen if extremely strong inhibitory synapses are being learned.

@michaeldeistler
Copy link
Contributor Author

A very simple reproducible example should look sth like this (but you need to turn of the save_exp from the channels; I did not test this right now):

comp = jx.Compartment()

# Strong negative current
current = jx.step_current(0.0, 10.0, -10.0, dt, t_max)
comp.stimulate(current)

comp.insert(HH())
comp.record()

@michaeldeistler
Copy link
Contributor Author

That being said, I have also observed NaN in single neurons (with morph detail) when using float32, but I don't have an example right now.

@jnsbck
Copy link
Contributor

jnsbck commented Apr 15, 2024

Interesting, might look into this. Thanks for all the hints.

@jnsbck
Copy link
Contributor

jnsbck commented Apr 17, 2024

I took a deep dive into this, using your example.

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 10.0  # ms

time_vec = jnp.arange(0.0, t_max+dt, dt)

comp = jx.Compartment()
comp.stimulate(current)
comp.insert(ModHH(clip_exp=clip_exp))

Leads to NaNs at some point, with higher precisions essentially just delaying this.
image
Save exponential dont have this problem, since they prevent the floating point format from topping out after a while.

Here is what it essentially comes down to:

def solve_gate_exponential(
    x: jnp.ndarray,
    dt: float,
    alpha: jnp.ndarray,
    beta: jnp.ndarray,
):
    tau = 1 / (alpha + beta) # <--- alpha can become inf here, hence tau becomes 0
    xinf = alpha * tau # <--- this means tau*alpha = inf*0, which is nan
    return exponential_euler(x, dt, xinf, tau)

The crux of the issue I think is alpha topping out the floating point format in the gates already, as you pointed out as well.

EDIT: Idea for better safe_exp:

def save_exp(x):
    """Clip the input below maximum value that dtype can support."""
    # floored log of maximum value that can be represented by the dtype.
    max_value = 88.0 if x.dtype == np.float32 else 709.0
    x = jnp.clip(x, a_max=max_value)
    return jnp.exp(x)

max_value can be obtained from: max_dtype_val = np.finfo(dtype).max; max_value = np.floor(np.log(max_dtype_val))

@jnsbck
Copy link
Contributor

jnsbck commented Apr 19, 2024

I spend some time looking into this

This notebook will give NaN if either one uses float32 or if one changes the max_value of save_exp e.g. to 100 or even more

I don't exactly know why one get's NaNs yet, but I was at least able to trace it to the checkpointing. Here is a minimally reproducing example that you can run in the notebook above if you have set up the network:

config.update("jax_enable_x64", False)

# expose checkpointing and stim_duration as kwargs in loss
def cross_entropy_loss(opt_params, image, label, t_dur=2.301, checkpoint_lengths=None):
    params = transform.forward(opt_params)

    def simulate(params, image):
        tau = 500.0
        i_amp = 10.0 / tau
        currents = jx.datapoint_to_step_currents(0.1, 1.0, i_amp*image, dt, t_dur)
        data_stimuli = net[range(784), 0, 0].data_stimulate(currents, None)
        return jx.integrate(net, params=params, data_stimuli=data_stimuli, tridiag_solver="thomas", checkpoint_lengths=checkpoint_lengths)

    vs = simulate(params, image)
    prediction = vs[:, -1]
    prediction += 60.0
    prediction /= 10.0
    log_prob = prediction[label] - logsumexp(prediction)
    return -log_prob

# For a set of parameters
with open(f"results/parameters/tmp_state_2.pkl", "rb") as handle:
    opt_params, batch = pickle.load(handle)

# ... on a specific training pair
image_batch, label_batch = tfds.as_numpy(batch)
image, label = image_batch[1], label_batch[1]

# computing the cross_entropy_loss works
l = cross_entropy_loss(opt_params, image, label, t_dur=2.301)

# BUT: Computing the gradient fails, raising a `FloatingPointError`
try:
    grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label)
except FloatingPointError as e:
    print(e)

# Either shortening the time or 
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.300)
# ... using multiple checkpoints fixes this issue
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.301, checkpoint_lengths=[103,4])

The model_checkpoint, where this happens is the following, which you also get, by running the training in the notebook.
tmp_state_2.zip

I have not looked at why changing the save_exp affects this, but its somehow linked to checkpointing, as adding multiple levels seems to fix this issue. Also, the FloatingPointError originates here

return scan_fn(f, init, xs, lengths[0])
and happens somewhere in lax.scan

Looking forward to hear your thoughts on this.

@jnsbck
Copy link
Contributor

jnsbck commented Apr 24, 2024

I did even more NaN chasing. Prompted by @manuelgloeckler, who found that https://github.com/mackelab/jaxley_experiments/blob/main/nex/smc/smc_allen_experimental.py yielded NaNs in the simluations sometimes.

I looked into one particular example (see plot).
image

When I running the jitted version of jx.integrate, no NaNs are returned. However, running the unjitted version, returns NaNs after some point. As pointed out above, the issue was somewhere in lax.scan . Running lax.scan on the unjitted body_fun, yields NaNs, while, running lax.scan until the last good output, and then running body_fun once, does not yield NaNs in the output, while running lax.scan or jit(body_fun) on the problematic inputs, does (thanks @manuelgloeckler for the pointer)! Thanks @michaeldeistler for the hint about tridiax solvers, because this issue is only present for the implicit solve. Have not gotten further than this yet though.

I have not checked with unroll=True since it takes ages.

you can find the notebook here if you want to have a look.
https://github.com/mackelab/jaxley_experiments/blob/fix_nans/nex/smc/nan_issues.ipynb

@jnsbck
Copy link
Contributor

jnsbck commented May 28, 2024

From diffrax

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants