Skip to content

Commit

Permalink
Add support for discrete rvs
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Sep 4, 2023
1 parent dd7730e commit be1ee9f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
28 changes: 15 additions & 13 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable

import pymc as pm

from pymc.logprob.abstract import (
MeasurableVariable,
_logcdf_helper,
Expand All @@ -63,7 +65,7 @@ class MeasurableMax(Max):


class MeasurableMaxDiscrete(Max):
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxDiscrete)
Expand Down Expand Up @@ -101,14 +103,14 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if axis != base_var_dims:
return None

# logprob for discrete distribution
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs
# distinguish measurable discrete and continuous (because logprob is different)
if base_var.owner.op.dtype.startswith("int"):
if isinstance(base_var.owner.op, RandomVariable):
measurable_max = MeasurableMaxDiscrete(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
# logprob for continuous distribution
return max_rv
else:
measurable_max = MeasurableMax(list(axis))
max_rv_node = measurable_max.make_node(base_var)
Expand Down Expand Up @@ -145,16 +147,16 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively.
.. math::
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
logprob = _logprob_helper(base_rv, value)
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

n = base_rv.size
[n] = constant_fold([base_rv.size])

logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n)
logprob = pm.math.logdiffexp(n * logcdf, n * logcdf_prev)

return logprob
24 changes: 13 additions & 11 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import numpy as np
import pytensor.tensor as pt
import pytest
import scipy.stats as sp

import pymc as pm

Expand Down Expand Up @@ -149,23 +150,24 @@ def test_max_logprob(shape, value, axis):
)


def test_max_discrete():
x = pm.DiscreteUniform.dist(0, 1, size=(3,))
x.name = "x"
x_max = pt.max(x, axis=-1)
@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

discrete_logprob = _logprob_helper(x, x_max_value)
discrete_logcdf = _logcdf_helper(x, x_max_value)
discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1)
n = x.size
discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n)
test_value = value

test_value = 0.85
n = size
exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n
exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n

np.testing.assert_allclose(
discrete_logprob.eval({x_max_value: test_value}),
np.log(exp_rv - exp_rv_prev),
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)

0 comments on commit be1ee9f

Please sign in to comment.