Skip to content

Commit

Permalink
Implement utility to change value variable transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 17, 2023
1 parent 430c3c8 commit e3f5828
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 3 deletions.
11 changes: 11 additions & 0 deletions pymc_experimental/model_transform/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Sequence, Union

from pymc import Model
from pytensor import Variable
from pytensor.graph import ancestors

from pymc_experimental.utils.model_fgraph import (
Expand All @@ -8,6 +11,8 @@
model_from_fgraph,
)

ModelVariable = Union[Variable, str]


def prune_vars_detached_from_observed(model: Model) -> Model:
"""Prune model variables that are not related to any observed variable in the Model."""
Expand All @@ -33,3 +38,9 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
for node_to_remove in nodes_to_remove:
fgraph.remove_node(node_to_remove)
return model_from_fgraph(fgraph)


def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]:
if not isinstance(vars, (list, tuple)):
vars = (vars,)
return [model[var] if isinstance(var, str) else var for var in vars]
140 changes: 138 additions & 2 deletions pymc_experimental/model_transform/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from pymc import Model
from pymc.logprob.transforms import RVTransform
from pymc.pytensorf import _replace_vars_in_graphs
from pymc.util import get_transformed_name, get_untransformed_name
from pytensor.tensor import TensorVariable

from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
from pymc_experimental.model_transform.basic import (
ModelVariable,
parse_vars,
prune_vars_detached_from_observed,
)
from pymc_experimental.utils.model_fgraph import (
ModelDeterministic,
ModelFreeRV,
extract_dims,
fgraph_from_model,
model_deterministic,
model_free_rv,
model_from_fgraph,
model_named,
model_observed_rv,
Expand Down Expand Up @@ -206,3 +213,132 @@ def do(
if prune_vars:
return prune_vars_detached_from_observed(model)
return model


def change_value_transforms(
model: Model,
vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]],
) -> Model:
"""Change the value variables transforms in the model
Parameters
----------
model: Model
vars_to_transforms: Dict
Mapping between RVs and new transforms to be applied to the respective value variables
Returns
-------
new_model: Model
Model with the updated transformed value variables
Examples
--------
Extract untransformed space Hessian after finding transformed space MAP
.. code-block:: python
import pymc as pm
from pymc.distributions.transforms import logodds
from pymc_experimental.model_transform.conditioning import change_value_transforms
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
w = pm.Binomial("w", n=9, p=p, observed=6)
with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
mean_q = pm.find_MAP()
with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
new_p = untransformed_p['p']
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")
# Mean, Standard deviation
# p 0.67, 0.16
"""
vars_to_transforms = {
parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items()
}

if set(vars_to_transforms.keys()) - set(model.free_RVs):
raise ValueError(f"All keys must be free variables in the model: {vars_to_transforms}")

fgraph, memo = fgraph_from_model(model)

vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()}
replacements = {}
for node in fgraph.apply_nodes:
if not isinstance(node.op, ModelFreeRV):
continue

[dummy_rv] = node.outputs
if dummy_rv not in vars_to_transforms:
continue

transform = vars_to_transforms[dummy_rv]

rv, value, *dims = node.inputs

new_value = rv.type()
try:
untransformed_name = get_untransformed_name(value.name)
except ValueError:
untransformed_name = value.name
if transform:
new_name = get_transformed_name(untransformed_name, transform)
else:
new_name = untransformed_name
new_value.name = new_name

new_dummy_rv = model_free_rv(rv, new_value, transform, *dims)
replacements[dummy_rv] = new_dummy_rv

toposort_replace(fgraph, tuple(replacements.items()))
return model_from_fgraph(fgraph)


def remove_value_transforms(
model: Model,
vars: Optional[Sequence[ModelVariable]] = None,
) -> Model:
"""Remove the value variables transforms in the model
Parameters
----------
model: Model
vars: Model variables, optional
Model variables for which to remove transforms. Defaults to all transformed variables
Returns
-------
new_model: Model
Model with the removed transformed value variables
Examples
--------
Extract untransformed space Hessian after finding transformed space MAP
.. code-block:: python
import pymc as pm
from pymc_experimental.model_transform.conditioning import remove_value_transforms
with pm.Model() as transformed_m:
p = pm.Uniform("p", 0, 1)
w = pm.Binomial("w", n=9, p=p, observed=6)
mean_q = pm.find_MAP()
with remove_value_transforms(transformed_m) as untransformed_m:
new_p = untransformed_m["p"]
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")
# Mean, Standard deviation
# p 0.67, 0.16
"""
if vars is None:
vars = model.free_RVs
return change_value_transforms(model, {var: None for var in vars})
64 changes: 63 additions & 1 deletion pymc_experimental/tests/model_transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
import numpy as np
import pymc as pm
import pytest
from pymc.distributions.transforms import logodds
from pymc.variational.minibatch_rv import create_minibatch_rv
from pytensor import config

from pymc_experimental.model_transform.conditioning import do, observe
from pymc_experimental.model_transform.conditioning import (
change_value_transforms,
do,
observe,
remove_value_transforms,
)


def test_observe():
Expand Down Expand Up @@ -214,3 +220,59 @@ def test_do_prune(prune):
assert set(do_m.named_vars) == {"x1", "z", "llike"}
else:
assert set(do_m.named_vars) == orig_named_vars


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, transform=None)
w = pm.Binomial("w", n=9, p=p, observed=6)
assert base_m.rvs_to_transforms == {p: None, w: None}

with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
new_p = transformed_p["p"]
new_w = transformed_p["w"]
assert transformed_p.rvs_to_transforms == {new_p: logodds, new_w: None}
mean_q = pm.find_MAP(progressbar=False)

with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
new_p = untransformed_p["p"]
new_w = untransformed_p["w"]
assert untransformed_p.rvs_to_transforms == {new_p: None, new_w: None}
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]

assert np.round(mean_q["p"], 2) == 0.67
assert np.round(std_q[0], 2) == 0.16


def test_change_value_transforms_error():
with pm.Model() as m:
x = pm.Uniform("x", observed=5.0)

with pytest.raises(ValueError, match="All keys must be free variables in the model"):
change_value_transforms(m, {x: logodds})


def test_remove_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", transform=logodds)
q = pm.Uniform("q", transform=logodds)

new_m = remove_value_transforms(base_m)
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}

new_m = remove_value_transforms(base_m, [p, q])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}

new_m = remove_value_transforms(base_m, [p])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds}

new_m = remove_value_transforms(base_m, ["q"])
new_p = new_m["p"]
new_q = new_m["q"]
assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}

0 comments on commit e3f5828

Please sign in to comment.