@@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False):
91
91
# We don't use SymPy's count_ops because we do not count integer arithmetic
92
92
# (e.g., array index functions such as i+1 in A[i+1])
93
93
# Also, the routine below is *much* faster than count_ops
94
+ seen = {}
94
95
flops = 0
95
96
for expr in as_tuple (exprs ):
96
97
# TODO: this if-then should be part of singledispatch too, but because
@@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False):
103
104
else :
104
105
e = expr
105
106
106
- flops += _estimate_cost (e , estimate )[0 ]
107
+ flops += _estimate_cost (e , estimate , seen )[0 ]
107
108
108
109
return flops
109
110
except :
@@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False):
121
122
}
122
123
123
124
125
+ def dont_count_if_seen (func ):
126
+ """
127
+ This decorator is used to avoid counting the same expression multiple
128
+ times. This is necessary because the same expression may appear multiple
129
+ times in the same expression tree or even across different expressions.
130
+ """
131
+ def wrapper (expr , estimate , seen ):
132
+ try :
133
+ _ , flags = seen [expr ]
134
+ flops = 0
135
+ except KeyError :
136
+ flops , flags = seen [expr ] = func (expr , estimate , seen )
137
+ return flops , flags
138
+ return wrapper
139
+
140
+
124
141
@singledispatch
125
- def _estimate_cost (expr , estimate ):
142
+ def _estimate_cost (expr , estimate , seen ):
126
143
# Retval: flops (int), flag (bool)
127
144
# The flag tells wether it's an integer expression (implying flops==0) or not
128
- flops , flags = zip (* [_estimate_cost (a , estimate ) for a in expr .args ])
145
+ flops , flags = zip (* [_estimate_cost (a , estimate , seen ) for a in expr .args ])
129
146
flops = sum (flops )
130
147
if all (flags ):
131
148
# `expr` is an operation involving integer operands only
@@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate):
138
155
139
156
@_estimate_cost .register (Tuple )
140
157
@_estimate_cost .register (CallFromPointer )
141
- def _ (expr , estimate ):
158
+ def _ (expr , estimate , seen ):
142
159
try :
143
- flops , flags = zip (* [_estimate_cost (a , estimate ) for a in expr .args ])
160
+ flops , flags = zip (* [_estimate_cost (a , estimate , seen ) for a in expr .args ])
144
161
except ValueError :
145
162
flops , flags = [], []
146
163
return sum (flops ), all (flags )
147
164
148
165
149
166
@_estimate_cost .register (Integer )
150
- def _ (expr , estimate ):
167
+ def _ (expr , estimate , seen ):
151
168
return 0 , True
152
169
153
170
154
171
@_estimate_cost .register (Number )
155
172
@_estimate_cost .register (ReservedWord )
156
- def _ (expr , estimate ):
173
+ def _ (expr , estimate , seen ):
157
174
return 0 , False
158
175
159
176
160
177
@_estimate_cost .register (Symbol )
161
178
@_estimate_cost .register (Indexed )
162
- def _ (expr , estimate ):
179
+ def _ (expr , estimate , seen ):
163
180
try :
164
181
if issubclass (expr .dtype , np .integer ):
165
182
return 0 , True
@@ -169,27 +186,27 @@ def _(expr, estimate):
169
186
170
187
171
188
@_estimate_cost .register (Mul )
172
- def _ (expr , estimate ):
173
- flops , flags = _estimate_cost .registry [object ](expr , estimate )
189
+ def _ (expr , estimate , seen ):
190
+ flops , flags = _estimate_cost .registry [object ](expr , estimate , seen )
174
191
if {S .One , S .NegativeOne }.intersection (expr .args ):
175
192
flops -= 1
176
193
return flops , flags
177
194
178
195
179
196
@_estimate_cost .register (INT )
180
- def _ (expr , estimate ):
181
- return _estimate_cost (expr .base , estimate )[0 ], True
197
+ def _ (expr , estimate , seen ):
198
+ return _estimate_cost (expr .base , estimate , seen )[0 ], True
182
199
183
200
184
201
@_estimate_cost .register (Cast )
185
- def _ (expr , estimate ):
186
- return _estimate_cost (expr .base , estimate )[0 ], False
202
+ def _ (expr , estimate , seen ):
203
+ return _estimate_cost (expr .base , estimate , seen )[0 ], False
187
204
188
205
189
206
@_estimate_cost .register (Function )
190
- def _ (expr , estimate ):
207
+ def _ (expr , estimate , seen ):
191
208
if q_routine (expr ):
192
- flops , _ = zip (* [_estimate_cost (a , estimate ) for a in expr .args ])
209
+ flops , _ = zip (* [_estimate_cost (a , estimate , seen ) for a in expr .args ])
193
210
flops = sum (flops )
194
211
if isinstance (expr , DefFunction ):
195
212
# Bypass user-defined or language-specific functions
@@ -207,8 +224,8 @@ def _(expr, estimate):
207
224
208
225
209
226
@_estimate_cost .register (Pow )
210
- def _ (expr , estimate ):
211
- flops , _ = zip (* [_estimate_cost (a , estimate ) for a in expr .args ])
227
+ def _ (expr , estimate , seen ):
228
+ flops , _ = zip (* [_estimate_cost (a , estimate , seen ) for a in expr .args ])
212
229
flops = sum (flops )
213
230
if estimate :
214
231
if expr .exp .is_Number :
@@ -229,13 +246,15 @@ def _(expr, estimate):
229
246
230
247
231
248
@_estimate_cost .register (Derivative )
232
- def _ (expr , estimate ):
233
- return _estimate_cost (expr ._evaluate (expand = False ), estimate )
249
+ @dont_count_if_seen
250
+ def _ (expr , estimate , seen ):
251
+ return _estimate_cost (expr ._evaluate (expand = False ), estimate , seen )
234
252
235
253
236
254
@_estimate_cost .register (IndexDerivative )
237
- def _ (expr , estimate ):
238
- flops , _ = _estimate_cost (expr .expr , estimate )
255
+ @dont_count_if_seen
256
+ def _ (expr , estimate , seen ):
257
+ flops , _ = _estimate_cost (expr .expr , estimate , seen )
239
258
240
259
# It's an increment
241
260
flops += 1
0 commit comments