diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index 0e865991..d4cd9435 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -20,7 +20,12 @@ from pymc.logprob.abstract import _logprob from pymc.logprob.basic import logp from pymc.pytensorf import constant_fold, intX -from pymc.util import check_dist_not_registered +from pymc.step_methods import STEP_METHODS +from pymc.step_methods.arraystep import ArrayStep +from pymc.step_methods.compound import Competence +from pymc.step_methods.metropolis import CategoricalGibbsMetropolis +from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars +from pytensor import Mode from pytensor.graph.basic import Node from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable @@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution): Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P, 3 in this case. - >>> with pm.Model() as markov_chain: - >>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) - >>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) - >>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) + .. code-block:: python + + import pymc as pm + import pymc_experimental as pmx + + with pm.Model() as markov_chain: + P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,)) + init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3)) + markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,)) """ @@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs): "P must sum to 1 along the last axis, " "First dimension of init_dist must be n_lags", ) + + +class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis): + name = "discrete_markov_chain_gibbs_metropolis" + + def __init__(self, vars, proposal="uniform", order="random", model=None): + model = pm.modelcontext(model) + vars = get_value_vars_from_user_vars(vars, model) + initial_point = model.initial_point() + + dimcats = [] + # The above variable is a list of pairs (aggregate dimension, number + # of categories). For example, if vars = [x, y] with x being a 2-D + # variable with M categories and y being a 3-D variable with N + # categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)]. + for v in vars: + v_init_val = initial_point[v.name] + rv_var = model.values_to_rvs[v] + rv_op = rv_var.owner.op + + if not isinstance(rv_op, DiscreteMarkovChainRV): + raise TypeError("All variables must be DiscreteMarkovChainRV") + + k_graph = rv_var.owner.inputs[0].shape[-1] + (k_graph,) = model.replace_rvs_by_values((k_graph,)) + k = model.compile_fn( + k_graph, + inputs=model.value_vars, + on_unused_input="ignore", + mode=Mode(linker="py", optimizer=None), + )(initial_point) + start = len(dimcats) + dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)] + + if order == "random": + self.shuffle_dims = True + self.dimcats = dimcats + else: + if sorted(order) != list(range(len(dimcats))): + raise ValueError("Argument 'order' has to be a permutation") + self.shuffle_dims = False + self.dimcats = [dimcats[j] for j in order] + + if proposal == "uniform": + self.astep = self.astep_unif + elif proposal == "proportional": + # Use the optimized "Metropolized Gibbs Sampler" described in Liu96. + self.astep = self.astep_prop + else: + raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") + + # Doesn't actually tune, but it's required to emit a sampler stat + # that indicates whether a draw was done in a tuning phase. + self.tune = True + + # We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic + ArrayStep.__init__(self, vars, [model.compile_logp()]) + + @staticmethod + def competence(var): + if isinstance(var.owner.op, DiscreteMarkovChainRV): + return Competence.IDEAL + return Competence.INCOMPATIBLE + + +STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis) diff --git a/tests/distributions/test_discrete_markov_chain.py b/tests/distributions/test_discrete_markov_chain.py index b2b1d979..0d855ef4 100644 --- a/tests/distributions/test_discrete_markov_chain.py +++ b/tests/distributions/test_discrete_markov_chain.py @@ -5,10 +5,15 @@ import pytensor.tensor as pt import pytest +from pymc.distributions import Categorical from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.utils import ParameterValueError +from pymc.sampling.mcmc import assign_step_methods -from pymc_experimental.distributions.timeseries import DiscreteMarkovChain +from pymc_experimental.distributions.timeseries import ( + DiscreteMarkovChain, + DiscreteMarkovChainGibbsMetropolis, +) def transition_probability_tests(steps, n_states, n_lags, n_draws, atol): @@ -216,3 +221,36 @@ def test_change_size_univariate(self): new_rw = change_dist_size(chain, new_size=(4, 3), expand=True) assert tuple(new_rw.shape.eval()) == (4, 3, 100, 5) + + def test_mcmc_sampling(self): + with pm.Model(coords={"step": range(100)}) as model: + init_dist = Categorical.dist(p=[0.5, 0.5]) + DiscreteMarkovChain( + "markov_chain", + P=[[0.1, 0.9], [0.1, 0.9]], + init_dist=init_dist, + shape=(100,), + dims="step", + ) + + step_method = assign_step_methods(model) + assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis) + + # Sampler needs no tuning + idata = pm.sample( + tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=0).mean(("chain", "draw")), + 0.5, + atol=0.05, + ) + + np.testing.assert_allclose( + idata.posterior["markov_chain"].isel(step=slice(1, None)).mean(("chain", "draw")), + 0.9, + atol=0.05, + ) + + assert pm.stats.ess(idata, method="tail").min() > 950