From be1ee9fe4e97af9c9738be79f192b44b001bad75 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Mon, 4 Sep 2023 18:59:27 +0530 Subject: [PATCH] Add support for discrete rvs --- pymc/logprob/order.py | 28 +++++++++++++++------------- tests/logprob/test_order.py | 24 +++++++++++++----------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 9b54a535b9c..c3f4372667f 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -45,6 +45,8 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.var import TensorVariable +import pymc as pm + from pymc.logprob.abstract import ( MeasurableVariable, _logcdf_helper, @@ -63,7 +65,7 @@ class MeasurableMax(Max): class MeasurableMaxDiscrete(Max): - """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" + """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" MeasurableVariable.register(MeasurableMaxDiscrete) @@ -101,14 +103,14 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if axis != base_var_dims: return None - # logprob for discrete distribution - if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): - measurable_max = MeasurableMaxDiscrete(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs + # distinguish measurable discrete and continuous (because logprob is different) + if base_var.owner.op.dtype.startswith("int"): + if isinstance(base_var.owner.op, RandomVariable): + measurable_max = MeasurableMaxDiscrete(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs - return max_rv - # logprob for continuous distribution + return max_rv else: measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) @@ -145,16 +147,16 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : - \ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) - where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + .. math:: + \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. """ (value,) = values - logprob = _logprob_helper(base_rv, value) logcdf = _logcdf_helper(base_rv, value) logcdf_prev = _logcdf_helper(base_rv, value - 1) - n = base_rv.size + [n] = constant_fold([base_rv.size]) - logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n) + logprob = pm.math.logdiffexp(n * logcdf, n * logcdf_prev) return logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 1b600205bb2..d41a4afb882 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -39,6 +39,7 @@ import numpy as np import pytensor.tensor as pt import pytest +import scipy.stats as sp import pymc as pm @@ -149,23 +150,24 @@ def test_max_logprob(shape, value, axis): ) -def test_max_discrete(): - x = pm.DiscreteUniform.dist(0, 1, size=(3,)) - x.name = "x" - x_max = pt.max(x, axis=-1) +@pytest.mark.parametrize( + "mu, size, value, axis", + [(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)], +) +def test_max_discrete(mu, size, value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(size)) + x_max = pt.max(x, axis=axis) x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - discrete_logprob = _logprob_helper(x, x_max_value) - discrete_logcdf = _logcdf_helper(x, x_max_value) - discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1) - n = x.size - discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n) + test_value = value - test_value = 0.85 + n = size + exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n + exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n np.testing.assert_allclose( - discrete_logprob.eval({x_max_value: test_value}), + np.log(exp_rv - exp_rv_prev), (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, )