Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement utility to recover marginalized variables from MarginalModel #285

Merged
merged 6 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 249 additions & 9 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,37 @@
from typing import Sequence, Tuple, Union

import numpy as np
import pymc
import pytensor.tensor as pt
from arviz import dict_to_dataset
from pymc import SymbolicRandomVariable
from pymc.backends.arviz import coords_and_dims_for_inferencedata
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import constant_fold, inputvars
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
from pytensor import Mode
from pytensor.compile import SharedVariable
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
from pytensor.graph import (
Constant,
FunctionGraph,
ancestors,
clone_replace,
vectorize_graph,
)
from pytensor.scan import map as scan_map
from pytensor.tensor import TensorVariable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Shape
from pytensor.tensor.special import log_softmax

__all__ = ["MarginalModel"]

from pytensor.tensor.shape import Shape


class MarginalModel(Model):
"""Subclass of PyMC Model that implements functionality for automatic
Expand Down Expand Up @@ -74,6 +84,7 @@ class MarginalModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.marginalized_rvs = []
self._marginalized_named_vars_to_dims = treedict()

def _delete_rv_mappings(self, rv: TensorVariable) -> None:
"""Remove all model mappings referring to rv
Expand Down Expand Up @@ -205,8 +216,9 @@ def clone(self):
vars = self.basic_RVs + self.potentials + self.deterministics + self.marginalized_rvs
cloned_vars = clone_replace(vars)
vars_to_clone = {var: cloned_var for var, cloned_var in zip(vars, cloned_vars)}
m.vars_to_clone = vars_to_clone

m.named_vars = {name: vars_to_clone[var] for name, var in self.named_vars.items()}
m.named_vars = treedict({name: vars_to_clone[var] for name, var in self.named_vars.items()})
m.named_vars_to_dims = self.named_vars_to_dims
m.values_to_rvs = {i: vars_to_clone[rv] for i, rv in self.values_to_rvs.items()}
m.rvs_to_values = {vars_to_clone[rv]: i for rv, i in self.rvs_to_values.items()}
Expand All @@ -220,11 +232,18 @@ def clone(self):
m.deterministics = [vars_to_clone[det] for det in self.deterministics]

m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
m._marginalized_named_vars_to_dims = self._marginalized_named_vars_to_dims
return m

def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
def marginalize(
self,
rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]],
):
if not isinstance(rvs_to_marginalize, Sequence):
rvs_to_marginalize = (rvs_to_marginalize,)
rvs_to_marginalize = [
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
]

supported_dists = (Bernoulli, Categorical, DiscreteUniform)
for rv_to_marginalize in rvs_to_marginalize:
Expand All @@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
f"Supported distribution include {supported_dists}"
)

if rv_to_marginalize.name in self.named_vars_to_dims:
dims = self.named_vars_to_dims[rv_to_marginalize.name]
self._marginalized_named_vars_to_dims[rv_to_marginalize.name] = dims

self._delete_rv_mappings(rv_to_marginalize)
self.marginalized_rvs.append(rv_to_marginalize)

# Raise errors and warnings immediately
self.clone()._marginalize(user_warnings=True)

def _to_transformed(self):
"Create a function from the untransformed space to the transformed space"
transformed_rvs = []
transformed_names = []

for rv in self.free_RVs:
transform = self.rvs_to_transforms.get(rv)
if transform is None:
transformed_rvs.append(rv)
transformed_names.append(rv.name)
else:
transformed_rv = transform.forward(rv, *rv.owner.inputs)
transformed_rvs.append(transformed_rv)
transformed_names.append(self.rvs_to_values[rv].name)

fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
return fn, transformed_names

def unmarginalize(self, rvs_to_unmarginalize):
for rv in rvs_to_unmarginalize:
self.marginalized_rvs.remove(rv)
if rv.name in self._marginalized_named_vars_to_dims:
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
else:
dims = None
self.register_rv(rv, name=rv.name, dims=dims)

