Skip to content

Commit

Permalink
Added the GumbelCDF bijector originally developed in #36.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 424848125
  • Loading branch information
franrruiz authored and DistraxDev committed Jan 28, 2022
1 parent bce1753 commit 6faaa61
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 194 deletions.
2 changes: 2 additions & 0 deletions distrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from distrax._src.bijectors.bijector import BijectorLike
from distrax._src.bijectors.block import Block
from distrax._src.bijectors.chain import Chain
from distrax._src.bijectors.gumbel_cdf import GumbelCDF
from distrax._src.bijectors.inverse import Inverse
from distrax._src.bijectors.lambda_bijector import Lambda
from distrax._src.bijectors.lower_upper_triangular_affine import LowerUpperTriangularAffine
Expand Down Expand Up @@ -87,6 +88,7 @@
"Gamma",
"Greedy",
"Gumbel",
"GumbelCDF",
"HMM",
"importance_sampling_ratios",
"Independent",
Expand Down
45 changes: 16 additions & 29 deletions distrax/_src/bijectors/gumbel_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,39 @@
# ==============================================================================
"""GumbelCDF bijector."""

from typing import Tuple, Union
from typing import Tuple

from distrax._src.bijectors import bijector as base
from distrax._src.utils import conversion
import jax.numpy as jnp

Array = base.Array
Numeric = Union[Array, float]


class GumbelCDF(base.Bijector):
"""A bijector that computes the Gumbel cumulative density function (CDF)."""
"""A bijector that computes the Gumbel cumulative density function (CDF).
def __init__(self, loc: Numeric = 0., scale: Numeric = 1.):
"""Initializes a Gumbel bijector."""
super().__init__(event_ndims_in=0)
self._loc = conversion.as_float_array(loc)
self._scale = conversion.as_float_array(scale)

@property
def loc(self) -> Numeric:
"""The bijector's location."""
return self._loc
The Gumbel CDF is given by `y = f(x) = exp(-exp(-x))` for a scalar input `x`.
Its inverse is `x = -log(-log(y))`. The log-det Jacobian of the transformation
is `log df/dx = -exp(-x) - x`.
"""

@property
def scale(self) -> Numeric:
"""The bijector's scale."""
return self._scale
def __init__(self):
"""Initializes a GumbelCDF bijector."""
super().__init__(event_ndims_in=0)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
z = (x - self._loc) / self._scale
y = jnp.exp(-jnp.exp(-z))
log_det = -z - jnp.exp(-z) - jnp.log(self._scale)
exp_neg_x = jnp.exp(-x)
y = jnp.exp(-exp_neg_x)
log_det = - x - exp_neg_x
return y, log_det

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
x = self._loc - self._scale * jnp.log(-jnp.log(y))
return x, jnp.log(self._scale / (-jnp.log(y) * y))
log_y = jnp.log(y)
x = -jnp.log(-log_y)
return x, x - log_y

