Skip to content

Commit

Permalink
Add survival and log-survival function to distrax.Distribution base…
Browse files Browse the repository at this point in the history
… class.

Note that by default we use a numerically not necessarily stable definition
of the survival function in terms of the CDF: sf(x) = 1-cdf(x) and similarly
for the log survival function.
More stable definitions should be implemented in future CLs for distributions,
for which they exist.

Added a test to `distribution.Uniform` where we will use the default behaviour.

PiperOrigin-RevId: 443160366
  • Loading branch information
DistraxDev authored and DistraxDev committed Apr 22, 2022
1 parent cc452c5 commit 8a2fd38
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
33 changes: 33 additions & 0 deletions distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,39 @@ def cdf(self, value: Array) -> Array:
"""
return jnp.exp(self.log_cdf(value))

def survival_function(self, value: Array) -> Array:
"""Evaluates the survival function at `value`.
Args:
value: An event.
Returns:
The survival function evaluated at `value`, i.e. P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return 1. - self.cdf(value)
else:
raise NotImplementedError('`survival_function` is not defined for '
f'distribution {self.name}')

def log_survival_function(self, value: Array) -> Array:
"""Evaluates the log of the survival function at `value`.
Args:
value: An event.
Returns:
The log of the survival function evaluated at `value`, i.e.
log P[X > value]
"""
if not self.event_shape:
# Defined for univariate distributions only.
return jnp.log1p(-self.cdf(value))
else:
raise NotImplementedError('`log_survival_function` is not defined for '
f'distribution {self.name}')

def mean(self) -> Array:
"""Calculates the mean."""
raise NotImplementedError(
Expand Down
6 changes: 6 additions & 0 deletions distrax/_src/distributions/distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def test_to_batch_shape_index_raises(self, index):
distribution.to_batch_shape_index(
batch_shape=(2, 3, 4), index=index)

def test_multivariate_survival_function_raises(self):
mult_dist = DummyMultivariateDist(42)
with self.assertRaises(NotImplementedError):
mult_dist.survival_function(jnp.zeros(42))
with self.assertRaises(NotImplementedError):
mult_dist.log_survival_function(jnp.zeros(42))

if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions distrax/_src/distributions/uniform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('log_prob', 'log_prob'),
('prob', 'prob'),
('cdf', 'cdf'),
('survival_function', 'survival_function'),
('log_survival_function', 'log_survival_function')
)
def test_method_with_inputs(self, function_string):
inputs = 10. * np.random.normal(size=(100,))
Expand Down

0 comments on commit 8a2fd38

Please sign in to comment.