def recover_marginals(
self,
idata,
var_names=None,
return_samples=True,
extend_inferencedata=True,
random_seed=None,
):
"""Computes posterior log-probabilities and samples of marginalized variables
conditioned on parameters of the model given InferenceData with posterior group

When there are multiple marginalized variables, each marginalized variable is
conditioned on both the parameters and the other variables still marginalized

All log-probabilities are within the transformed space

Parameters
----------
idata : InferenceData
InferenceData with posterior group
var_names : sequence of str, optional
List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
return_samples : bool, default True
If True, also return samples of the marginalized variables
extend_inferencedata : bool, default True
Whether to extend the original InferenceData or return a new one
random_seed: int, array-like of int or SeedSequence, optional
Seed used to generating samples

Returns
-------
idata : InferenceData
InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group

.. code-block:: python

import pymc as pm
from pymc_experimental import MarginalModel

with MarginalModel() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])

m.marginalize([x])

idata = pm.sample()
m.recover_marginals(idata, var_names=["x"])


"""
if var_names is None:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
var_names = [var.name for var in self.marginalized_rvs]

var_names = [var if isinstance(var, str) else var.name for var in var_names]
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
if missing_names:
raise ValueError(f"Unrecognized var_names: {missing_names}")

if return_samples and random_seed is not None:
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
else:
seeds = [None] * len(vars_to_recover)

posterior = idata.posterior

# Remove Deterministics
posterior_values = posterior[
[rv.name for rv in self.free_RVs if rv not in self.marginalized_rvs]
]

sample_dims = ("chain", "draw")
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)

# Handle Transforms
transform_fn, transform_names = self._to_transformed()

def transform_input(inputs):
return dict(zip(transform_names, transform_fn(inputs)))

posterior_pts = [transform_input(vs) for vs in posterior_pts]

rv_dict = {}
rv_dims = {}
for seed, rv in zip(seeds, vars_to_recover):
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
if not isinstance(rv.owner.op, supported_dists):
raise NotImplementedError(
f"RV with distribution {rv.owner.op} cannot be recovered. "
f"Supported distribution include {supported_dists}"
)

m = self.clone()
rv = m.vars_to_clone[rv]
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
m.unmarginalize([rv])
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)

marginalized_value = m.rvs_to_values[rv]
other_values = [v for v in m.value_vars if v is not marginalized_value]

# Handle batch dims for marginalized value and its dependent RVs
joint_logp = joint_logps[-1]
for dv in joint_logps[:-1]:
dbcast = dv.type.broadcastable
mbcast = marginalized_value.type.broadcastable
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
values_axis_bcast = [
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
]
joint_logp += dv.sum(values_axis_bcast)

rv_shape = constant_fold(tuple(rv.shape))
rv_domain = get_domain_of_finite_discrete_rv(rv)
rv_domain_tensor = pt.moveaxis(
pt.full(
(*rv_shape, len(rv_domain)),
rv_domain,
dtype=rv.dtype,
),
-1,
0,
)

joint_logps = vectorize_graph(
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
joint_logp,
replace={marginalized_value: rv_domain_tensor},
)
joint_logps = pt.moveaxis(joint_logps, 0, -1)

rv_loglike_fn = None
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
if isinstance(rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

rv_loglike_fn = compile_pymc(
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
inputs=other_values,
outputs=[joint_logps_norm, sample_rv_outs],
on_unused_input="ignore",
random_seed=seed,
)
else:
rv_loglike_fn = compile_pymc(
inputs=other_values,
outputs=joint_logps_norm,
on_unused_input="ignore",
random_seed=seed,
)

logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]

logps = None
samples = None
if return_samples:
logps, samples = zip(*logvs)
logps = np.array(logps)
zaxtax marked this conversation as resolved.
Show resolved Hide resolved
samples = np.array(samples)
rv_dict[rv.name] = samples.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
)
else:
logps = np.array(logvs)

rv_dict["lp_" + rv.name] = logps.reshape(
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
)
if rv.name in m.named_vars_to_dims:
rv_dims[rv.name] = list(m.named_vars_to_dims[rv.name])
rv_dims["lp_" + rv.name] = rv_dims[rv.name] + ["lp_" + rv.name + "_dim"]

coords, dims = coords_and_dims_for_inferencedata(self)
dims.update(rv_dims)
rv_dataset = dict_to_dataset(
rv_dict,
library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)

if extend_inferencedata:
idata.posterior = idata.posterior.assign(rv_dataset)
return idata
else:
return rv_dataset


class MarginalRV(SymbolicRandomVariable):
"""Base class for Marginalized RVs"""
Expand Down Expand Up @@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
marginalized_rv_domain_tensor = pt.swapaxes(
marginalized_rv_domain_tensor = pt.moveaxis(
pt.full(
(*marginalized_rv_shape, len(marginalized_rv_domain)),
marginalized_rv_domain,
dtype=marginalized_rv.dtype,
),
axis1=0,
axis2=-1,
-1,
0,
)

# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
Expand Down
Loading
Loading