diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 89796c0c..1eb23ff2 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -2,27 +2,37 @@ from typing import Sequence, Tuple, Union import numpy as np +import pymc import pytensor.tensor as pt +from arviz import dict_to_dataset from pymc import SymbolicRandomVariable +from pymc.backends.arviz import coords_and_dims_for_inferencedata from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob from pymc.logprob.basic import conditional_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model -from pymc.pytensorf import constant_fold, inputvars +from pymc.pytensorf import compile_pymc, constant_fold, inputvars +from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict from pytensor import Mode from pytensor.compile import SharedVariable from pytensor.compile.builders import OpFromGraph -from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace +from pytensor.graph import ( + Constant, + FunctionGraph, + ancestors, + clone_replace, + vectorize_graph, +) from pytensor.scan import map as scan_map from pytensor.tensor import TensorVariable from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.shape import Shape +from pytensor.tensor.special import log_softmax __all__ = ["MarginalModel"] -from pytensor.tensor.shape import Shape - class MarginalModel(Model): """Subclass of PyMC Model that implements functionality for automatic @@ -74,6 +84,7 @@ class MarginalModel(Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.marginalized_rvs = [] + self._marginalized_named_vars_to_dims = treedict() def _delete_rv_mappings(self, rv: TensorVariable) -> None: """Remove all model mappings referring to rv @@ -205,8 +216,9 @@ def clone(self): vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs cloned_vars = clone_replace(vars) vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)} + m.vars_to_clone = vars_to_clone - m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()} + m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()}) m.named_vars_to_dims = self.named_vars_to_dims m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()} m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()} @@ -220,11 +232,18 @@ def clone(self): m.deterministics = [vars_to_clone[det] for det in self.deterministics] m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs] + m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims return m - def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]): + def marginalize( + self, + rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]], + ): if not isinstance(rvs_to_marginalize, Sequence): rvs_to_marginalize = (rvs_to_marginalize,) + rvs_to_marginalize = [ + self[var] if isinstance(var, str) else var for var in rvs_to_marginalize + ] supported_dists = (Bernoulli, Categorical, DiscreteUniform) for rv_to_marginalize in rvs_to_marginalize: @@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV f"Supported distribution include {supported_dists}" ) + if rv_to_marginalize.name in self.named_vars_to_dims: + dims = self.named_vars_to_dims[rv_to_marginalize.name] + self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims + self._delete_rv_mappings(rv_to_marginalize) self.marginalized_rvs.append(rv_to_marginalize) # Raise errors and warnings immediately self.clone()._marginalize(user_warnings=True) + def _to_transformed(self): + "Create a function from the untransformed space to the transformed space" + transformed_rvs = [] + transformed_names = [] + + for rv in self.free_RVs: + transform = self.rvs_to_transforms.get(rv) + if transform is None: + transformed_rvs.append(rv) + transformed_names.append(rv.name) + else: + transformed_rv = transform.forward(rv, *rv.owner.inputs) + transformed_rvs.append(transformed_rv) + transformed_names.append(self.rvs_to_values[rv].name) + + fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs) + return fn, transformed_names + + def unmarginalize(self, rvs_to_unmarginalize): + for rv in rvs_to_unmarginalize: + self.marginalized_rvs.remove(rv) + if rv.name in self._marginalized_named_vars_to_dims: + dims = self._marginalized_named_vars_to_dims.pop(rv.name) + else: + dims = None + self.register_rv(rv, name=rv.name, dims=dims) + + def recover_marginals( + self, + idata, + var_names=None, + return_samples=True, + extend_inferencedata=True, + random_seed=None, + ): + """Computes posterior log-probabilities and samples of marginalized variables + conditioned on parameters of the model given InferenceData with posterior group + + When there are multiple marginalized variables, each marginalized variable is + conditioned on both the parameters and the other variables still marginalized + + All log-probabilities are within the transformed space + + Parameters + ---------- + idata : InferenceData + InferenceData with posterior group + var_names : sequence of str, optional + List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables + return_samples : bool, default True + If True, also return samples of the marginalized variables + extend_inferencedata : bool, default True + Whether to extend the original InferenceData or return a new one + random_seed: int, array-like of int or SeedSequence, optional + Seed used to generating samples + + Returns + ------- + idata : InferenceData + InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group + + .. code-block:: python + + import pymc as pm + from pymc_experimental import MarginalModel + + with MarginalModel() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10]) + + m.marginalize([x]) + + idata = pm.sample() + m.recover_marginals(idata, var_names=["x"]) + + + """ + if var_names is None: + var_names = [var.name for var in self.marginalized_rvs] + + var_names = [var if isinstance(var, str) else var.name for var in var_names] + vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names] + missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs] + if missing_names: + raise ValueError(f"Unrecognized var_names: {missing_names}") + + if return_samples and random_seed is not None: + seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover)) + else: + seeds = [None] * len(vars_to_recover) + + posterior = idata.posterior + + # Remove Deterministics + posterior_values = posterior[ + [rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs] + ] + + sample_dims = ("chain", "draw") + posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) + + # Handle Transforms + transform_fn, transform_names = self._to_transformed() + + def transform_input(inputs): + return dict(zip(transform_names, transform_fn(inputs))) + + posterior_pts = [transform_input(vs) for vs in posterior_pts] + + rv_dict = {} + rv_dims = {} + for seed, rv in zip(seeds, vars_to_recover): + supported_dists = (Bernoulli, Categorical, DiscreteUniform) + if not isinstance(rv.owner.op, supported_dists): + raise NotImplementedError( + f"RV with distribution {rv.owner.op} cannot be recovered. " + f"Supported distribution include {supported_dists}" + ) + + m = self.clone() + rv = m.vars_to_clone[rv] + m.unmarginalize([rv]) + dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs) + joint_logps = m.logp(vars=dependent_vars + [rv], sum=False) + + marginalized_value = m.rvs_to_values[rv] + other_values = [v for v in m.value_vars if v is not marginalized_value] + + # Handle batch dims for marginalized value and its dependent RVs + joint_logp = joint_logps[-1] + for dv in joint_logps[:-1]: + dbcast = dv.type.broadcastable + mbcast = marginalized_value.type.broadcastable + mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast + values_axis_bcast = [ + i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v + ] + joint_logp += dv.sum(values_axis_bcast) + + rv_shape = constant_fold(tuple(rv.shape)) + rv_domain = get_domain_of_finite_discrete_rv(rv) + rv_domain_tensor = pt.moveaxis( + pt.full( + (*rv_shape, len(rv_domain)), + rv_domain, + dtype=rv.dtype, + ), + -1, + 0, + ) + + joint_logps = vectorize_graph( + joint_logp, + replace={marginalized_value: rv_domain_tensor}, + ) + joint_logps = pt.moveaxis(joint_logps, 0, -1) + + rv_loglike_fn = None + joint_logps_norm = log_softmax(joint_logps, axis=-1) + if return_samples: + sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) + if isinstance(rv.owner.op, DiscreteUniform): + sample_rv_outs += rv_domain[0] + + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=[joint_logps_norm, sample_rv_outs], + on_unused_input="ignore", + random_seed=seed, + ) + else: + rv_loglike_fn = compile_pymc( + inputs=other_values, + outputs=joint_logps_norm, + on_unused_input="ignore", + random_seed=seed, + ) + + logvs = [rv_loglike_fn(**vs) for vs in posterior_pts] + + logps = None + samples = None + if return_samples: + logps, samples = zip(*logvs) + logps = np.array(logps) + samples = np.array(samples) + rv_dict[rv.name] = samples.reshape( + tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:], + ) + else: + logps = np.array(logvs) + + rv_dict["lp_" + rv.name] = logps.reshape( + tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:], + ) + if rv.name in m.named_vars_to_dims: + rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name]) + rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"] + + coords, dims = coords_and_dims_for_inferencedata(self) + dims.update(rv_dims) + rv_dataset = dict_to_dataset( + rv_dict, + library=pymc, + dims=dims, + coords=coords, + default_dims=list(sample_dims), + skip_event_dims=True, + ) + + if extend_inferencedata: + idata.posterior = idata.posterior.assign(rv_dataset) + return idata + else: + return rv_dataset + class MarginalRV(SymbolicRandomVariable): """Base class for Marginalized RVs""" @@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # PyMC does not allow RVs in the logp graph, even if we are just using the shape marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape)) marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv) - marginalized_rv_domain_tensor = pt.swapaxes( + marginalized_rv_domain_tensor = pt.moveaxis( pt.full( (*marginalized_rv_shape, len(marginalized_rv_domain)), marginalized_rv_domain, dtype=marginalized_rv.dtype, ), - axis1=0, - axis2=-1, + -1, + 0, ) # Arbitrary cutoff to switch to Scan implementation to keep graph size under control diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index f8ed5718..b667635c 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -6,11 +6,13 @@ import pymc as pm import pytensor.tensor as pt import pytest +from arviz import InferenceData, dict_to_dataset from pymc import ImputationWarning, inputvars from pymc.distributions import transforms from pymc.logprob.abstract import _logprob from pymc.util import UNSET -from scipy.special import logsumexp +from scipy.special import log_softmax, logsumexp +from scipy.stats import halfnorm, norm from pymc_experimental.model.marginal_model import ( FiniteDiscreteMarginalRV, @@ -166,6 +168,21 @@ def test_multiple_dependent_marginalized_rvs(): np.testing.assert_array_almost_equal(logp_x_y, ref_logp_x_y) +def test_rv_dependent_multiple_marginalized_rvs(): + """Test when random variables depend on multiple marginalized variables""" + with MarginalModel() as m: + x = pm.Bernoulli("x", 0.1) + y = pm.Bernoulli("y", 0.3) + z = pm.DiracDelta("z", c=x + y) + + m.marginalize([x, y]) + logp = m.compile_logp() + + np.testing.assert_allclose(np.exp(logp({"z": 0})), 0.9 * 0.7) + np.testing.assert_allclose(np.exp(logp({"z": 1})), 0.9 * 0.3 + 0.1 * 0.7) + np.testing.assert_allclose(np.exp(logp({"z": 2})), 0.1 * 0.3) + + @pytest.mark.filterwarnings("error") def test_nested_marginalized_rvs(): """Test that marginalization works when there are nested marginalized RVs""" @@ -251,6 +268,161 @@ def test_marginalized_change_point_model_sampling(disaster_model): ) +def test_recover_marginals_basic(): + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + p = np.array([0.5, 0.2, 0.3]) + k = pm.Categorical("k", p=p) + mu = np.array([-3.0, 0.0, 3.0]) + mu_ = pt.as_tensor_variable(mu) + y = pm.Normal("y", mu=mu_[k], sigma=sigma) + + m.marginalize([k]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + samples=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData(posterior=dict_to_dataset(prior)) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert "k" in post + assert "lp_k" in post + assert post.k.shape == post.y.shape + assert post.lp_k.shape == post.k.shape + (len(p),) + + def true_logp(y, sigma): + y = y.repeat(len(p)).reshape(len(y), -1) + sigma = sigma.repeat(len(p)).reshape(len(sigma), -1) + return log_softmax( + np.log(p) + + norm.logpdf(y, loc=mu, scale=sigma) + + halfnorm.logpdf(sigma) + + np.log(sigma), + axis=1, + ) + + np.testing.assert_almost_equal( + true_logp(post.y.values.flatten(), post.sigma.values.flatten()), + post.lp_k[0].values, + ) + np.testing.assert_almost_equal(logsumexp(post.lp_k, axis=-1), 0) + + +def test_recover_marginals_coords(): + """Test if coords can be recovered with marginalized value had it originally""" + with MarginalModel(coords={"year": [1990, 1991, 1992]}) as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.75, dims="year") + x = pm.Normal("x", mu=idx, sigma=sigma, dims="year") + + m.marginalize([idx]) + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + samples=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData( + posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + ) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert post.idx.dims == ("chain", "draw", "year") + assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") + + +def test_recover_batched_marginal(): + """Test that marginalization works for batched random variables""" + with MarginalModel() as m: + sigma = pm.HalfNormal("sigma") + idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2)) + y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2)) + + m.marginalize([idx]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + samples=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData( + posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) + ) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert "idx" in post + assert "lp_idx" in post + assert post.idx.shape == post.y.shape + assert post.lp_idx.shape == post.idx.shape + (2,) + + +def test_nested_recover_marginals(): + """Test that marginalization works when there are nested marginalized RVs""" + + with MarginalModel() as m: + idx = pm.Bernoulli("idx", p=0.75) + sub_idx = pm.Bernoulli("sub_idx", p=pt.switch(pt.eq(idx, 0), 0.15, 0.95)) + sub_dep = pm.Normal("y", mu=idx + sub_idx, sigma=1.0) + + m.marginalize([idx, sub_idx]) + + rng = np.random.default_rng(211) + + with m: + prior = pm.sample_prior_predictive( + samples=20, + random_seed=rng, + return_inferencedata=False, + ) + idata = InferenceData(posterior=dict_to_dataset(prior)) + + idata = m.recover_marginals(idata, return_samples=True) + post = idata.posterior + assert "idx" in post + assert "lp_idx" in post + assert post.idx.shape == post.y.shape + assert post.lp_idx.shape == post.idx.shape + (2,) + assert "sub_idx" in post + assert "lp_sub_idx" in post + assert post.sub_idx.shape == post.y.shape + assert post.lp_sub_idx.shape == post.sub_idx.shape + (2,) + + def true_idx_logp(y): + idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1)) + idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + return log_softmax(np.stack([idx_0, idx_1]).T, axis=1) + + np.testing.assert_almost_equal( + true_idx_logp(post.y.values.flatten()), + post.lp_idx[0].values, + ) + + def true_sub_idx_logp(y): + sub_idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.05 * 0.75 * norm.pdf(y, loc=1)) + sub_idx_1 = np.log(0.15 * 0.25 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2)) + return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1) + + np.testing.assert_almost_equal( + true_sub_idx_logp(post.y.values.flatten()), + post.lp_sub_idx[0].values, + ) + np.testing.assert_almost_equal(logsumexp(post.lp_idx, axis=-1), 0) + np.testing.assert_almost_equal(logsumexp(post.lp_sub_idx, axis=-1), 0) + + @pytest.mark.filterwarnings("error") def test_not_supported_marginalized(): """Marginalized graphs with non-Elemwise Operations are not supported as they