Skip to content

Logprob derivation of Max for Discrete IID distributions #6790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 24, 2023
39 changes: 34 additions & 5 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold


Expand All @@ -66,6 +67,13 @@ class MeasurableMax(Max):
MeasurableVariable.register(MeasurableMax)


class MeasurableMaxDiscrete(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxDiscrete)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -87,10 +95,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
Expand All @@ -102,7 +106,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if axis != base_var_dims:
return None

measurable_max = MeasurableMax(list(axis))
# distinguish measurable discrete and continuous (because logprob is different)
if base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
else:
measurable_max = MeasurableMax(list(axis))

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

Expand Down Expand Up @@ -131,6 +140,26 @@ def max_logprob(op, values, base_rv, **kwargs):
return logprob


@_logprob.register(MeasurableMaxDiscrete)
def max_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.

The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

[n] = constant_fold([base_rv.size])

logprob = logdiffexp(n * logcdf, n * logcdf_prev)

return logprob


class MeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""
Expand Down
24 changes: 24 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import numpy as np
import pytensor.tensor as pt
import pytest
import scipy.stats as sp

import pymc as pm

Expand Down Expand Up @@ -230,3 +231,26 @@ def test_min_non_mul_elemwise_fails():
x_min_value = pt.vector("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)


@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

test_value = value

n = size
exp_rv = sp.poisson(mu).cdf(test_value) ** n
exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n

np.testing.assert_allclose(
np.log(exp_rv - exp_rv_prev),
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)