Skip to content

Commit

Permalink
.WIP refactor MarginalModel
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 4, 2024
1 parent d447e0e commit 6eb8af1
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 614 deletions.
56 changes: 56 additions & 0 deletions pymc_experimental/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytensor.tensor as pt

from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.distribution import _support_point, support_point
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
Expand All @@ -15,6 +16,7 @@
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType

from pymc_experimental.distributions import DiscreteMarkovChain

Expand Down Expand Up @@ -44,6 +46,60 @@ def support_axes(self) -> tuple[tuple[int]]:
return tuple(support_axes_vars)


@_support_point.register
def support_point_marginal_rv(op: MarginalRV, rv, *inputs):
"""Support point for a marginalized RV.
The support point of a marginalized RV is the support point of the inner RV,
conditioned on the marginalized RV taking its support point.
"""
outputs = rv.owner.outputs
rv_idx = outputs.index(rv)
inner_rv = op.inner_outputs[rv_idx]
other_inner_rvs = [
out
for out in op.inner_outputs
if not isinstance(out.type, RandomType) and out is not inner_rv
]

# Replace references to inner rvs by the dummy variables (including the marginalized RV)
# This is necessary because the inner RVs may depend on each other
other_inner_rv_to_dummies = {
other_inner_rv: other_inner_rv.clone() for other_inner_rv in other_inner_rvs
}
inner_rv = clone_replace(inner_rv, other_inner_rv_to_dummies)
inner_rv_support_point = support_point(inner_rv)

# Replace the dummy marginalized RV by the support point of the marginalized RV
marginalized_rv = other_inner_rvs[0]
marginalized_rv_support_point = support_point(marginalized_rv)
dummy_marginalized_rv = other_inner_rv_to_dummies[marginalized_rv]
inner_rv_support_point = clone_replace(
inner_rv_support_point,
{dummy_marginalized_rv: marginalized_rv_support_point},
)

# Replace the remaining dummy variables by outer RVs
rv_support_point = graph_replace(
inner_rv_support_point,
replace={
v: outputs[op.inner_outputs.index(k)]
for k, v in other_inner_rv_to_dummies.items()
if k is not marginalized_rv
},
strict=False,
)

# Make it a function of any remaining outer inputs
rv_support_point = graph_replace(
rv_support_point,
replace=tuple(zip(op.inner_inputs, inputs)),
strict=False,
)

return rv_support_point


class MarginalFiniteDiscreteRV(MarginalRV):
"""Base class for Marginalized Finite Discrete RVs"""

Expand Down
10 changes: 7 additions & 3 deletions pymc_experimental/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import zip_longest

from pymc import SymbolicRandomVariable
from pymc.model.fgraph import ModelVar
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, Variable, ancestors
from pytensor.graph.basic import io_toposort
Expand Down Expand Up @@ -35,12 +36,12 @@ def static_shape_ancestors(vars):

def find_conditional_input_rvs(output_rvs, all_rvs):
"""Find conditionally indepedent input RVs."""
blockers = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers += static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
other_rvs = [other_rv for other_rv in all_rvs if other_rv not in output_rvs]
blockers = other_rvs + static_shape_ancestors(tuple(all_rvs) + tuple(output_rvs))
return [
var
for var in ancestors(output_rvs, blockers=blockers)
if var in blockers or (var.owner is None and not isinstance(var, Constant | SharedVariable))
if var in other_rvs
]


Expand Down Expand Up @@ -141,6 +142,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
# None of the inputs are related to the batch_axes of the input_vars
continue

elif isinstance(node.op, ModelVar):
var_dims[node.outputs[0]] = inputs_dims[0]

elif isinstance(node.op, DimShuffle):
[input_dims] = inputs_dims
output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order)
Expand Down
Loading

0 comments on commit 6eb8af1

Please sign in to comment.