1- import warnings
2- from collections import deque
31from typing import Dict , List , Optional , Tuple , Union
42
53import aesara .tensor as at
6- from aesara import config
7- from aesara .graph .basic import graph_inputs , io_toposort
8- from aesara .graph .op import compute_test_value
94from aesara .graph .rewriting .basic import GraphRewriter , NodeRewriter
105from aesara .tensor .var import TensorVariable
116
12- from aeppl .abstract import get_measurable_outputs
7+ from aeppl .abstract import ValuedVariable , get_measurable_outputs
138from aeppl .logprob import _logprob
149from aeppl .rewriting import construct_ir_fgraph
15- from aeppl .utils import rvs_to_value_vars
1610
1711
1812def conditional_logprob (
@@ -22,7 +16,7 @@ def conditional_logprob(
2216 ir_rewriter : Optional [GraphRewriter ] = None ,
2317 extra_rewrites : Optional [Union [GraphRewriter , NodeRewriter ]] = None ,
2418 ** kwargs ,
25- ) -> Tuple [Dict [TensorVariable , TensorVariable ], List [TensorVariable ]]:
19+ ) -> Tuple [Dict [TensorVariable , TensorVariable ], Tuple [TensorVariable , ... ]]:
2620 r"""Create a map between random variables and their conditional log-probabilities.
2721
2822 The list of measurable variables implicitly defines a joint probability that
@@ -106,133 +100,71 @@ def conditional_logprob(
106100 # graphs. We can thus use them to recover the original random variables to index the
107101 # maps to the logprob graphs and value variables before returning them.
108102 rv_values = {** original_rv_values , ** realized }
109- vv_to_original_rvs = {vv : rv for rv , vv in rv_values .items ()}
103+ # vv_to_original_rvs = {vv: rv for rv, vv in rv_values.items()}
110104
111- fgraph , rv_values , _ = construct_ir_fgraph (rv_values , ir_rewriter = ir_rewriter )
105+ fgraph , _ , memo = construct_ir_fgraph (rv_values , ir_rewriter = ir_rewriter )
112106
113- # The interface for transformations assumes that the value variables are in
114- # the transformed space. To get the correct `shape` and `dtype` for the
115- # value variables we return we need to apply the forward transformation to
116- # our RV copies, and return the type of the resulting variable as a value
117- # variable.
118- vv_remapper = {}
119107 if extra_rewrites is not None :
120- extra_rewrites .add_requirements (fgraph , { ** original_rv_values , ** realized } )
108+ extra_rewrites .add_requirements (fgraph , rv_values )
121109 extra_rewrites .apply (fgraph )
122- vv_remapper = fgraph .values_to_untransformed
123-
124- rv_remapper = fgraph .preserve_rv_mappings
125-
126- # This is the updated random-to-value-vars map with the lifted/rewritten
127- # variables. The rewrites are supposed to produce new
128- # `MeasurableVariable`s that are amenable to `_logprob`.
129- updated_rv_values = rv_remapper .rv_values
130-
131- # Some rewrites also transform the original value variables. This is the
132- # updated map from the new value variables to the original ones, which
133- # we want to use as the keys in the final dictionary output
134- original_values = rv_remapper .original_values
135-
136- # When a `_logprob` has been produced for a `MeasurableVariable` node, all
137- # other references to it need to be replaced with its value-variable all
138- # throughout the `_logprob`-produced graphs. The following `dict`
139- # cumulatively maintains remappings for all the variables/nodes that needed
140- # to be recreated after replacing `MeasurableVariable`s with their
141- # value-variables. Since these replacements work in topological order, all
142- # the necessary value-variable replacements should be present for each
143- # node.
144- replacements = updated_rv_values .copy ()
145-
146- # To avoid cloning the value variables, we map them to themselves in the
147- # `replacements` `dict` (i.e. entries already existing in `replacements`
148- # aren't cloned)
149- replacements .update ({v : v for v in rv_values .values ()})
150-
151- # Walk the graph from its inputs to its outputs and construct the
152- # log-probability
153- q = deque (fgraph .toposort ())
154110
155111 logprob_vars = {}
156- value_variables = {}
157112
158- while q :
159- node = q . popleft ()
113+ for out , old_out in zip ( fgraph . outputs , rv_values . keys ()) :
114+ node = out . owner
160115
161- outputs = get_measurable_outputs (node .op , node )
162- if not outputs :
163- continue
164-
165- if any (o not in updated_rv_values for o in outputs ):
166- if warn_missing_rvs :
167- warnings .warn (
168- "Found a random variable that is not assigned a value variable: "
169- f"{ node .outputs } "
170- )
171- continue
172-
173- q_value_vars = [replacements [q_rv_var ] for q_rv_var in outputs ]
174-
175- if not q_value_vars :
176- continue
177-
178- # Replace `RandomVariable`s in the inputs with value variables.
179- # Also, store the results in the `replacements` map for the nodes
180- # that follow.
181- remapped_vars , _ = rvs_to_value_vars (
182- q_value_vars + list (node .inputs ),
183- initial_replacements = replacements ,
184- )
185- q_value_vars = remapped_vars [: len (q_value_vars )]
186- q_rv_inputs = remapped_vars [len (q_value_vars ) :]
187-
188- q_logprob_vars = _logprob (
189- node .op ,
190- q_value_vars ,
191- * q_rv_inputs ,
192- ** kwargs ,
193- )
116+ assert isinstance (node .op , ValuedVariable )
194117
195- if not isinstance (q_logprob_vars , (list , tuple )):
196- q_logprob_vars = [q_logprob_vars ]
118+ rv_var , val_var = node .inputs
197119
198- for q_value_var , q_logprob_var in zip (q_value_vars , q_logprob_vars ):
120+ rv_node = rv_var .owner
121+ outputs = get_measurable_outputs (rv_node .op , rv_node )
199122
200- q_value_var = original_values [q_value_var ]
201- q_rv = vv_to_original_rvs [q_value_var ]
202-
203- if q_rv .name :
204- q_logprob_var .name = f"{ q_rv .name } _logprob"
123+ if not outputs :
124+ raise ValueError (f"Couldn't derive a log-probability for { out } " )
125+
126+ # TODO: This probably needs to be done outside of this loop.
127+ # if warn_missing_rvs:
128+ # warnings.warn(
129+ # "Found a random variable that is not assigned a value variable: "
130+ # f"{node.outputs}"
131+ # )
132+ rv_logprob = _logprob (
133+ rv_node .op ,
134+ [val_var ],
135+ * rv_node .inputs ,
136+ ** kwargs ,
137+ )
205138
206- if q_rv in logprob_vars :
207- raise ValueError (
208- f"More than one logprob factor was assigned to the random variable { q_rv } "
209- )
139+ if isinstance (rv_logprob , (tuple , list )):
140+ (rv_logprob ,) = rv_logprob
210141
211- logprob_vars [q_rv ] = q_logprob_var
142+ if old_out .name :
143+ rv_logprob .name = f"{ old_out .name } _logprob"
212144
213- q_value_var = vv_remapper .get (q_value_var , q_value_var )
214- value_variables [q_rv ] = q_value_var
145+ logprob_vars [old_out ] = rv_logprob
215146
216- # Recompute test values for the changes introduced by the
217- # replacements above.
218- if config .compute_test_value != "off" :
219- for node in io_toposort (graph_inputs (q_logprob_vars ), q_logprob_vars ):
220- compute_test_value (node )
147+ # # Recompute test values for the changes introduced by the
148+ # # replacements above.
149+ # if config.compute_test_value != "off":
150+ # for node in io_toposort(graph_inputs([rv_logprob] ), q_logprob_vars):
151+ # compute_test_value(node)
221152
222- missing_value_terms = set (vv_to_original_rvs .values ()) - set (logprob_vars .keys ())
223- if missing_value_terms :
224- raise RuntimeError (
225- f"The logprob terms of the following random variables could not be derived: { missing_value_terms } "
226- )
153+ # missing_value_terms = set(vv_to_original_rvs.values()) - set(logprob_vars.keys())
154+ # if missing_value_terms:
155+ # raise RuntimeError(
156+ # f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
157+ # )
227158
228- return logprob_vars , [value_variables [rv ] for rv in original_rv_values .keys ()]
159+ value_vars = tuple (memo [vv ] for rv , vv in rv_values .items () if rv not in realized )
160+ return logprob_vars , value_vars
229161
230162
231163def joint_logprob (
232164 * random_variables : List [TensorVariable ],
233165 realized : Dict [TensorVariable , TensorVariable ] = {},
234166 ** kwargs ,
235- ) -> Optional [Tuple [TensorVariable , List [TensorVariable ]]]:
167+ ) -> Optional [Tuple [TensorVariable , Tuple [TensorVariable , ... ]]]:
236168 """Create a graph representing the joint log-probability/measure of a graph.
237169
238170 This function calls `factorized_joint_logprob` and returns the combined
0 commit comments