Skip to content

Commit

Permalink
compiler: Patch estimate_cost (now distinguishes INT ops correctly)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Feb 7, 2022
1 parent ec57319 commit 17e6b05
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 60 deletions.
128 changes: 94 additions & 34 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from collections import Counter
from functools import singledispatch

import numpy as np
from sympy import Function, Indexed, Integer, Mul, Number, Pow, S, Symbol

from devito.logger import warning
from devito.symbolics.extended_sympy import INT, Cast
from devito.symbolics.queries import q_routine
from devito.symbolics.search import retrieve_xops, search
from devito.tools import as_tuple, flatten
from devito.symbolics.search import search
from devito.tools import as_tuple

__all__ = ['compare_ops', 'count', 'estimate_cost']

Expand Down Expand Up @@ -84,25 +88,18 @@ def estimate_cost(exprs, estimate=False):
return 0
except AttributeError:
pass
try:
# Is it a dict ?
exprs = exprs.values()
except AttributeError:
try:
# Could still be a list of dicts
exprs = flatten([i.values() for i in exprs])
except (AttributeError, TypeError):
pass
try:
# At this point it must be a list of SymPy objects
# We don't use SymPy's count_ops because we do not count integer arithmetic
# (e.g., array index functions such as i+1 in A[i+1])
# Also, the routine below is *much* faster than count_ops
exprs = [i.rhs if i.is_Equality else i for i in as_tuple(exprs)]
operations = flatten(retrieve_xops(i) for i in exprs)
flops = 0
for op in operations:
flops += _estimate_cost(op, estimate)
for expr in as_tuple(exprs):
if expr.is_Equality:
e = expr.rhs
else:
e = expr
flops += _estimate_cost(e, estimate)[0]
return flops
except:
warning("Cannot estimate cost of `%s`" % str(exprs))
Expand All @@ -118,26 +115,89 @@ def estimate_cost(exprs, estimate=False):

@singledispatch
def _estimate_cost(expr, estimate):
if expr.is_Function:
if estimate and q_routine(expr):
return estimate_values['elementary']
else:
return 1
elif expr.is_Pow:
# Retval: flops (int), flag (bool)
# The flag tells wether it's an integer expression (implying flops==0) or not
flops, flags = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops = sum(flops)
if all(flags):
# `expr` is an operation involving integer operands only
# NOTE: one of the operands may contain, internally, non-integer
# operations, e.g. the `a*b` in `2 + INT(a*b)`
return flops, True
else:
return flops + (len(expr.args) - 1), False


@_estimate_cost.register(Integer)
def _(expr, estimate):
return 0, True


@_estimate_cost.register(Number)
def _(expr, estimate):
return 0, False


@_estimate_cost.register(Symbol)
@_estimate_cost.register(Indexed)
def _(expr, estimate):
try:
if issubclass(expr.dtype, np.integer):
return 0, True
except:
pass
return 0, False


@_estimate_cost.register(Mul)
def _(expr, estimate):
flops, flags = _estimate_cost.registry[object](expr, estimate)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
return flops, flags


@_estimate_cost.register(INT)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], True


@_estimate_cost.register(Cast)
def _(expr, estimate):
return _estimate_cost(expr.base, estimate)[0], False


@_estimate_cost.register(Function)
def _(expr, estimate):
if q_routine(expr):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops = sum(flops)
if estimate:
if expr.exp.is_Number:
if expr.exp < 0:
return estimate_values['div']
elif expr.exp == 0 or expr.exp == 1:
return 0
elif expr.exp.is_Integer:
# Natural pows a**b are estimated as b-1 Muls
return int(expr.exp) - 1
else:
return estimate_values['pow']
flops += estimate_values['elementary']
else:
flops += 1
else:
flops = 0
return flops, False


@_estimate_cost.register(Pow)
def _(expr, estimate):
flops, _ = zip(*[_estimate_cost(a, estimate) for a in expr.args])
flops = sum(flops)
if estimate:
if expr.exp.is_Number:
if expr.exp < 0:
flops += estimate_values['div']
elif expr.exp == 0 or expr.exp == 1:
flops += 0
elif expr.exp.is_Integer:
# Natural pows a**b are estimated as b-1 Muls
flops += int(expr.exp) - 1
else:
return estimate_values['pow']
flops += estimate_values['pow']
else:
return 1
flops += estimate_values['pow']
else:
return len(expr.args) - (1 + sum(True for i in expr.args if i.is_Integer))
flops += 1
return flops, False
10 changes: 2 additions & 8 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, q_xop,
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
q_symbol, q_dimension)
from devito.tools import as_tuple

__all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers',
'retrieve_terminals', 'retrieve_xops', 'retrieve_symbols',
'retrieve_dimensions', 'search']
'retrieve_terminals', 'retrieve_symbols', 'retrieve_dimensions', 'search']


