Skip to content

Commit

Permalink
compiler: Split processing of elementaries and divs in CIRE
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed May 11, 2021
1 parent b586057 commit 20d5870
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 54 deletions.
54 changes: 34 additions & 20 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,16 @@ def cire(clusters, mode, sregistry, options, platform):
t1 = 3.0*t2[x,y,z+1]
"""
modes = {
CireInvariants.optname: CireInvariants,
CireSops.optname: CireSops,
'invariants': [CireInvariantsElementary, CireInvariantsDivs],
'sops': [CireSops],
}

return modes[mode](sregistry, options, platform).process(clusters)
for cls in modes[mode]:
transformer = cls(sregistry, options, platform)

clusters = transformer.process(clusters)

return clusters


class CireTransformer(object):
Expand All @@ -95,8 +100,6 @@ class CireTransformer(object):
Abstract base class for transformers implementing a CIRE variant.
"""

optname = None

def __init__(self, sregistry, options, platform):
self.sregistry = sregistry
self.platform = platform
Expand Down Expand Up @@ -221,8 +224,6 @@ def process(self, clusters):

class CireInvariants(CireTransformer, Queue):

optname = 'invariants'

def __init__(self, sregistry, options, platform):
super().__init__(sregistry, options, platform)

Expand Down Expand Up @@ -262,9 +263,24 @@ def callback(self, clusters, prefix, xtracted=None):

return processed

def _lookup_key(self, c, d):
ispace = c.ispace.reset()
dintervals = c.dspace.intervals.drop(d).reset()
properties = frozendict({d: relax_properties(v) for d, v in c.properties.items()})

return AliasKey(ispace, dintervals, c.dtype, None, properties)

def _eval_variants_delta(self, delta_flops, delta_ws):
# Always prefer the Variant with fewer temporaries
return ((delta_ws == 0 and delta_flops < 0) or
(delta_ws > 0))


class CireInvariantsElementary(CireInvariants):

def _generate(self, exprs, exclude):
# E.g., extract `sin(x)` and `cos(x)` from `a*sin(x)*cos(x)`
rule = lambda e: e.is_Function or (e.is_Pow and e.exp.is_Number and e.exp < 1)
# E.g., extract `sin(x)` and `sqrt(x)` from `a*sin(x)*sqrt(x)`
rule = lambda e: e.is_Function or (e.is_Pow and e.exp.is_Number and 0 < e.exp < 1)
cbk_search = lambda e: search(e, rule, 'all', 'bfs_first_hit')
basextr = self._do_generate(exprs, exclude, cbk_search)
if not basextr:
Expand All @@ -282,23 +298,21 @@ def cbk_search(expr):
cbk_compose = lambda e: split(e.args, lambda a: a in basextr)[0]
yield self._do_generate(exprs, exclude, cbk_search, cbk_compose)

def _lookup_key(self, c, d):
ispace = c.ispace.reset()
dintervals = c.dspace.intervals.drop(d).reset()
properties = frozendict({d: relax_properties(v) for d, v in c.properties.items()})

return AliasKey(ispace, dintervals, c.dtype, None, properties)
class CireInvariantsDivs(CireInvariants):

def _eval_variants_delta(self, delta_flops, delta_ws):
# Always prefer the Variant with fewer temporaries
return ((delta_ws == 0 and delta_flops < 0) or
(delta_ws > 0))
def _generate(self, exprs, exclude):
# E.g., extract `1/h_x`
rule = lambda e: e.is_Pow and e.exp.is_Number and e.exp < 0
cbk_search = lambda e: search(e, rule, 'all', 'bfs_first_hit')
basextr = self._do_generate(exprs, exclude, cbk_search)
if not basextr:
return
yield basextr


class CireSops(CireTransformer):

optname = 'sops'

def __init__(self, sregistry, options, platform):
super().__init__(sregistry, options, platform)

Expand Down
10 changes: 5 additions & 5 deletions examples/performance/00_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@
"float (*r0)[y_size][z_size];\n",
"posix_memalign((void**)&r0, 64, sizeof(float[x_size][y_size][z_size]));\n",
"\n",
"float r1 = 1.0F/h_y;\n",
"/* Begin section0 */\n",
"START_TIMER(section0)\n",
"#pragma omp parallel num_threads(nthreads)\n",
Expand All @@ -528,7 +529,6 @@
"}\n",
"STOP_TIMER(section0,timers)\n",
"/* End section0 */\n",
"float r1 = 1.0F/h_y;\n",
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
"{\n",
" /* Begin section1 */\n",
Expand Down Expand Up @@ -1295,6 +1295,7 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_y;\n",
" /* Begin section0 */\n",
" START_TIMER(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
Expand All @@ -1314,7 +1315,6 @@
" }\n",
" STOP_TIMER(section0,timers)\n",
" /* End section0 */\n",
" float r1 = 1.0F/h_y;\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" /* Begin section1 */\n",
Expand Down Expand Up @@ -1585,6 +1585,7 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_y;\n",
" /* Begin section0 */\n",
" START_TIMER(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
Expand All @@ -1604,7 +1605,6 @@
" }\n",
" STOP_TIMER(section0,timers)\n",
" /* End section0 */\n",
" float r1 = 1.0F/h_y;\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" /* Begin section1 */\n",
Expand Down Expand Up @@ -1741,6 +1741,8 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_x;\n",
" float r2 = 1.0F/h_y;\n",
" /* Begin section0 */\n",
" START_TIMER(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
Expand All @@ -1760,8 +1762,6 @@
" }\n",
" STOP_TIMER(section0,timers)\n",
" /* End section0 */\n",
" float r1 = 1.0F/h_x;\n",
" float r2 = 1.0F/h_y;\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" /* Begin section1 */\n",
Expand Down
86 changes: 57 additions & 29 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,59 @@ def test_nested_invariants_v3(self):
exprs = FindNodes(Expression).visit(op)
assert len(exprs[-1].expr.find(sin)) == 0

