Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit fd2a534

Browse files
Track valued/bound variables using an in-graph Op
1 parent d0c009d commit fd2a534

16 files changed

+383
-385
lines changed

aeppl/abstract.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
from functools import singledispatch
44
from typing import Callable, List, Tuple
55

6+
import aesara.tensor as at
7+
from aesara.gradient import grad_undefined
68
from aesara.graph.basic import Apply, Variable
79
from aesara.graph.op import Op
810
from aesara.graph.utils import MetaType
911
from aesara.tensor.elemwise import Elemwise
1012
from aesara.tensor.random.op import RandomVariable
13+
from aesara.tensor.type import TensorType
1114

1215

1316
class MeasurableVariable(abc.ABC):
@@ -134,3 +137,55 @@ def __init__(self, scalar_op, *args, **kwargs):
134137

135138

136139
MeasurableVariable.register(MeasurableElemwise)
140+
141+
142+
class ValuedVariable(Op):
143+
r"""Represents the association of a measurable variable and its value.
144+
145+
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
146+
:math:`Y` is a random variable and :math:`y \sim Y`.
147+
148+
Log-probability (densities) are functions over these pairs, which makes
149+
these nodes in a graph an intermediate form that serves to construct a
150+
log-probability from a model graph.
151+
152+
This intermediate form can be used as the target for rewrites that
153+
otherwise wouldn't make sense to apply to--say--a random variable node
154+
directly. An example is `BroadcastTo` lifting through `RandomVariable`\s.
155+
"""
156+
157+
default_output = 0
158+
view_map = {0: [0]}
159+
160+
def make_node(self, rv, value):
161+
162+
assert isinstance(rv.type, TensorType)
163+
out_rv = rv.type()
164+
165+
vv = at.as_tensor_variable(value)
166+
assert isinstance(vv.type, TensorType)
167+
168+
# TODO: We should probably check the `Type`s of `out_rv` and `vv`
169+
if vv.type.dtype != rv.type.dtype:
170+
raise TypeError(
171+
f"Value type {vv.type} does not match random variable type {out_rv.type}"
172+
)
173+
174+
return Apply(self, [rv, vv], [out_rv])
175+
176+
def perform(self, node, inputs, out):
177+
out[0][0] = inputs[0]
178+
179+
def grad(self, inputs, outputs):
180+
return [
181+
grad_undefined(self, k, inp, "No gradient defined for `ValuedVariable`")
182+
for k, inp in enumerate(inputs)
183+
]
184+
185+
def infer_shape(self, fgraph, node, input_shapes):
186+
return [input_shapes[0]]
187+
188+
189+
MeasurableVariable.register(ValuedVariable)
190+
191+
valued_variable = ValuedVariable()

aeppl/censoring.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from aeppl.abstract import (
1414
MeasurableElemwise,
1515
MeasurableVariable,
16+
ValuedVariable,
1617
assign_custom_measurable_outputs,
1718
)
1819
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
@@ -34,10 +35,6 @@ def find_measurable_clips(
3435
) -> Optional[List[MeasurableClip]]:
3536
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
3637

37-
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
38-
if rv_map_feature is None:
39-
return None # pragma: no cover
40-
4138
if isinstance(node.op, MeasurableClip):
4239
return None # pragma: no cover
4340

