Skip to content

Commit

Permalink
Add survival and log-survival function to distrax.Distribution base…
Browse files Browse the repository at this point in the history
… class.

Note that by default we use a numerically not necessarily stable definition
of the survival function in terms of the CDF: sf(x) = 1-cdf(x) and similarly
for the log survival function.
More stable definitions should be implemented in future CLs for distributions,
for which they exist.

Added a test to `distribution.Uniform` where we will use the default behaviour.

PiperOrigin-RevId: 443160366
  • Loading branch information
DistraxDev authored and DistraxDev committed Apr 22, 2022
1 parent cc452c5 commit 10bd0db
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
43 changes: 43 additions & 0 deletions distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,49 @@ def cdf(self, value: Array) -> Array:
"""
return jnp.exp(self.log_cdf(value))

def survival_function(self, value: Array) -> Array:
"""Evaluates the survival function at `value`.
Note that by default we use a numerically not necessarily stable definition
of the survival function in terms of the CDF.
More stable definitions should be implemented in subclasses for
distributions for which they exist.
Args:
value: An event.
Returns:
The survival function evaluated at `value`, i.e. P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return 1. - self.cdf(value)
else:
raise NotImplementedError('`survival_function` is not defined for '
f'distribution `{self.name}`.')

def log_survival_function(self, value: Array) -> Array:
"""Evaluates the log of the survival function at `value`.
Note that by default we use a numerically not necessarily stable definition
of the log of the survival function in terms of the CDF.
More stable definitions should be implemented in subclasses for
distributions for which they exist.
Args:
value: An event.
Returns:
The log of the survival function evaluated at `value`, i.e.
log P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return jnp.log1p(-self.cdf(value))
else:
raise NotImplementedError('`log_survival_function` is not defined for '
f'distribution `{self.name}`.')

def mean(self) -> Array:
"""Calculates the mean."""
raise NotImplementedError(
Expand Down
6 changes: 6 additions & 0 deletions distrax/_src/distributions/distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def test_to_batch_shape_index_raises(self, index):
distribution.to_batch_shape_index(
batch_shape=(2, 3, 4), index=index)

def test_multivariate_survival_function_raises(self):
mult_dist = DummyMultivariateDist(42)
with self.assertRaises(NotImplementedError):
mult_dist.survival_function(jnp.zeros(42))
with self.assertRaises(NotImplementedError):
mult_dist.log_survival_function(jnp.zeros(42))

if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions distrax/_src/distributions/uniform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('log_prob', 'log_prob'),
('prob', 'prob'),
('cdf', 'cdf'),
('survival_function', 'survival_function'),
('log_survival_function', 'log_survival_function')
)
def test_method_with_inputs(self, function_string):
inputs = 10. * np.random.normal(size=(100,))
Expand Down

0 comments on commit 10bd0db

Please sign in to comment.