class Search(object):
Expand Down Expand Up @@ -168,11 +167,6 @@ def retrieve_terminals(exprs, mode='all', deep=False):
return search(exprs, q_terminal, mode, 'dfs', deep)


def retrieve_xops(exprs):
"""Shorthand to retrieve the arithmetic operations within ``exprs``."""
return search(exprs, q_xop, 'all', 'dfs')


def retrieve_dimensions(exprs, mode='all', deep=False):
"""Shorthand to retrieve the dimensions in ``exprs``."""
return search(exprs, q_dimension, mode, 'dfs', deep)
45 changes: 27 additions & 18 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from devito.passes.clusters.aliases import collect
from devito.passes.clusters.cse import Temp, _cse
from devito.passes.iet.parpragma import VExpanded
from devito.symbolics import estimate_cost, pow_to_mul, indexify
from devito.symbolics import INT, FLOAT, estimate_cost, pow_to_mul, indexify # noqa
from devito.tools import as_tuple, generator
from devito.types import Scalar, Array

Expand Down Expand Up @@ -51,12 +51,12 @@ def test_scheduling_after_rewrite():
(['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, t0 + t1)', 'Eq(ti1, t0 + t1)'],
['t0 + t1', '2/r0', 'r0', 'r0']),
(['Eq(tu, 2/(t0 + t1))', 'Eq(ti0, 2/(t0 + t1) + 1)', 'Eq(ti1, 2/(t0 + t1) + 1)'],
['1/(t0 + t1)', '2*r5', 'r3 + 1', 'r3', 'r2', 'r2']),
['2/(t0 + t1)', 'r1 + 1', 'r1', 'r0', 'r0']),
(['Eq(tu, (tv + tw + 5.)*(ti0 + ti1) + (t0 + t1)*(ti0 + ti1))'],
['ti0[x, y, z] + ti1[x, y, z]',
'r0*(t0 + t1) + r0*(tv[t, x, y, z] + tw[t, x, y, z] + 5.0)']),
(['Eq(tu, t0/t1)', 'Eq(ti0, 2 + t0/t1)', 'Eq(ti1, 2 + t0/t1)'],
['t0/t1', 'r2 + 2', 'r2', 'r1', 'r1']),
['t0/t1', 'r1 + 2', 'r1', 'r0', 'r0']),
# Across expressions
(['Eq(tu, tv*4 + tw*5 + tw*5*t0)', 'Eq(tv, tw*5)'],
['5*tw[t, x, y, z]', 'r0 + 5*t0*tw[t, x, y, z] + 4*tv[t, x, y, z]', 'r0']),
Expand Down Expand Up @@ -175,9 +175,18 @@ def test_pow_to_mul(expr, expected):
('Eq(t0, 3.2/h_x)', 6, True), # seen as `3.2*(1/h_x)`, so counts as 2
('Eq(t0, 3.2/h_x*fa + 2.4/h_x*fb)', 15, True), # `pow(...constants...)` counts as 1
# Integer arithmetic should not count
('Eq(t0, INT(t1))', 0, True),
('Eq(t0, INT(t1*t0))', 1, True),
('Eq(t0, 2 + INT(t1*t0))', 1, True),
('Eq(t0, FLOAT(t1))', 0, True),
('Eq(t0, FLOAT(t1*t2*t3))', 2, True),
('Eq(t0, 1 + FLOAT(t1*t2*t3))', 3, True), # The 1 gets casted to float
('Eq(t0, 1 + t3)', 0, False),
('Eq(t0, 1 + t3)', 0, True),
#('Eq(t0, t3 + t4)', 0, True),
('Eq(t0, t3 + t4)', 0, True),
('Eq(t0, 2*t3)', 0, True),
('Eq(t0, 2*t1)', 1, True),
('Eq(t0, -4 + INT(t3*t4 + t3))', 0, True),
])
def test_estimate_cost(expr, expected, estimate):
# Note: integer arithmetic isn't counted
Expand Down Expand Up @@ -1151,7 +1160,7 @@ def d1(field):

# Also check against expected operation count to make sure
# all redundancies have been detected correctly
assert sum(i.ops for i in summary1.values()) == 69
assert sum(i.ops for i in summary1.values()) == 75

@pytest.mark.parametrize('rotate', [False, True])
def test_from_different_nests(self, rotate):
Expand Down Expand Up @@ -1689,7 +1698,7 @@ def test_extraction_from_lifted_ispace(self, rotate):
# all redundancies have been detected correctly
assert summary[('section0', None)].ops == 93