@@ -50,7 +47,7 @@ def find_measurable_clips(
5047
if not (
5148
base_var.owner
5249
and isinstance(base_var.owner.op, MeasurableVariable)
53-
and base_var not in rv_map_feature.rv_values
50+
and not isinstance(base_var, ValuedVariable)
5451
):
5552
return None
5653

@@ -155,10 +152,6 @@ def find_measurable_roundings(
155152
fgraph: FunctionGraph, node: Node
156153
) -> Optional[List[MeasurableRound]]:
157154

158-
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
159-
if rv_map_feature is None:
160-
return None # pragma: no cover
161-
162155
if isinstance(node.op, MeasurableRound):
163156
return None # pragma: no cover
164157

@@ -174,7 +167,7 @@ def find_measurable_roundings(
174167
if not (
175168
base_var.owner
176169
and isinstance(base_var.owner.op, MeasurableVariable)
177-
and base_var not in rv_map_feature.rv_values
170+
and not isinstance(base_var, ValuedVariable)
178171
# Rounding only makes sense for continuous variables
179172
and base_var.dtype.startswith("float")
180173
):

aeppl/cumsum.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
from aesara.graph.rewriting.basic import node_rewriter
55
from aesara.tensor.extra_ops import CumOp
66

7-
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
7+
from aeppl.abstract import (
8+
MeasurableVariable,
9+
ValuedVariable,
10+
assign_custom_measurable_outputs,
11+
)
812
from aeppl.logprob import _logprob, logprob
9-
from aeppl.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
13+
from aeppl.rewriting import measurable_ir_rewrites_db
1014

1115

1216
class MeasurableCumsum(CumOp):
@@ -50,20 +54,13 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
5054
if isinstance(node.op, MeasurableCumsum):
5155
return None # pragma: no cover
5256

53-
rv_map_feature: Optional[PreserveRVMappings] = getattr(
54-
fgraph, "preserve_rv_mappings", None
55-
)
56-
57-
if rv_map_feature is None:
58-
return None # pragma: no cover
59-
6057
rv = node.outputs[0]
6158

6259
base_rv = node.inputs[0]
6360
if not (
6461
base_rv.owner
6562
and isinstance(base_rv.owner.op, MeasurableVariable)
66-
and base_rv not in rv_map_feature.rv_values
63+
and not isinstance(base_rv, ValuedVariable)
6764
):
6865
return None # pragma: no cover
6966

aeppl/joint_logprob.py

Lines changed: 44 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
import warnings
2-
from collections import deque
31
from typing import Dict, List, Optional, Tuple, Union
42

53
import aesara.tensor as at
6-
from aesara import config
7-
from aesara.graph.basic import graph_inputs, io_toposort
8-
from aesara.graph.op import compute_test_value
94
from aesara.graph.rewriting.basic import GraphRewriter, NodeRewriter
105
from aesara.tensor.var import TensorVariable
116

12-
from aeppl.abstract import get_measurable_outputs
7+
from aeppl.abstract import ValuedVariable, get_measurable_outputs
138
from aeppl.logprob import _logprob
149
from aeppl.rewriting import construct_ir_fgraph
15-
from aeppl.utils import rvs_to_value_vars
1610

1711

1812
def conditional_logprob(
@@ -22,7 +16,7 @@ def conditional_logprob(
2216
ir_rewriter: Optional[GraphRewriter] = None,
2317
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
2418
**kwargs,
25-
) -> Tuple[Dict[TensorVariable, TensorVariable], List[TensorVariable]]:
19+
) -> Tuple[Dict[TensorVariable, TensorVariable], Tuple[TensorVariable, ...]]:
2620
r"""Create a map between random variables and their conditional log-probabilities.
2721
2822
The list of measurable variables implicitly defines a joint probability that
@@ -106,133 +100,71 @@ def conditional_logprob(
106100
# graphs. We can thus use them to recover the original random variables to index the
107101
# maps to the logprob graphs and value variables before returning them.
108102
rv_values = {**original_rv_values, **realized}
109-
vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}
103+
# vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}
110104

111-
fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
105+
fgraph, _, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
112106

113-
# The interface for transformations assumes that the value variables are in
114-
# the transformed space. To get the correct `shape` and `dtype` for the
115-
# value variables we return we need to apply the forward transformation to
116-
# our RV copies, and return the type of the resulting variable as a value
117-
# variable.
118-
vv_remapper = {}
119107
if extra_rewrites is not None:
120-
extra_rewrites.add_requirements(fgraph, {**original_rv_values, **realized})
108+
extra_rewrites.add_requirements(fgraph, rv_values)
121109
extra_rewrites.apply(fgraph)
122-
vv_remapper = fgraph.values_to_untransformed
123-
124-
rv_remapper = fgraph.preserve_rv_mappings
125-
126-
# This is the updated random-to-value-vars map with the lifted/rewritten
127-
# variables. The rewrites are supposed to produce new
128-
# `MeasurableVariable`s that are amenable to `_logprob`.
129-
updated_rv_values = rv_remapper.rv_values
130-
131-
# Some rewrites also transform the original value variables. This is the
132-
# updated map from the new value variables to the original ones, which
133-
# we want to use as the keys in the final dictionary output
134-
original_values = rv_remapper.original_values
135-
136-
# When a `_logprob` has been produced for a `MeasurableVariable` node, all
137-
# other references to it need to be replaced with its value-variable all
138-
# throughout the `_logprob`-produced graphs. The following `dict`
139-
# cumulatively maintains remappings for all the variables/nodes that needed
140-
# to be recreated after replacing `MeasurableVariable`s with their
141-
# value-variables. Since these replacements work in topological order, all
142-
# the necessary value-variable replacements should be present for each
143-
# node.
144-
replacements = updated_rv_values.copy()
145-
146-
# To avoid cloning the value variables, we map them to themselves in the
147-
# `replacements` `dict` (i.e. entries already existing in `replacements`
148-
# aren't cloned)
149-
replacements.update({v: v for v in rv_values.values()})
150-
151-
# Walk the graph from its inputs to its outputs and construct the
152-
# log-probability
153-
q = deque(fgraph.toposort())
154110

155111
logprob_vars = {}
156-
value_variables = {}
157112

158-
while q:
159-
node = q.popleft()
113+
for out, old_out in zip(fgraph.outputs, rv_values.keys()):
114+
node = out.owner
160115

161-
outputs = get_measurable_outputs(node.op, node)
162-
if not outputs:
163-
continue
164-
165-
if any(o not in updated_rv_values for o in outputs):
166-
if warn_missing_rvs:
167-
warnings.warn(
168-
"Found a random variable that is not assigned a value variable: "
169-
f"{node.outputs}"
170-
)
171-
continue
172-
173-
q_value_vars = [replacements[q_rv_var] for q_rv_var in outputs]
174-
175-
if not q_value_vars:
176-
continue
177-
178-
# Replace `RandomVariable`s in the inputs with value variables.
179-
# Also, store the results in the `replacements` map for the nodes
180-
# that follow.
181-
remapped_vars, _ = rvs_to_value_vars(
182-
q_value_vars + list(node.inputs),
183-
initial_replacements=replacements,
184-
)
185-
q_value_vars = remapped_vars[: len(q_value_vars)]
186-
q_rv_inputs = remapped_vars[len(q_value_vars) :]
187-
188-
q_logprob_vars = _logprob(
189-
node.op,
190-
q_value_vars,
191-
*q_rv_inputs,
192-
**kwargs,
193-
)
116+
assert isinstance(node.op, ValuedVariable)
194117

195-
if not isinstance(q_logprob_vars, (list, tuple)):
196-
q_logprob_vars = [q_logprob_vars]
118+
rv_var, val_var = node.inputs
197119

198-
for q_value_var, q_logprob_var in zip(q_value_vars, q_logprob_vars):
120+
rv_node = rv_var.owner
121+
outputs = get_measurable_outputs(rv_node.op, rv_node)
199122

200-
q_value_var = original_values[q_value_var]
201-
q_rv = vv_to_original_rvs[q_value_var]
202-
203-
if q_rv.name:
204-
q_logprob_var.name = f"{q_rv.name}_logprob"
123+
if not outputs:
124+
raise ValueError(f"Couldn't derive a log-probability for {out}")
125+
126+
# TODO: This probably needs to be done outside of this loop.
127+
# if warn_missing_rvs:
128+
# warnings.warn(
129+
# "Found a random variable that is not assigned a value variable: "
130+
# f"{node.outputs}"
131+
# )
132+
rv_logprob = _logprob(
133+
rv_node.op,
134+
[val_var],
135+
*rv_node.inputs,
136+
**kwargs,
137+
)
205138

206-
if q_rv in logprob_vars:
207-
raise ValueError(
208-
f"More than one logprob factor was assigned to the random variable {q_rv}"
209-
)
139+
if isinstance(rv_logprob, (tuple, list)):
140+
(rv_logprob,) = rv_logprob
210141

211-
logprob_vars[q_rv] = q_logprob_var
142+
if old_out.name:
143+
rv_logprob.name = f"{old_out.name}_logprob"
212144

213-
q_value_var = vv_remapper.get(q_value_var, q_value_var)
214-
value_variables[q_rv] = q_value_var
145+
logprob_vars[old_out] = rv_logprob
215146

216-
# Recompute test values for the changes introduced by the
217-
# replacements above.
218-
if config.compute_test_value != "off":
219-
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
220-
compute_test_value(node)
147+
# # Recompute test values for the changes introduced by the
148+
# # replacements above.
149+
# if config.compute_test_value != "off":
150+
# for node in io_toposort(graph_inputs([rv_logprob]), q_logprob_vars):
151+
# compute_test_value(node)
221152

222-
missing_value_terms = set(vv_to_original_rvs.values()) - set(logprob_vars.keys())
223-
if missing_value_terms:
224-
raise RuntimeError(
225-
f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
226-
)
153+
# missing_value_terms = set(vv_to_original_rvs.values()) - set(logprob_vars.keys())
154+
# if missing_value_terms:
155+
# raise RuntimeError(
156+
# f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
157+
# )
227158

228-
return logprob_vars, [value_variables[rv] for rv in original_rv_values.keys()]
159+
value_vars = tuple(memo[vv] for rv, vv in rv_values.items() if rv not in realized)
160+
return logprob_vars, value_vars
229161

230162

231163
def joint_logprob(
232164
*random_variables: List[TensorVariable],
233165
realized: Dict[TensorVariable, TensorVariable] = {},
234166
**kwargs,
235-
) -> Optional[Tuple[TensorVariable, List[TensorVariable]]]:
167+
) -> Optional[Tuple[TensorVariable, Tuple[TensorVariable, ...]]]:
236168
"""Create a graph representing the joint log-probability/measure of a graph.
237169
238170
This function calls `factorized_joint_logprob` and returns the combined

0 commit comments

Comments
 (0)