1
+ import warnings
2
+
1
3
from collections .abc import Sequence
2
4
3
5
import numpy as np
4
6
import pytensor .tensor as pt
5
7
6
8
from pymc .distributions import Bernoulli , Categorical , DiscreteUniform
9
+ from pymc .distributions .distribution import _support_point , support_point
7
10
from pymc .logprob .abstract import MeasurableOp , _logprob
8
11
from pymc .logprob .basic import conditional_logp , logp
9
- from pymc .pytensorf import constant_fold
12
+ from pymc .model .fgraph import ModelVar
13
+ from pymc .pytensorf import constant_fold , StringType
10
14
from pytensor import Variable
11
15
from pytensor .compile .builders import OpFromGraph
12
16
from pytensor .compile .mode import Mode
13
- from pytensor .graph import Op , vectorize_graph
17
+ from pytensor .graph import FunctionGraph , Op , vectorize_graph
18
+ from pytensor .graph .basic import equal_computations , Apply
14
19
from pytensor .graph .replace import clone_replace , graph_replace
15
20
from pytensor .scan import map as scan_map
16
21
from pytensor .scan import scan
17
22
from pytensor .tensor import TensorVariable
23
+ from pytensor .tensor .random .type import RandomType
18
24
19
25
from pymc_extras .distributions import DiscreteMarkovChain
20
26
21
27
22
28
class MarginalRV (OpFromGraph , MeasurableOp ):
23
29
"""Base class for Marginalized RVs"""
24
30
25
- def __init__ (self , * args , dims_connections : tuple [tuple [int | None ]], ** kwargs ) -> None :
31
+ def __init__ (self , * args , dims_connections : tuple [tuple [int | None ], ...], dims : tuple [ Variable , ... ], ** kwargs ) -> None :
26
32
self .dims_connections = dims_connections
33
+ self .dims = dims
27
34
super ().__init__ (* args , ** kwargs )
28
35
29
36
@property
@@ -43,6 +50,74 @@ def support_axes(self) -> tuple[tuple[int]]:
43
50
)
44
51
return tuple (support_axes_vars )
45
52
53
+ def __eq__ (self , other ):
54
+ # Just to allow easy testing of equivalent models,
55
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
56
+ if type (self ) is not type (other ):
57
+ return False
58
+
59
+ return equal_computations (
60
+ self .inner_outputs ,
61
+ other .inner_outputs ,
62
+ self .inner_inputs ,
63
+ other .inner_inputs ,
64
+ )
65
+
66
+ def __hash__ (self ):
67
+ # Just to allow easy testing of equivalent models,
68
+ # This can be removed once https://github.com/pymc-devs/pytensor/issues/1114 is fixed
69
+ return hash ((type (self ), len (self .inner_inputs ), len (self .inner_outputs )))
70
+
71
+
72
+ @_support_point .register
73
+ def support_point_marginal_rv (op : MarginalRV , rv , * inputs ):
74
+ """Support point for a marginalized RV.
75
+
76
+ The support point of a marginalized RV is the support point of the inner RV,
77
+ conditioned on the marginalized RV taking its support point.
78
+ """
79
+ outputs = rv .owner .outputs
80
+
81
+ inner_rv = op .inner_outputs [outputs .index (rv )]
82
+ marginalized_inner_rv , * other_dependent_inner_rvs = (
83
+ out
84
+ for out in op .inner_outputs
85
+ if out is not inner_rv and not isinstance (out .type , RandomType )
86
+ )
87
+
88
+ # Replace references to inner rvs by the dummy variables (including the marginalized RV)
89
+ # This is necessary because the inner RVs may depend on each other
90
+ marginalized_inner_rv_dummy = marginalized_inner_rv .clone ()
91
+ other_dependent_inner_rv_to_dummies = {
92
+ inner_rv : inner_rv .clone () for inner_rv in other_dependent_inner_rvs
93
+ }
94
+ inner_rv = clone_replace (
95
+ inner_rv ,
96
+ replace = {marginalized_inner_rv : marginalized_inner_rv_dummy }
97
+ | other_dependent_inner_rv_to_dummies ,
98
+ )
99
+
100
+ # Get support point of inner RV and marginalized RV
101
+ inner_rv_support_point = support_point (inner_rv )
102
+ marginalized_inner_rv_support_point = support_point (marginalized_inner_rv )
103
+
104
+ replacements = [
105
+ # Replace the marginalized RV dummy by its support point
106
+ (marginalized_inner_rv_dummy , marginalized_inner_rv_support_point ),
107
+ # Replace other dependent RVs dummies by the respective outer outputs.
108
+ # PyMC will replace them by their support points later
109
+ * (
110
+ (v , outputs [op .inner_outputs .index (k )])
111
+ for k , v in other_dependent_inner_rv_to_dummies .items ()
112
+ ),
113
+ # Replace outer input RVs
114
+ * zip (op .inner_inputs , inputs ),
115
+ ]
116
+ fgraph = FunctionGraph (outputs = [inner_rv_support_point ], clone = False )
117
+ fgraph .replace_all (replacements , import_missing = True )
118
+ [rv_support_point ] = fgraph .outputs
119
+ return rv_support_point
120
+
46
121
47
122
class MarginalFiniteDiscreteRV (MarginalRV ):
48
123
"""Base class for Marginalized Finite Discrete RVs"""
@@ -132,12 +207,27 @@ def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Var
132
207
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133
208
the inner graph.
134
209
"""
135
- return clone_replace (
210
+ return graph_replace (
136
211
op .inner_outputs ,
137
212
replace = tuple (zip (op .inner_inputs , inputs )),
213
+ strict = False ,
138
214
)
139
215
140
216
217
+ class NonSeparableLogpWarning (UserWarning ):
218
+ pass
219
+
220
+
221
+ def warn_non_separable_logp (values ):
222
+ if len (values ) > 1 :
223
+ warnings .warn (
224
+ "There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
225
+ f"Their joint logp terms will be assigned to the first value: { values [0 ]} ." ,
226
+ NonSeparableLogpWarning ,
227
+ stacklevel = 2 ,
228
+ )
229
+
230
+
141
231
DUMMY_ZERO = pt .constant (0 , name = "dummy_zero" )
142
232
143
233
@@ -199,6 +289,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
199
289
# Align logp with non-collapsed batch dimensions of first RV
200
290
joint_logp = align_logp_dims (dims = op .dims_connections [0 ], logp = joint_logp )
201
291
292
+ warn_non_separable_logp (values )
202
293
# We have to add dummy logps for the remaining value variables, otherwise PyMC will raise
203
294
dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
204
295
return joint_logp , * dummy_logps
@@ -272,5 +363,6 @@ def step_alpha(logp_emission, log_alpha, log_P):
272
363
273
364
# If there are multiple emission streams, we have to add dummy logps for the remaining value variables. The first
274
365
# return is the joint probability of everything together, but PyMC still expects one logp for each emission stream.
366
+ warn_non_separable_logp (values )
275
367
dummy_logps = (DUMMY_ZERO ,) * (len (values ) - 1 )
276
368
return joint_logp , * dummy_logps
0 commit comments