def test_nested_invariant_v4(self):
"""
Check that nested aliases are optimized away.
"""
grid = Grid((10, 10))

a = Function(name="a", grid=grid, space_order=4)
b = Function(name="b", grid=grid, space_order=4)

e = TimeFunction(name="e", grid=grid, space_order=4)
f = TimeFunction(name="f", grid=grid, space_order=4)

subexpr0 = sqrt(1. + 1./a)
subexpr1 = 1/(8.*subexpr0 - 8./b)
eqns = [Eq(e.forward, e + 1),
Eq(f.forward, f*subexpr0 - f*subexpr1 + e.forward.dx)]

op = Operator(eqns)

trees = retrieve_iteration_tree(op)
assert len(trees) == 3
arrays = [i for i in FindSymbols().visit(trees[0].root) if i.is_Array]
assert len(arrays) == 1
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_nested_invariant_v5(self):
"""
Check that nested aliases are optimized away.
"""
grid = Grid((10, 10))
x, y = grid.dimensions
hx = x.spacing
dt = grid.stepping_dim.spacing

xright = SubDimension.right(name='xright', parent=x, thickness=4)

a = Function(name="a", grid=grid)
b = Function(name="b", grid=grid)
e = TimeFunction(name="e", grid=grid)

expr = 1./(5.*dt*sqrt(a)*b/hx + 2.*dt**2*b**2*a/hx**2 + 3.)
eq = Eq(e.forward, 2.*expr*sqrt(a) + 3.*expr + e*sqrt(a)).subs({x: xright})

op = Operator(eq, openmp=False)

# Check generated code
arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
assert len(arrays) == 1
exprs = FindNodes(Expression).visit(op)
assert len(exprs) == 2
assert exprs[0].write is arrays[0]
assert exprs[0].expr.rhs.is_Pow

@switchconfig(profiling='advanced')
def test_twin_sops(self):
"""
Expand Down Expand Up @@ -1065,7 +1118,7 @@ def d1(field):
# We expect two temporary Arrays which have in common a sub-expression
# stemming from `d0(v, p0, p1)`
arrays = [i for i in FindSymbols().visit(op1._func_table['bf0']) if i.is_Array]
assert len(arrays) == 7
assert len(arrays) == 6
vexpandeds = FindNodes(VExpanded).visit(op1._func_table['bf0'])
assert len(vexpandeds) == (2 if configuration['language'] == 'openmp' else 0)
assert all(i._mem_heap and not i._mem_external for i in arrays)
Expand All @@ -1083,7 +1136,7 @@ def d1(field):

# Also check against expected operation count to make sure
# all redundancies have been detected correctly
assert sum(i.ops for i in summary1.values()) == 76
assert sum(i.ops for i in summary1.values()) == 69

@pytest.mark.parametrize('rotate', [False, True])
def test_from_different_nests(self, rotate):
Expand Down Expand Up @@ -1237,9 +1290,9 @@ def test_catch_best_invariant_v2(self):
assert len(arrays) == 4

exprs = FindNodes(Expression).visit(op)
sqrt_exprs = exprs[:2]
sqrt_exprs = exprs[2:4]
assert all(e.write in arrays for e in sqrt_exprs)
assert all(e.expr.rhs.args[0].is_Pow for e in sqrt_exprs)
assert all(e.expr.rhs.is_Pow for e in sqrt_exprs)
assert all(e.write._mem_heap and not e.write._mem_external for e in sqrt_exprs)

tmp_exprs = exprs[4:6]
Expand Down Expand Up @@ -1420,31 +1473,6 @@ def test_catch_duplicate_from_different_clusters(self):
assert len(arrays) == 3
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_hoisting_if_coupled(self):
"""
Test that coupled aliases are successfully hoisted out of the time loop.
"""
grid = Grid((10, 10))

a = Function(name="a", grid=grid, space_order=4)
b = Function(name="b", grid=grid, space_order=4)

e = TimeFunction(name="e", grid=grid, space_order=4)
f = TimeFunction(name="f", grid=grid, space_order=4)

subexpr0 = sqrt(1. + 1./a)
subexpr1 = 1/(8.*subexpr0 - 8./b)
eqns = [Eq(e.forward, e + 1),
Eq(f.forward, f*subexpr0 - f*subexpr1 + e.forward.dx)]

op = Operator(eqns, opt=('advanced', {'cire-mingain': 28}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 3
arrays = [i for i in FindSymbols().visit(trees[0].root) if i.is_Array]
assert len(arrays) == 2
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_discarded_compound(self):
"""
Test that compound aliases may be ignored if part of a bigger alias.
Expand Down

0 comments on commit 20d5870

Please sign in to comment.