def same_as(self, other: base.Bijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is GumbelCDF: # pylint: disable=unidiomatic-typecheck
return all((
self.loc is other.loc,
self.scale is other.scale,
))
return False
return type(other) is GumbelCDF # pylint: disable=unidiomatic-typecheck
243 changes: 78 additions & 165 deletions distrax/_src/bijectors/gumbel_cdf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,15 @@

import chex
from distrax._src.bijectors import gumbel_cdf
from distrax._src.distributions import normal
from distrax._src.distributions import transformed
from distrax._src.utils import conversion
from distrax._src.bijectors import tanh
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp


tfd = tfp.distributions
tfb = tfp.bijectors


RTOL = 1e-2


def _with_additional_parameters(params, all_named_parameters):
"""Convenience function for appending a cartesian product of parameters."""
for name, param in params:
for named_params in all_named_parameters:
yield (f'{named_params[0]}; {name}',) + named_params[1:] + (param,)


def _with_base_dists(*all_named_parameters):
"""Partial of _with_additional_parameters to specify distrax and TFP base."""
base_dists = (
('tfp_base', tfd.Normal),
('distrax_base', normal.Normal),
)
return _with_additional_parameters(base_dists, all_named_parameters)
RTOL = 1e-5


class GumbelCDFTest(parameterized.TestCase):
Expand All @@ -58,143 +37,96 @@ def setUp(self):
self.seed = jax.random.PRNGKey(1234)

def test_properties(self):
bijector = gumbel_cdf.GumbelCDF(loc=-1., scale=0.1)
bijector = gumbel_cdf.GumbelCDF()
self.assertEqual(bijector.event_ndims_in, 0)
self.assertEqual(bijector.event_ndims_out, 0)
self.assertFalse(bijector.is_constant_jacobian)
self.assertFalse(bijector.is_constant_log_det)
np.testing.assert_allclose(bijector.loc, -1.)
np.testing.assert_allclose(bijector.scale, 0.1)

@parameterized.named_parameters(_with_base_dists(
('1d std normal', 0, 1),
('2d std normal', np.zeros(2), np.ones(2)),
('broadcasted loc', 0, np.ones(3)),
('broadcasted scale', np.ones(3), 1),
))
def test_event_shape(self, mu, sigma, base_dist):
base = base_dist(mu, sigma)

@chex.all_variants
@parameterized.parameters(
{'x_shape': (2,)},
{'x_shape': (2, 3)},
{'x_shape': (2, 3, 4)})
def test_forward_shapes(self, x_shape):
x = jnp.zeros(x_shape)
bijector = gumbel_cdf.GumbelCDF()
dist = transformed.Transformed(base, bijector)

tfp_bijector = tfb.GumbelCDF()
tfp_dist = tfd.TransformedDistribution(
conversion.to_tfp(base), tfp_bijector)

assert dist.event_shape == tfp_dist.event_shape
y1 = self.variant(bijector.forward)(x)
logdet1 = self.variant(bijector.forward_log_det_jacobian)(x)
y2, logdet2 = self.variant(bijector.forward_and_log_det)(x)
self.assertEqual(y1.shape, x_shape)
self.assertEqual(y2.shape, x_shape)
self.assertEqual(logdet1.shape, x_shape)
self.assertEqual(logdet2.shape, x_shape)

@chex.all_variants
@parameterized.named_parameters(_with_base_dists(
('1d std normal, no shape', 0, 1, ()),
('1d std normal, int shape', 0, 1, 1),
('1d std normal, 1-tuple shape', 0, 1, (1,)),
('1d std normal, 2-tuple shape', 0, 1, (2, 2)),
('2d std normal, no shape', np.zeros(2), np.ones(2), ()),
('2d std normal, int shape', [0, 0], [1, 1], 1),
('2d std normal, 1-tuple shape', np.zeros(2), np.ones(2), (1,)),
('2d std normal, 2-tuple shape', [0, 0], [1, 1], (2, 2)),
('rank 2 std normal, 2-tuple shape', np.zeros(
(3, 2)), np.ones((3, 2)), (2, 2)),
('broadcasted loc', 0, np.ones(3), (2, 2)),
('broadcasted scale', np.ones(3), 1, ()),
))
def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
base = base_dist(mu, sigma)
@parameterized.parameters(
{'y_shape': (2,)},
{'y_shape': (2, 3)},
{'y_shape': (2, 3, 4)})
def test_inverse_shapes(self, y_shape):
y = jnp.zeros(y_shape)
bijector = gumbel_cdf.GumbelCDF()
dist = transformed.Transformed(base, bijector)
def sample_fn(seed, sample_shape):
return dist.sample(seed=seed, sample_shape=sample_shape)
samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)(
self.seed, sample_shape)
x1 = self.variant(bijector.inverse)(y)
logdet1 = self.variant(bijector.inverse_log_det_jacobian)(y)
x2, logdet2 = self.variant(bijector.inverse_and_log_det)(y)
self.assertEqual(x1.shape, y_shape)
self.assertEqual(x2.shape, y_shape)
self.assertEqual(logdet1.shape, y_shape)
self.assertEqual(logdet2.shape, y_shape)

tfp_bijector = tfb.GumbelCDF()
tfp_dist = tfd.TransformedDistribution(
conversion.to_tfp(base), tfp_bijector)
tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed)

