@@ -115,40 +115,37 @@ def simplex_cont_transform(op, rv):
115115
116116
117117def quaddist_matrix (cov = None , chol = None , tau = None , lower = True , * args , ** kwargs ):
118- if chol is not None and not lower :
119- chol = chol .T
120-
121118 if len ([i for i in [tau , cov , chol ] if i is not None ]) != 1 :
122119 raise ValueError ("Incompatible parameterization. Specify exactly one of tau, cov, or chol." )
123120
124121 if cov is not None :
125122 cov = pt .as_tensor_variable (cov )
126- if cov .ndim != 2 :
127- raise ValueError ("cov must be two dimensional." )
123+ if cov .ndim < 2 :
124+ raise ValueError ("cov must be at least two dimensional." )
128125 elif tau is not None :
129126 tau = pt .as_tensor_variable (tau )
130- if tau .ndim != 2 :
131- raise ValueError ("tau must be two dimensional." )
132- # TODO: What's the correct order/approach (in the non-square case)?
133- # `pytensor.tensor.nlinalg.tensorinv`?
127+ if tau .ndim < 2 :
128+ raise ValueError ("tau must be at least two dimensional." )
134129 cov = matrix_inverse (tau )
135130 else :
136- # TODO: What's the correct order/approach (in the non-square case)?
137131 chol = pt .as_tensor_variable (chol )
138- if chol .ndim != 2 :
139- raise ValueError ("chol must be two dimensional." )
132+ if chol .ndim < 2 :
133+ raise ValueError ("chol must be at least two dimensional." )
134+
135+ if not lower :
136+ chol = pt .swapaxes (chol , - 1 , - 2 )
140137
141138 # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
142139 chol .tag .lower_triangular = True
143- cov = chol . dot (chol . T )
140+ cov = pt . matmul (chol , pt . swapaxes ( chol , - 1 , - 2 ) )
144141
145142 return cov
146143
147144
148- def quaddist_parse (value , mu , cov , mat_type = "cov" ):
145+ def quaddist_chol (value , mu , cov ):
149146 """Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
150- if value .ndim > 2 or value . ndim == 0 :
151- raise ValueError ("Invalid dimension for value: %s" % value . ndim )
147+ if value .ndim == 0 :
148+ raise ValueError ("Value can't be a scalar" )
152149 if value .ndim == 1 :
153150 onedim = True
154151 value = value [None , :]
@@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
157154
158155 delta = value - mu
159156 chol_cov = nan_lower_cholesky (cov )
160- if mat_type != "tau" :
161- dist , logdet , ok = quaddist_chol (delta , chol_cov )
162- else :
163- dist , logdet , ok = quaddist_tau (delta , chol_cov )
164- if onedim :
165- return dist [0 ], logdet , ok
166-
167- return dist , logdet , ok
168-
169157
170- def quaddist_chol (delta , chol_mat ):
171- diag = pt .diag (chol_mat )
158+ diag = pt .diagonal (chol_cov , axis1 = - 2 , axis2 = - 1 )
172159 # Check if the covariance matrix is positive definite.
173- ok = pt .all (diag > 0 )
160+ ok = pt .all (diag > 0 , axis = - 1 )
174161 # If not, replace the diagonal. We return -inf later, but
175162 # need to prevent solve_lower from throwing an exception.
176- chol_cov = pt .switch (ok , chol_mat , 1 )
177-
178- delta_trans = solve_lower (chol_cov , delta .T ).T
163+ chol_cov = pt .switch (ok [..., None , None ], chol_cov , 1 )
164+ delta_trans = solve_lower (chol_cov , delta , b_ndim = 1 )
179165 quaddist = (delta_trans ** 2 ).sum (axis = - 1 )
180- logdet = pt .sum (pt .log (diag ))
181- return quaddist , logdet , ok
182-
183-
184- def quaddist_tau (delta , chol_mat ):
185- diag = pt .nlinalg .diag (chol_mat )
186- # Check if the precision matrix is positive definite.
187- ok = pt .all (diag > 0 )
188- # If not, replace the diagonal. We return -inf later, but
189- # need to prevent solve_lower from throwing an exception.
190- chol_tau = pt .switch (ok , chol_mat , 1 )
166+ logdet = pt .log (diag ).sum (axis = - 1 )
191167
192- delta_trans = pt . dot ( delta , chol_tau )
193- quaddist = ( delta_trans ** 2 ). sum ( axis = - 1 )
194- logdet = - pt . sum ( pt . log ( diag ))
195- return quaddist , logdet , ok
168+ if onedim :
169+ return quaddist [ 0 ], logdet , ok
170+ else :
171+ return quaddist , logdet , ok
196172
197173
198174class MvNormal (Continuous ):
@@ -266,10 +242,11 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
266242 mu = pt .as_tensor_variable (mu )
267243 cov = quaddist_matrix (cov , chol , tau , lower )
268244 # PyTensor is stricter about the shape of mu, than PyMC used to be
269- mu = pt .broadcast_arrays (mu , cov [..., - 1 ])[ 0 ]
245+ mu , _ = pt .broadcast_arrays (mu , cov [..., - 1 ])
270246 return super ().dist ([mu , cov ], ** kwargs )
271247
272248 def moment (rv , size , mu , cov ):
249+ # mu is broadcasted to the potential length of cov in `dist`
273250 moment = mu
274251 if not rv_size_is_none (size ):
275252 moment_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -290,7 +267,7 @@ def logp(value, mu, cov):
290267 -------
291268 TensorVariable
292269 """
293- quaddist , logdet , ok = quaddist_parse (value , mu , cov )
270+ quaddist , logdet , ok = quaddist_chol (value , mu , cov )
294271 k = floatX (value .shape [- 1 ])
295272 norm = - 0.5 * k * pm .floatX (np .log (2 * np .pi ))
296273 return check_parameters (
@@ -307,22 +284,6 @@ class MvStudentTRV(RandomVariable):
307284 dtype = "floatX"
308285 _print_name = ("MvStudentT" , "\\ operatorname{MvStudentT}" )
309286
310- def make_node (self , rng , size , dtype , nu , mu , cov ):
311- nu = pt .as_tensor_variable (nu )
312- if not nu .ndim == 0 :
313- raise ValueError ("nu must be a scalar (ndim=0)." )
314-
315- return super ().make_node (rng , size , dtype , nu , mu , cov )
316-
317- def __call__ (self , nu , mu = None , cov = None , size = None , ** kwargs ):
318- dtype = pytensor .config .floatX if self .dtype == "floatX" else self .dtype
319-
320- if mu is None :
321- mu = np .array ([0.0 ], dtype = dtype )
322- if cov is None :
323- cov = np .array ([[1.0 ]], dtype = dtype )
324- return super ().__call__ (nu , mu , cov , size = size , ** kwargs )
325-
326287 def _supp_shape_from_params (self , dist_params , param_shapes = None ):
327288 return supp_shape_from_ref_param_shape (
328289 ndim_supp = self .ndim_supp ,
@@ -333,14 +294,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
333294
334295 @classmethod
335296 def rng_fn (cls , rng , nu , mu , cov , size ):
297+ if size is None :
298+ # When size is implicit, we need to broadcast parameters correctly,
299+ # so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
300+ # nu broadcasts mu and cov
301+ if np .ndim (nu ) > max (mu .ndim - 1 , cov .ndim - 2 ):
302+ _ , mu , cov = broadcast_params ((nu , mu , cov ), ndims_params = cls .ndims_params )
303+ # nu is broadcasted by either mu or cov
304+ elif np .ndim (nu ) < max (mu .ndim - 1 , cov .ndim - 2 ):
305+ nu , _ , _ = broadcast_params ((nu , mu , cov ), ndims_params = cls .ndims_params )
306+
336307 mv_samples = multivariate_normal .rng_fn (rng = rng , mean = np .zeros_like (mu ), cov = cov , size = size )
337308
338309 # Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
339310 chi2_samples = np .sqrt (rng .chisquare (nu , size = size ) / nu )[..., None ]
340311
341- if size :
342- mu = np .broadcast_to (mu , size + (mu .shape [- 1 ],))
343-
344312 return (mv_samples / chi2_samples ) + mu
345313
346314
@@ -390,7 +358,7 @@ class MvStudentT(Continuous):
390358 rv_op = mv_studentt
391359
392360 @classmethod
393- def dist (cls , nu , Sigma = None , mu = None , scale = None , tau = None , chol = None , lower = True , ** kwargs ):
361+ def dist (cls , nu , * , Sigma = None , mu , scale = None , tau = None , chol = None , lower = True , ** kwargs ):
394362 cov = kwargs .pop ("cov" , None )
395363 if cov is not None :
396364 warnings .warn (
@@ -407,11 +375,13 @@ def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=Tr
407375 mu = pt .as_tensor_variable (floatX (mu ))
408376 scale = quaddist_matrix (scale , chol , tau , lower )
409377 # PyTensor is stricter about the shape of mu, than PyMC used to be
410- mu = pt .broadcast_arrays (mu , scale [..., - 1 ])[ 0 ]
378+ mu , _ = pt .broadcast_arrays (mu , scale [..., - 1 ])
411379
412380 return super ().dist ([nu , mu , scale ], ** kwargs )
413381
414382 def moment (rv , size , nu , mu , scale ):
383+ # mu is broadcasted to the potential length of scale in `dist`
384+ mu , _ = pt .random .utils .broadcast_params ([mu , nu ], ndims_params = [1 , 0 ])
415385 moment = mu
416386 if not rv_size_is_none (size ):
417387 moment_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -432,7 +402,7 @@ def logp(value, nu, mu, scale):
432402 -------
433403 TensorVariable
434404 """
435- quaddist , logdet , ok = quaddist_parse (value , mu , scale )
405+ quaddist , logdet , ok = quaddist_chol (value , mu , scale )
436406 k = floatX (value .shape [- 1 ])
437407
438408 norm = gammaln ((nu + k ) / 2.0 ) - gammaln (nu / 2.0 ) - 0.5 * k * pt .log (nu * np .pi )
0 commit comments