Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tanh numerical instability #7

Closed
ikostrikov opened this issue Apr 16, 2021 · 4 comments
Closed

Tanh numerical instability #7

ikostrikov opened this issue Apr 16, 2021 · 4 comments

Comments

@ikostrikov
Copy link

ikostrikov commented Apr 16, 2021

Tanh bijector with Transformed seems to be less numerically stable than their tfp counterparts:

import distrax
import jax
import jax.numpy as jnp

BATCH_SIZE = 8
means = jnp.zeros((BATCH_SIZE,))
log_stds = jnp.full_like(means, 10)

base_dist = distrax.Normal(
    loc=means, scale=jnp.exp(log_stds))

dist = distrax.Transformed(distribution=base_dist,
                            bijector=distrax.Tanh())

@jax.jit
def test(seed):
  samples = dist.sample(seed=seed)
  return dist.log_prob(samples)

print(test(seed=jax.random.PRNGKey(1)))
print(test(seed=jax.random.PRNGKey(2)))

Prints nans.

from tensorflow_probability.substrates import jax as tfp
import jax
import jax.numpy as jnp

tfd = tfp.distributions
tfb = tfp.bijectors

BATCH_SIZE = 8
means = jnp.zeros((BATCH_SIZE,))
log_stds = jnp.full_like(means, 10)

base_dist = tfd.Normal(loc=means, scale=jnp.exp(log_stds))
dist = tfd.TransformedDistribution(distribution=base_dist,
                                    bijector=tfb.Tanh())

@jax.jit
def test(seed):
  samples = dist.sample(seed=seed)
  return dist.log_prob(samples)

print(test(seed=jax.random.PRNGKey(1)))
print(test(seed=jax.random.PRNGKey(2)))

Prints regular numbers.

@franrruiz
Copy link
Collaborator

Thank you for raising this. We will look into it. In the meantime, you can probably get rid of the NaNs by using sample_and_log_prob instead:

@jax.jit
def test(seed):
  _, log_prob = dist.sample_and_log_prob(seed=seed)
  return log_prob

@franrruiz
Copy link
Collaborator

Closing this issue - When the output of the bijector is outside a certain range, it is unfortunately not possible to recover the input within machine precision. We have added a comment to the docstring of the Tanh and Sigmoid bijectors to indicate that and suggested the sample_and_log_prob as a workaround for that particular case.

@jgsimard
Copy link

jgsimard commented Sep 4, 2022

will this be look at again, it is the reason I am still using tensorflow probability over distrax

@kevinzakka
Copy link

Hi, any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants