Skip to content

Commit 357f760

Browse files
committed
compiler: Improve estimate_cost
1 parent 8f45ba0 commit 357f760

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

devito/symbolics/inspection.py

+41-22
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def estimate_cost(exprs, estimate=False):
9191
# We don't use SymPy's count_ops because we do not count integer arithmetic
9292
# (e.g., array index functions such as i+1 in A[i+1])
9393
# Also, the routine below is *much* faster than count_ops
94+
seen = {}
9495
flops = 0
9596
for expr in as_tuple(exprs):
9697
# TODO: this if-then should be part of singledispatch too, but because
@@ -103,7 +104,7 @@ def estimate_cost(exprs, estimate=False):
103104
else:
104105
e = expr
105106

106-
flops += _estimate_cost(e, estimate)[0]
107+
flops += _estimate_cost(e, estimate, seen)[0]
107108

108109
return flops
109110
except:
@@ -121,11 +122,27 @@ def estimate_cost(exprs, estimate=False):
121122
}
122123

123124

125+
def dont_count_if_seen(func):
126+
"""
127+
This decorator is used to avoid counting the same expression multiple
128+
times. This is necessary because the same expression may appear multiple
129+
times in the same expression tree or even across different expressions.
130+
"""
131+
def wrapper(expr, estimate, seen):
132+
try:
133+
_, flags = seen[expr]
134+
flops = 0
135+
except KeyError:
136+
flops, flags = seen[expr] = func(expr, estimate, seen)
137+
return flops, flags
138+
return wrapper
139+
140+
124141
@singledispatch
125-
def _estimate_cost(expr, estimate):
142+
def _estimate_cost(expr, estimate, seen):
126143
# Retval: flops (int), flag (bool)
127144
# The flag tells wether it's an integer expression (implying flops==0) or not
128-
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
145+
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
129146
flops = sum(flops)
130147
if all(flags):
131148
# `expr` is an operation involving integer operands only
@@ -138,28 +155,28 @@ def _estimate_cost(expr, estimate):
138155

139156
@_estimate_cost.register(Tuple)
140157
@_estimate_cost.register(CallFromPointer)
141-
def _(expr, estimate):
158+
def _(expr, estimate, seen):
142159
try:
143-
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
160+
flops, flags = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
144161
except ValueError:
145162
flops, flags = [], []
146163
return sum(flops), all(flags)
147164

148165

149166
@_estimate_cost.register(Integer)
150-
def _(expr, estimate):
167+
def _(expr, estimate, seen):
151168
return 0, True
152169

153170

154171
@_estimate_cost.register(Number)
155172
@_estimate_cost.register(ReservedWord)
156-
def _(expr, estimate):
173+
def _(expr, estimate, seen):
157174
return 0, False
158175

159176

160177
@_estimate_cost.register(Symbol)
161178
@_estimate_cost.register(Indexed)
162-
def _(expr, estimate):
179+
def _(expr, estimate, seen):
163180
try:
164181
if issubclass(expr.dtype, np.integer):
165182
return 0, True
@@ -169,27 +186,27 @@ def _(expr, estimate):
169186

170187

171188
@_estimate_cost.register(Mul)
172-
def _(expr, estimate):
173-
flops, flags = _estimate_cost.registry[object](expr, estimate)
189+
def _(expr, estimate, seen):
190+
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
174191
if {S.One, S.NegativeOne}.intersection(expr.args):
175192
flops -= 1
176193
return flops, flags
177194

178195

179196
@_estimate_cost.register(INT)
180-
def _(expr, estimate):
181-
return _estimate_cost(expr.base, estimate)[0], True
197+
def _(expr, estimate, seen):
198+
return _estimate_cost(expr.base, estimate, seen)[0], True
182199

183200

184201
@_estimate_cost.register(Cast)
185-
def _(expr, estimate):
186-
return _estimate_cost(expr.base, estimate)[0], False
202+
def _(expr, estimate, seen):
203+
return _estimate_cost(expr.base, estimate, seen)[0], False
187204

188205

189206
@_estimate_cost.register(Function)
190-
def _(expr, estimate):
207+
def _(expr, estimate, seen):
191208
if q_routine(expr):
192-
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
209+
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
193210
flops = sum(flops)
194211
if isinstance(expr, DefFunction):
195212
# Bypass user-defined or language-specific functions
@@ -207,8 +224,8 @@ def _(expr, estimate):
207224

208225

209226
@_estimate_cost.register(Pow)
210-
def _(expr, estimate):
211-
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
227+
def _(expr, estimate, seen):
228+
flops, _ = zip(*[_estimate_cost(a, estimate, seen) for a in expr.args])
212229
flops = sum(flops)
213230
if estimate:
214231
if expr.exp.is_Number:
@@ -229,13 +246,15 @@ def _(expr, estimate):
229246

230247

231248
@_estimate_cost.register(Derivative)
232-
def _(expr, estimate):
233-
return _estimate_cost(expr._evaluate(expand=False), estimate)
249+
@dont_count_if_seen
250+
def _(expr, estimate, seen):
251+
return _estimate_cost(expr._evaluate(expand=False), estimate, seen)
234252

235253

236254
@_estimate_cost.register(IndexDerivative)
237-
def _(expr, estimate):
238-
flops, _ = _estimate_cost(expr.expr, estimate)
255+
@dont_count_if_seen
256+
def _(expr, estimate, seen):
257+
flops, _ = _estimate_cost(expr.expr, estimate, seen)
239258

240259
# It's an increment
241260
flops += 1

tests/test_dse.py

+2
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def test_factorize(expr, expected):
270270
('Eq(fb, fd.dx)', 10, True),
271271
('Eq(fb, fd.dx._evaluate(expand=False))', 10, False),
272272
('Eq(fb, fd.dx.dy + fa.dx)', 66, False),
273+
# Ensure redundancies aren't counted
274+
('Eq(t0, fd.dx.dy + fa*fd.dx.dy)', 62, True),
273275
])
274276
def test_estimate_cost(expr, expected, estimate):
275277
# Note: integer arithmetic isn't counted

0 commit comments

Comments
 (0)