Skip to content

Commit

Permalink
Implement step method sampler for DiscreteMarkovChain
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 15, 2024
1 parent c0c4931 commit f61e161
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 6 deletions.
86 changes: 81 additions & 5 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import logp
from pymc.pytensorf import constant_fold, intX
from pymc.util import check_dist_not_registered
from pymc.step_methods import STEP_METHODS
from pymc.step_methods.arraystep import ArrayStep
from pymc.step_methods.compound import Competence
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
from pytensor import Mode
from pytensor.graph.basic import Node
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
3 in this case.
>>> with pm.Model() as markov_chain:
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
.. code-block:: python
import pymc as pm
import pymc_experimental as pmx
with pm.Model() as markov_chain:
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
"""

Expand Down Expand Up @@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
"P must sum to 1 along the last axis, "
"First dimension of init_dist must be n_lags",
)


class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
name = "discrete_markov_chain_gibbs_metropolis"

def __init__(self, vars, proposal="uniform", order="random", model=None):
model = pm.modelcontext(model)
vars = get_value_vars_from_user_vars(vars, model)
initial_point = model.initial_point()

dimcats = []
# The above variable is a list of pairs (aggregate dimension, number
# of categories). For example, if vars = [x, y] with x being a 2-D
# variable with M categories and y being a 3-D variable with N
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
for v in vars:
v_init_val = initial_point[v.name]
rv_var = model.values_to_rvs[v]
rv_op = rv_var.owner.op

if not isinstance(rv_op, DiscreteMarkovChainRV):
raise TypeError("All variables must be DiscreteMarkovChainRV")

k_graph = rv_var.owner.inputs[0].shape[-1]
(k_graph,) = model.replace_rvs_by_values((k_graph,))
k = model.compile_fn(
k_graph,
inputs=model.value_vars,
on_unused_input="ignore",
mode=Mode(linker="py", optimizer=None),
)(initial_point)
start = len(dimcats)
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]

if order == "random":
self.shuffle_dims = True
self.dimcats = dimcats
else:
if sorted(order) != list(range(len(dimcats))):
raise ValueError("Argument 'order' has to be a permutation")
self.shuffle_dims = False
self.dimcats = [dimcats[j] for j in order]

if proposal == "uniform":
self.astep = self.astep_unif
elif proposal == "proportional":
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
self.astep = self.astep_prop
else:
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")

# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True

# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
ArrayStep.__init__(self, vars, [model.compile_logp()])

@staticmethod
def competence(var):
if isinstance(var.owner.op, DiscreteMarkovChainRV):
return Competence.IDEAL
return Competence.INCOMPATIBLE


STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)
40 changes: 39 additions & 1 deletion tests/distributions/test_discrete_markov_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
import pytensor.tensor as pt
import pytest

from pymc.distributions import Categorical
from pymc.distributions.shape_utils import change_dist_size
from pymc.logprob.utils import ParameterValueError
from pymc.sampling.mcmc import assign_step_methods

from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
from pymc_experimental.distributions.timeseries import (
DiscreteMarkovChain,
DiscreteMarkovChainGibbsMetropolis,
)


def transition_probability_tests(steps, n_states, n_lags, n_draws, atol):
Expand Down Expand Up @@ -216,3 +221,36 @@ def test_change_size_univariate(self):

new_rw = change_dist_size(chain, new_size=(4, 3), expand=True)
assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5)

def test_mcmc_sampling(self):
with pm.Model(coords={"step": range(100)}) as model:
init_dist = Categorical.dist(p=[0.5, 0.5])
DiscreteMarkovChain(
"markov_chain",
P=[[0.1, 0.9], [0.1, 0.9]],
init_dist=init_dist,
shape=(100,),
dims="step",
)

step_method = assign_step_methods(model)
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)

# Sampler needs no tuning
idata = pm.sample(
tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False
)

np.testing.assert_allclose(
idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")),
0.5,
atol=0.05,
)

np.testing.assert_allclose(
idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")),
0.9,
atol=0.05,
)

assert pm.stats.ess(idata, method="tail").min() > 950

0 comments on commit f61e161

Please sign in to comment.