Skip to content

Commit

Permalink
Implement utility to recover marginalized variables from `MarginalMod…
Browse files Browse the repository at this point in the history
…el` (#285)

Adding recover_marginals utility function
  • Loading branch information
zaxtax authored Dec 25, 2023
1 parent 99f30aa commit 4f75687
Show file tree
Hide file tree
Showing 2 changed files with 422 additions and 10 deletions.
258 changes: 249 additions & 9 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,37 @@
from typing import Sequence, Tuple, Union

import numpy as np
import pymc
import pytensor.tensor as pt
from arviz import dict_to_dataset
from pymc import SymbolicRandomVariable
from pymc.backends.arviz import coords_and_dims_for_inferencedata
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import constant_fold, inputvars
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
from pytensor import Mode
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph import (
Constant,
FunctionGraph,
ancestors,
clone_replace,
vectorize_graph,
)
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorVariable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax

__all__ = ["MarginalModel"]

from pytensor.tensor.shape import Shape


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
Expand Down Expand Up @@ -74,6 +84,7 @@ class MarginalModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.marginalized_rvs = []
self._marginalized_named_vars_to_dims = treedict()

def _delete_rv_mappings(self, rv: TensorVariable) -> None:
"""Remove all model mappings referring to rv
Expand Down Expand Up @@ -205,8 +216,9 @@ def clone(self):
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
cloned_vars = clone_replace(vars)
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
m.vars_to_clone = vars_to_clone

m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
m.named_vars_to_dims = self.named_vars_to_dims
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
Expand All @@ -220,11 +232,18 @@ def clone(self):
m.deterministics = [vars_to_clone[det] for det in self.deterministics]

m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
return m

def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
def marginalize(
self,
rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]],
):
if not isinstance(rvs_to_marginalize, Sequence):
rvs_to_marginalize = (rvs_to_marginalize,)
rvs_to_marginalize = [
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
]

supported_dists = (Bernoulli, Categorical, DiscreteUniform)
for rv_to_marginalize in rvs_to_marginalize:
Expand All @@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
f"Supported distribution include {supported_dists}"
)

if rv_to_marginalize.name in self.named_vars_to_dims:
dims = self.named_vars_to_dims[rv_to_marginalize.name]
self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims

self._delete_rv_mappings(rv_to_marginalize)
self.marginalized_rvs.append(rv_to_marginalize)

# Raise errors and warnings immediately
self.clone()._marginalize(user_warnings=True)

def _to_transformed(self):
"Create a function from the untransformed space to the transformed space"
transformed_rvs = []
transformed_names = []

for rv in self.free_RVs:
transform = self.rvs_to_transforms.get(rv)
if transform is None:
transformed_rvs.append(rv)
transformed_names.append(rv.name)
else:
transformed_rv = transform.forward(rv, *rv.owner.inputs)
transformed_rvs.append(transformed_rv)
transformed_names.append(self.rvs_to_values[rv].name)

fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
return fn, transformed_names

def unmarginalize(self, rvs_to_unmarginalize):
for rv in rvs_to_unmarginalize:
self.marginalized_rvs.remove(rv)
if rv.name in self._marginalized_named_vars_to_dims:
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
else:
dims = None
self.register_rv(rv, name=rv.name, dims=dims)

def recover_marginals(
self,
idata,
var_names=None,
return_samples=True,
extend_inferencedata=True,
random_seed=None,
):
"""Computes posterior log-probabilities and samples of marginalized variables
conditioned on parameters of the model given InferenceData with posterior group
When there are multiple marginalized variables, each marginalized variable is
conditioned on both the parameters and the other variables still marginalized
All log-probabilities are within the transformed space
Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
return_samples : bool, default True
If True, also return samples of the marginalized variables
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
random_seed: int, array-like of int or SeedSequence, optional
Seed used to generating samples
Returns
-------
idata : InferenceData
InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
.. code-block:: python
import pymc as pm
from pymc_experimental import MarginalModel
with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
m.marginalize([x])
idata = pm.sample()
m.recover_marginals(idata, var_names=["x"])
"""
if var_names is None:
var_names = [var.name for var in self.marginalized_rvs]

var_names = [var if isinstance(var, str) else var.name for var in var_names]
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
if missing_names:
raise ValueError(f"Unrecognized var_names: {missing_names}")

if return_samples and random_seed is not None:
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
else:
seeds = [None] * len(vars_to_recover)

posterior = idata.posterior

# Remove Deterministics
posterior_values = posterior[
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
]

sample_dims = ("chain", "draw")
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)

# Handle Transforms
transform_fn, transform_names = self._to_transformed()

def transform_input(inputs):
return dict(zip(transform_names, transform_fn(inputs)))

posterior_pts = [transform_input(vs) for vs in posterior_pts]

rv_dict = {}
rv_dims = {}
for seed, rv in zip(seeds, vars_to_recover):
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
if not isinstance(rv.owner.op, supported_dists):
raise NotImplementedError(
f"RV with distribution {rv.owner.op} cannot be recovered. "
f"Supported distribution include {supported_dists}"
)

m = self.clone()
rv = m.vars_to_clone[rv]
m.unmarginalize([rv])
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)

marginalized_value = m.rvs_to_values[rv]
other_values = [v for v in m.value_vars if v is not marginalized_value]

# Handle batch dims for marginalized value and its dependent RVs
joint_logp = joint_logps[-1]
for dv in joint_logps[:-1]:
dbcast = dv.type.broadcastable
mbcast = marginalized_value.type.broadcastable
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
values_axis_bcast = [
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
]
joint_logp += dv.sum(values_axis_bcast)

rv_shape = constant_fold(tuple(rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(rv)
rv_domain_tensor = pt.moveaxis(
pt.full(
(*rv_shape, len(rv_domain)),
rv_domain,
dtype=rv.dtype,
),
-1,
0,
)

joint_logps = vectorize_graph(
joint_logp,
replace={marginalized_value: rv_domain_tensor},
)
joint_logps = pt.moveaxis(joint_logps, 0, -1)

rv_loglike_fn = None
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
if isinstance(rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

rv_loglike_fn = compile_pymc(
inputs=other_values,
outputs=[joint_logps_norm, sample_rv_outs],
on_unused_input="ignore",
random_seed=seed,
)
else:
rv_loglike_fn = compile_pymc(
inputs=other_values,
outputs=joint_logps_norm,
on_unused_input="ignore",
random_seed=seed,
)

logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]

logps = None
samples = None
if return_samples:
logps, samples = zip(*logvs)
logps = np.array(logps)
samples = np.array(samples)
rv_dict[rv.name] = samples.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
)
else:
logps = np.array(logvs)

rv_dict["lp_" + rv.name] = logps.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
)
if rv.name in m.named_vars_to_dims:
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]

coords, dims = coords_and_dims_for_inferencedata(self)
dims.update(rv_dims)
rv_dataset = dict_to_dataset(
rv_dict,
library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)

if extend_inferencedata:
idata.posterior = idata.posterior.assign(rv_dataset)
return idata
else:
return rv_dataset


class MarginalRV(SymbolicRandomVariable):
"""Base class for Marginalized RVs"""
Expand Down Expand Up @@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
marginalized_rv_domain_tensor = pt.swapaxes(
marginalized_rv_domain_tensor = pt.moveaxis(
pt.full(
(*marginalized_rv_shape, len(marginalized_rv_domain)),
marginalized_rv_domain,
dtype=marginalized_rv.dtype,
),
axis1=0,
axis2=-1,
-1,
0,
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
Expand Down
Loading

0 comments on commit 4f75687

Please sign in to comment.