@pytest.mark.parametrize('so_ops', [(4, 108)])
@pytest.mark.parametrize('so_ops', [(4, 113)])
@switchconfig(profiling='advanced')
def test_tti_J_akin_bb0(self, so_ops):
grid = Grid(shape=(16, 16, 16))
Expand Down Expand Up @@ -1772,7 +1781,7 @@ def g3_tilde(field, phi):
assert len([i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]) == 6
assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 3

@pytest.mark.parametrize('so_ops', [(4, 48)])
@pytest.mark.parametrize('so_ops', [(4, 49)])
@switchconfig(profiling='advanced')
def test_tti_J_akin_bb2(self, so_ops):
grid = Grid(shape=(16, 16, 16))
Expand Down Expand Up @@ -1816,7 +1825,7 @@ def g2_tilde(field, phi, theta):
assert len([i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]) == 7
assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 3

@pytest.mark.parametrize('so_ops', [(4, 144), (8, 208)])
@pytest.mark.parametrize('so_ops', [(4, 146), (8, 210)])
@switchconfig(profiling='advanced')
def test_tti_J_akin_complete(self, so_ops):
grid = Grid(shape=(16, 16, 16))
Expand Down Expand Up @@ -2038,7 +2047,7 @@ def test_nested_first_derivatives(self, rotate):

# Also check against expected operation count to make sure
# all redundancies have been detected correctly
assert summary1[('section0', None)].ops == 14
assert summary1[('section0', None)].ops == 16

def test_undestroyed_preevaluated_derivatives_v1(self):
grid = Grid(shape=(10, 10))
Expand Down Expand Up @@ -2073,9 +2082,9 @@ def test_nested_first_derivatives_unbalanced(self):
('v.dx.dx + p.dx.dx',
(2, 2, (0, 2)), (61, 61, 25)),
('(v.dx + v.dy).dx - (v.dx + v.dy).dy + 2*f.dx.dx + f*f.dy.dy + f.dx.dx(x0=1)',
(3, 3, (0, 3)), (217, 201, 74)),
(3, 3, (0, 3)), (218, 202, 74)),
('(g*(1 + f)*v.dx).dx + (2*g*f*v.dx).dx',
(1, 2, (0, 1)), (50, 65, 18)),
(1, 2, (0, 1)), (52, 70, 20)),
('g*(f.dx.dx + g.dx.dx)',
(1, 2, (0, 1)), (47, 62, 17)),
])
Expand Down Expand Up @@ -2527,7 +2536,7 @@ def test_fullopt(self):
bns, _ = assert_blocking(op1, {'x0_blk0'}) # due to loop blocking

assert summary0[('section0', None)].ops == 50
assert summary0[('section1', None)].ops == 139
assert summary0[('section1', None)].ops == 140
assert np.isclose(summary0[('section0', None)].oi, 2.851, atol=0.001)

assert summary1[('section0', None)].ops == 31
Expand Down Expand Up @@ -2578,7 +2587,7 @@ def tti_noopt(self):
# Make sure no opts were applied
op = wavesolver.op_fwd(False)
assert len(op._func_table) == 0
assert summary[('section0', None)].ops == 737
assert summary[('section0', None)].ops == 743

return v, rec

Expand All @@ -2591,8 +2600,8 @@ def test_fullopt(self):
assert np.allclose(self.tti_noopt[1].data, rec.data, atol=10e-1)

# Check expected opcount/oi
assert summary[('section1', None)].ops == 90
assert np.isclose(summary[('section1', None)].oi, 2.031, atol=0.001)
assert summary[('section1', None)].ops == 92
assert np.isclose(summary[('section1', None)].oi, 2.074, atol=0.001)

# With optimizations enabled, there should be exactly four BlockDimensions
op = wavesolver.op_fwd()
Expand Down Expand Up @@ -2639,7 +2648,7 @@ def test_fullopt_w_mpi(self):

@switchconfig(profiling='advanced')
@pytest.mark.parametrize('space_order,expected', [
(8, 152), (16, 270)
(8, 154), (16, 272)
])
def test_opcounts(self, space_order, expected):
op = self.tti_operator(opt='advanced', space_order=space_order)
Expand All @@ -2648,7 +2657,7 @@ def test_opcounts(self, space_order, expected):

@switchconfig(profiling='advanced')
@pytest.mark.parametrize('space_order,expected', [
(4, 111),
(4, 121),
])
def test_opcounts_adjoint(self, space_order, expected):
wavesolver = self.tti_operator(opt=('advanced', {'openmp': False}))
Expand All @@ -2662,7 +2671,7 @@ class TestTTIv2(object):

@switchconfig(profiling='advanced')
@pytest.mark.parametrize('space_order,expected', [
(4, 191), (12, 383)
(4, 200), (12, 392)
])
def test_opcounts(self, space_order, expected):
grid = Grid(shape=(3, 3, 3))
Expand Down

0 comments on commit 17e6b05

Please sign in to comment.