11import logging
2+ from typing import cast
23
34from pytensor .graph .rewriting .basic import node_rewriter
4- from pytensor .tensor import basic as at
5+ from pytensor .tensor . basic import TensorVariable , extract_diag , swapaxes
56from pytensor .tensor .blas import Dot22
7+ from pytensor .tensor .blockwise import Blockwise
68from pytensor .tensor .elemwise import DimShuffle
79from pytensor .tensor .math import Dot , Prod , log , prod
8- from pytensor .tensor .nlinalg import Det , MatrixInverse
10+ from pytensor .tensor .nlinalg import MatrixInverse , det
911from pytensor .tensor .rewriting .basic import (
1012 register_canonicalize ,
1113 register_specialize ,
1719logger = logging .getLogger (__name__ )
1820
1921
22+ def is_matrix_transpose (x : TensorVariable ) -> bool :
23+ """Check if a variable corresponds to a transpose of the last two axes"""
24+ node = x .owner
25+ if (
26+ node
27+ and isinstance (node .op , DimShuffle )
28+ and not (node .op .drop or node .op .augment )
29+ ):
30+ [inp ] = node .inputs
31+ ndims = inp .type .ndim
32+ if ndims < 2 :
33+ return False
34+ transpose_order = tuple (range (ndims - 2 )) + (ndims - 1 , ndims - 2 )
35+ return cast (bool , node .op .new_order == transpose_order )
36+ return False
37+
38+
39+ def T (x : TensorVariable ) -> TensorVariable :
40+ """Matrix transpose for potentially higher dimensionality tensors"""
41+ return swapaxes (x , - 1 , - 2 )
42+
43+
2044@register_canonicalize
2145@node_rewriter ([DimShuffle ])
2246def transinv_to_invtrans (fgraph , node ):
23- if isinstance (node .op , DimShuffle ):
24- if node .op .new_order == (1 , 0 ):
25- (A ,) = node .inputs
26- if A .owner :
27- if isinstance (A .owner .op , MatrixInverse ):
28- (X ,) = A .owner .inputs
29- return [A .owner .op (node .op (X ))]
47+ if is_matrix_transpose (node .outputs [0 ]):
48+ (A ,) = node .inputs
49+ if (
50+ A .owner
51+ and isinstance (A .owner .op , Blockwise )
52+ and isinstance (A .owner .op .core_op , MatrixInverse )
53+ ):
54+ (X ,) = A .owner .inputs
55+ return [A .owner .op (node .op (X ))]
3056
3157
3258@register_stabilize
@@ -37,87 +63,98 @@ def inv_as_solve(fgraph, node):
3763 """
3864 if isinstance (node .op , (Dot , Dot22 )):
3965 l , r = node .inputs
40- if l .owner and isinstance (l .owner .op , MatrixInverse ):
66+ if (
67+ l .owner
68+ and isinstance (l .owner .op , Blockwise )
69+ and isinstance (l .owner .op .core_op , MatrixInverse )
70+ ):
4171 return [solve (l .owner .inputs [0 ], r )]
42- if r .owner and isinstance (r .owner .op , MatrixInverse ):
72+ if (
73+ r .owner
74+ and isinstance (r .owner .op , Blockwise )
75+ and isinstance (r .owner .op .core_op , MatrixInverse )
76+ ):
4377 x = r .owner .inputs [0 ]
4478 if getattr (x .tag , "symmetric" , None ) is True :
45- return [solve (x , l . T ). T ]
79+ return [T ( solve (x , T ( l ))) ]
4680 else :
47- return [solve (x . T , l . T ). T ]
81+ return [T ( solve (T ( x ), T ( l ))) ]
4882
4983
5084@register_stabilize
5185@register_canonicalize
52- @node_rewriter ([Solve ])
86+ @node_rewriter ([Blockwise ])
5387def tag_solve_triangular (fgraph , node ):
5488 """
5589 If a general solve() is applied to the output of a cholesky op, then
5690 replace it with a triangular solve.
5791
5892 """
59- if isinstance (node .op , Solve ):
60- if node .op .assume_a == "gen" :
93+ if isinstance (node .op . core_op , Solve ) and node . op . core_op . b_ndim == 1 :
94+ if node .op .core_op . assume_a == "gen" :
6195 A , b = node .inputs # result is solution Ax=b
62- if A .owner and isinstance (A .owner .op , Cholesky ):
63- if A .owner .op .lower :
64- return [Solve (assume_a = "sym" , lower = True )(A , b )]
65- else :
66- return [Solve (assume_a = "sym" , lower = False )(A , b )]
6796 if (
6897 A .owner
69- and isinstance (A .owner .op , DimShuffle )
70- and A .owner .op .new_order == ( 1 , 0 )
98+ and isinstance (A .owner .op , Blockwise )
99+ and isinstance ( A .owner .op .core_op , Cholesky )
71100 ):
101+ if A .owner .op .core_op .lower :
102+ return [solve (A , b , assume_a = "sym" , lower = True )]
103+ else :
104+ return [solve (A , b , assume_a = "sym" , lower = False )]
105+ if is_matrix_transpose (A ):
72106 (A_T ,) = A .owner .inputs
73- if A_T .owner and isinstance (A_T .owner .op , Cholesky ):
107+ if (
108+ A_T .owner
109+ and isinstance (A_T .owner .op , Blockwise )
110+ and isinstance (A_T .owner .op , Cholesky )
111+ ):
74112 if A_T .owner .op .lower :
75- return [Solve ( assume_a = "sym" , lower = False )( A , b )]
113+ return [solve ( A , b , assume_a = "sym" , lower = False )]
76114 else :
77- return [Solve ( assume_a = "sym" , lower = True )( A , b )]
115+ return [solve ( A , b , assume_a = "sym" , lower = True )]
78116
79117
80118@register_canonicalize
81119@register_stabilize
82120@register_specialize
83121@node_rewriter ([DimShuffle ])
84122def no_transpose_symmetric (fgraph , node ):
85- if isinstance (node .op , DimShuffle ):
123+ if is_matrix_transpose (node .outputs [ 0 ] ):
86124 x = node .inputs [0 ]
87- if x .type .ndim == 2 and getattr (x .tag , "symmetric" , None ) is True :
88- if node .op .new_order == [1 , 0 ]:
89- return [x ]
125+ if getattr (x .tag , "symmetric" , None ):
126+ return [x ]
90127
91128
92129@register_stabilize
93- @node_rewriter ([Solve ])
130+ @node_rewriter ([Blockwise ])
94131def psd_solve_with_chol (fgraph , node ):
95132 """
96133 This utilizes a boolean `psd` tag on matrices.
97134 """
98- if isinstance (node .op , Solve ):
135+ if isinstance (node .op . core_op , Solve ) and node . op . core_op . b_ndim == 2 :
99136 A , b = node .inputs # result is solution Ax=b
100137 if getattr (A .tag , "psd" , None ) is True :
101138 L = cholesky (A )
102139 # N.B. this can be further reduced to a yet-unwritten cho_solve Op
103- # __if__ no other Op makes use of the the L matrix during the
140+ # __if__ no other Op makes use of the L matrix during the
104141 # stabilization
105- Li_b = Solve ( assume_a = "sym" , lower = True )( L , b )
106- x = Solve ( assume_a = "sym" , lower = False )( L . T , Li_b )
142+ Li_b = solve ( L , b , assume_a = "sym" , lower = True , b_ndim = 2 )
143+ x = solve ( T ( L ), Li_b , assume_a = "sym" , lower = False , b_ndim = 2 )
107144 return [x ]
108145
109146
110147@register_canonicalize
111148@register_stabilize
112- @node_rewriter ([Cholesky ])
149+ @node_rewriter ([Blockwise ])
113150def cholesky_ldotlt (fgraph , node ):
114151 """
115152 rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
116153 or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
117154
118155 This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
119156 """
120- if not isinstance (node .op , Cholesky ):
157+ if not isinstance (node .op . core_op , Cholesky ):
121158 return
122159
123160 A = node .inputs [0 ]
@@ -129,43 +166,38 @@ def cholesky_ldotlt(fgraph, node):
129166 # cholesky(dot(L,L.T)) case
130167 if (
131168 getattr (l .tag , "lower_triangular" , False )
132- and r .owner
133- and isinstance (r .owner .op , DimShuffle )
134- and r .owner .op .new_order == (1 , 0 )
169+ and is_matrix_transpose (r )
135170 and r .owner .inputs [0 ] == l
136171 ):
137- if node .op .lower :
172+ if node .op .core_op . lower :
138173 return [l ]
139174 return [r ]
140175
141176 # cholesky(dot(U.T,U)) case
142177 if (
143178 getattr (r .tag , "upper_triangular" , False )
144- and l .owner
145- and isinstance (l .owner .op , DimShuffle )
146- and l .owner .op .new_order == (1 , 0 )
179+ and is_matrix_transpose (l )
147180 and l .owner .inputs [0 ] == r
148181 ):
149- if node .op .lower :
182+ if node .op .core_op . lower :
150183 return [l ]
151184 return [r ]
152185
153186
154187@register_stabilize
155188@register_specialize
156- @node_rewriter ([Det ])
189+ @node_rewriter ([det ])
157190def local_det_chol (fgraph , node ):
158191 """
159192 If we have det(X) and there is already an L=cholesky(X)
160193 floating around, then we can use prod(diag(L)) to get the determinant.
161194
162195 """
163- if isinstance (node .op , Det ):
164- (x ,) = node .inputs
165- for cl , xpos in fgraph .clients [x ]:
166- if isinstance (cl .op , Cholesky ):
167- L = cl .outputs [0 ]
168- return [prod (at .extract_diag (L ) ** 2 )]
196+ (x ,) = node .inputs
197+ for cl , xpos in fgraph .clients [x ]:
198+ if isinstance (cl .op , Blockwise ) and isinstance (cl .op .core_op , Cholesky ):
199+ L = cl .outputs [0 ]
200+ return [prod (extract_diag (L ) ** 2 , axis = (- 1 , - 2 ))]
169201
170202
171203@register_canonicalize
@@ -176,16 +208,15 @@ def local_log_prod_sqr(fgraph, node):
176208 """
177209 This utilizes a boolean `positive` tag on matrices.
178210 """
179- if node .op == log :
180- (x ,) = node .inputs
181- if x .owner and isinstance (x .owner .op , Prod ):
182- # we cannot always make this substitution because
183- # the prod might include negative terms
184- p = x .owner .inputs [0 ]
185-
186- # p is the matrix we're reducing with prod
187- if getattr (p .tag , "positive" , None ) is True :
188- return [log (p ).sum (axis = x .owner .op .axis )]
189-
190- # TODO: have a reduction like prod and sum that simply
191- # returns the sign of the prod multiplication.
211+ (x ,) = node .inputs
212+ if x .owner and isinstance (x .owner .op , Prod ):
213+ # we cannot always make this substitution because
214+ # the prod might include negative terms
215+ p = x .owner .inputs [0 ]
216+
217+ # p is the matrix we're reducing with prod
218+ if getattr (p .tag , "positive" , None ) is True :
219+ return [log (p ).sum (axis = x .owner .op .axis )]
220+
221+ # TODO: have a reduction like prod and sum that simply
222+ # returns the sign of the prod multiplication.
0 commit comments