chex.assert_equal_shape([samples, tfp_samples])
@chex.all_variants
def test_forward(self):
x = jax.random.normal(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
y = self.variant(bijector.forward)(x)
expected_y = jnp.exp(-jnp.exp(-x))
np.testing.assert_allclose(y, expected_y, rtol=RTOL)

@chex.all_variants
@parameterized.named_parameters(_with_base_dists(
('1d dist, 1d value', 0, 1, 1.),
('1d dist, 1d value int', 0, 1, 1),
('1d dist, 2d value', 0., 1., np.array([1., 2.])),
('1d dist, 2d value int', 0., 1., np.array([1, 2], dtype=np.int32)),
('2d dist, 1d value', np.zeros(2), np.ones(2), 1.),
('2d broadcasted dist, 1d value', np.zeros(2), 1, 1.),
('2d dist, 2d value', np.zeros(2), np.ones(2), np.array([1., 2.])),
('1d dist, 1d value, edge case', 0, 1, 200.),
))
def test_log_prob(self, mu, sigma, value, base_dist):
base = base_dist(mu, sigma)
def test_forward_log_det_jacobian(self):
x = jax.random.normal(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
dist = transformed.Transformed(base, bijector)
actual = self.variant(dist.log_prob)(value)
fwd_logdet = self.variant(bijector.forward_log_det_jacobian)(x)
actual = jnp.log(jax.vmap(jax.grad(bijector.forward))(x))
np.testing.assert_allclose(fwd_logdet, actual, rtol=1e-3)

tfp_bijector = tfb.GumbelCDF()
tfp_dist = tfd.TransformedDistribution(
conversion.to_tfp(base), tfp_bijector)
expected = tfp_dist.log_prob(value)
np.testing.assert_allclose(actual, expected, atol=1e-6)
@chex.all_variants
def test_forward_and_log_det(self):
x = jax.random.normal(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
y1 = self.variant(bijector.forward)(x)
logdet1 = self.variant(bijector.forward_log_det_jacobian)(x)
y2, logdet2 = self.variant(bijector.forward_and_log_det)(x)
np.testing.assert_allclose(y1, y2, rtol=RTOL)
np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL)

@chex.all_variants
@parameterized.named_parameters(_with_base_dists(
('1d dist, 1d value', 0, 1, 1.),
('1d dist, 1d value int', 0, 1, 1),
('1d dist, 2d value', 0., 1., np.array([1., 2.])),
('1d dist, 2d value int', 0., 1., np.array([1, 2], dtype=np.int32)),
('2d dist, 1d value', np.zeros(2), np.ones(2), 1.),
('2d broadcasted dist, 1d value', np.zeros(2), 1, 1.),
('2d dist, 2d value', np.zeros(2), np.ones(2), np.array([1., 2.])),
('1d dist, 1d value, edge case', 0, 1, 200.),
))
def test_prob(self, mu, sigma, value, base_dist):
base = base_dist(mu, sigma)
def test_inverse(self):
x = jax.random.normal(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
dist = transformed.Transformed(base, bijector)
actual = self.variant(dist.prob)(value)
y = self.variant(bijector.forward)(x)
x_rec = self.variant(bijector.inverse)(y)
np.testing.assert_allclose(x_rec, x, rtol=1e-3)

tfp_bijector = tfb.GumbelCDF()
tfp_dist = tfd.TransformedDistribution(
conversion.to_tfp(base), tfp_bijector)
expected = tfp_dist.prob(value)
np.testing.assert_allclose(actual, expected, atol=1e-9)
@chex.all_variants
def test_inverse_log_det_jacobian(self):
x = jax.random.normal(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
y = self.variant(bijector.forward)(x)
fwd_logdet = self.variant(bijector.forward_log_det_jacobian)(x)
inv_logdet = self.variant(bijector.inverse_log_det_jacobian)(y)
np.testing.assert_allclose(inv_logdet, -fwd_logdet, rtol=1e-3)

@chex.all_variants
@parameterized.named_parameters(_with_base_dists(
('1d std normal, no shape', 0, 1, ()),
('1d std normal, int shape', 0, 1, 1),
('1d std normal, 1-tuple shape', 0, 1, (1,)),
('1d std normal, 2-tuple shape', 0, 1, (2, 2)),
('2d std normal, no shape', np.zeros(2), np.ones(2), ()),
('2d std normal, int shape', [0, 0], [1, 1], 1),
('2d std normal, 1-tuple shape', np.zeros(2), np.ones(2), (1,)),
('2d std normal, 2-tuple shape', [0, 0], [1, 1], (2, 2)),
('rank 2 std normal, 2-tuple shape', np.zeros(
(3, 2)), np.ones((3, 2)), (2, 2)),
('broadcasted loc', 0, np.ones(3), (2, 2)),
('broadcasted scale', np.ones(3), 1, ()),
))
def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
base = base_dist(mu, sigma)
bijector = gumbel_cdf.GumbelCDF(mu, sigma)
dist = transformed.Transformed(base, bijector)
def sample_and_log_prob_fn(seed, sample_shape):
return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape)
samples, log_prob = self.variant(
sample_and_log_prob_fn, ignore_argnums=(1,), static_argnums=(1,))(
self.seed, sample_shape)
expected_samples = bijector.forward(
base.sample(seed=self.seed, sample_shape=sample_shape))

tfp_bijector = tfb.GumbelCDF(mu, sigma)
tfp_dist = tfd.TransformedDistribution(
conversion.to_tfp(base), tfp_bijector)
tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape)
tfp_log_prob = tfp_dist.log_prob(tfp_samples)

chex.assert_equal_shape([samples, tfp_samples])
np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
def test_inverse_and_log_det(self):
y = jax.random.uniform(self.seed, (100,))
bijector = gumbel_cdf.GumbelCDF()
x1 = self.variant(bijector.inverse)(y)
logdet1 = self.variant(bijector.inverse_log_det_jacobian)(y)
x2, logdet2 = self.variant(bijector.inverse_and_log_det)(y)
np.testing.assert_allclose(x1, x2, rtol=RTOL)
np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL)

@chex.all_variants
def test_stability(self):
Expand All @@ -211,22 +143,6 @@ def test_stability(self):
ildj_ = self.variant(bijector.inverse_log_det_jacobian)(y)
np.testing.assert_allclose(ildj_, ildj, rtol=RTOL)

@chex.all_variants
@parameterized.named_parameters(
('int16', np.array([0, 0], dtype=np.int16)),
('int32', np.array([0, 0], dtype=np.int32)),
('int64', np.array([0, 0], dtype=np.int64)),
)
def test_integer_inputs(self, inputs):
bijector = gumbel_cdf.GumbelCDF()
output, log_det = self.variant(bijector.forward_and_log_det)(inputs)

expected_out = (jnp.exp(-jnp.exp(inputs))).astype(jnp.float32)
expected_log_det = (-inputs - jnp.exp(-inputs)).astype(jnp.float32)

np.testing.assert_array_equal(output, expected_out)
np.testing.assert_array_equal(log_det, expected_log_det)

def test_jittable(self):
@jax.jit
def f(x, b):
Expand All @@ -236,14 +152,11 @@ def f(x, b):
x = np.zeros(())
f(x, bijector)

@chex.all_variants(with_pmap=False)
def test_same_as(self):
loc1 = np.array([0., -1., 1.])
loc2 = np.array([1., 1., 1.])
scale = np.array([1., 2., 3.])
bijector1 = gumbel_cdf.GumbelCDF(loc1, scale)
bijector2 = gumbel_cdf.GumbelCDF(loc2, scale)
self.assertFalse(self.variant(bijector1.same_as)(bijector2))
bijector = gumbel_cdf.GumbelCDF()
self.assertTrue(bijector.same_as(bijector))
self.assertTrue(bijector.same_as(gumbel_cdf.GumbelCDF()))
self.assertFalse(bijector.same_as(tanh.Tanh()))


if __name__ == '__main__':
Expand Down

0 comments on commit 6faaa61

Please sign in to comment.