From 1c4be51dc449856c7d410fb2f1d28622a500fde0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Mar 2023 08:28:19 +0000 Subject: [PATCH 1/9] tests: Rearrange unexpansion tests --- tests/conftest.py | 42 +++++- tests/test_dse.py | 267 +++++++++----------------------------- tests/test_unexpansion.py | 191 +++++++++++++++++++++++++++ 3 files changed, 289 insertions(+), 211 deletions(-) create mode 100644 tests/test_unexpansion.py diff --git a/tests/conftest.py b/tests/conftest.py index 55afa7cede..9f1246a294 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,17 @@ import os +import sys from subprocess import check_call import pytest -import sys +from sympy import Add from devito import Eq, configuration, Revolver # noqa from devito.checkpointing import NoopRevolver from devito.finite_differences.differentiable import EvalDerivative from devito.arch import Cpu64, Device, sniff_mpi_distro, Arm from devito.arch.compiler import compiler_registry, IntelCompiler, NvidiaCompiler -from devito.ir.iet import retrieve_iteration_tree, FindNodes, Iteration, ParallelBlock +from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock, + retrieve_iteration_tree) from devito.tools import as_tuple try: @@ -335,3 +337,39 @@ def assert_blocking(operator, exp_nests): ('advanced', {'blocklevels': 1, 'skewing': True}), ('advanced', {'blocklevels': 1, 'skewing': True, 'blockinner': True})] + + +# More utilities for testing + + +def get_params(op, *names): + ret = [] + for i in names: + for p in op.parameters: + if i == p.name: + ret.append(p) + return tuple(ret) + + +def get_arrays(iet): + return [i for i in FindSymbols().visit(iet) + if i.is_Array and i._mem_heap] + + +def check_array(array, exp_halo, exp_shape, rotate=False): + assert len(array.dimensions) == len(exp_halo) + + shape = [] + for i in array.symbolic_shape: + if i.is_Number or i.is_Symbol: + shape.append(i) + else: + assert i.is_Add + shape.append(Add(*i.args)) + + if rotate: + exp_shape = (sum(exp_halo[0]) + 1,) + tuple(exp_shape[1:]) + exp_halo = ((0, 0),) + tuple(exp_halo[1:]) + + assert tuple(array.halo) == exp_halo + assert tuple(shape) == tuple(exp_shape) diff --git a/tests/test_dse.py b/tests/test_dse.py index 572fc90056..d11f25f16e 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1,9 +1,9 @@ -from sympy import Add import numpy as np import pytest from cached_property import cached_property -from conftest import skipif, EVAL, _R, assert_structure, assert_blocking # noqa +from conftest import (skipif, EVAL, _R, assert_structure, assert_blocking, # noqa + get_params, get_arrays, check_array) from devito import (NODE, Eq, Inc, Constant, Function, TimeFunction, SparseTimeFunction, # noqa Dimension, SubDimension, ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, @@ -496,36 +496,6 @@ def test_collection(self, exprs, expected): assert len(aliases) == len(expected) assert all(i.pivot in expected for i in aliases) - def get_params(self, op, *names): - ret = [] - for i in names: - for p in op.parameters: - if i == p.name: - ret.append(p) - return tuple(ret) - - def get_arrays(self, iet): - return [i for i in FindSymbols().visit(iet) - if i.is_Array and i._mem_heap] - - def check_array(self, array, exp_halo, exp_shape, rotate=False): - assert len(array.dimensions) == len(exp_halo) - - shape = [] - for i in array.symbolic_shape: - if i.is_Number or i.is_Symbol: - shape.append(i) - else: - assert i.is_Add - shape.append(Add(*i.args)) - - if rotate: - exp_shape = (sum(exp_halo[0]) + 1,) + tuple(exp_shape[1:]) - exp_halo = ((0, 0),) + tuple(exp_halo[1:]) - - assert tuple(array.halo) == exp_halo - assert tuple(shape) == tuple(exp_shape) - @pytest.mark.parametrize('rotate', [False, True]) def test_full_shape(self, rotate): """ @@ -553,11 +523,11 @@ def test_full_shape(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 1 - self.check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2), rotate) + check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2), rotate) # Check numerical output op0(time_M=1) @@ -590,11 +560,11 @@ def test_contracted_shape(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - ys, zs = self.get_params(op1, 'y0_blk0_size', 'z_size') + ys, zs = get_params(op1, 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 1 - self.check_array(arrays[0], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) + check_array(arrays[0], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) # Check numerical output op0(time_M=1) @@ -628,11 +598,11 @@ def test_uncontracted_shape(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 1 - self.check_array(arrays[0], ((1, 1), (1, 1), (0, 0)), (xs+2, ys+2, zs), rotate) + check_array(arrays[0], ((1, 1), (1, 1), (0, 0)), (xs+2, ys+2, zs), rotate) # Check numerical output op0(time_M=1) @@ -666,10 +636,10 @@ def func(f): op1 = Operator(eqn, opt=('advanced', {'openmp': True})) # Check code generation - xs, ys, zs = self.get_params(op1, 'x_size', 'y_size', 'z_size') + xs, ys, zs = get_params(op1, 'x_size', 'y_size', 'z_size') arrays = [i for i in FindSymbols().visit(op1) if i.is_Array] assert len(arrays) == 1 - self.check_array(arrays[0], ((0, 0), (0, 0), (1, 0)), (xs, ys, zs+1)) + check_array(arrays[0], ((0, 0), (0, 0), (1, 0)), (xs, ys, zs+1)) # Check numerical output op0(time_M=1) @@ -702,11 +672,11 @@ def test_full_shape_w_subdims(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'i0x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'i0x0_blk0_size', 'i0y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'i0x0_blk0_size', 'i0y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['i0x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['i0x0_blk0'])) == 1 - self.check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2), rotate) + check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2), rotate) # Check numerical output op0(time_M=1) @@ -747,12 +717,12 @@ def test_mixed_shapes(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 2 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 2 - self.check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs), rotate) - self.check_array(arrays[1], ((1, 1), (0, 0)), (ys+2, zs), rotate) + check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs), rotate) + check_array(arrays[1], ((1, 1), (0, 0)), (ys+2, zs), rotate) # Check numerical output op0(time_M=1) @@ -790,12 +760,12 @@ def test_min_storage_in_isolation(self): # Check code generation # `min-storage` leads to one 2D and one 3D Arrays - xs, ys, zs = self.get_params(op1, 'x_size', 'y_size', 'z_size') + xs, ys, zs = get_params(op1, 'x_size', 'y_size', 'z_size') arrays = [i for i in FindSymbols().visit(op1) if i.is_Array] assert len(arrays) == 2 assert len(FindNodes(VExpanded).visit(op1)) == 1 - self.check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) - self.check_array(arrays[1], ((2, 2), (0, 0)), (ys+4, zs)) + check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) + check_array(arrays[1], ((2, 2), (0, 0)), (ys+4, zs)) # Check that `advanced-fsg` + `min-storage` is incompatible try: @@ -876,12 +846,12 @@ def test_mixed_shapes_v2_w_subdims(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'i0x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'i0x0_blk0_size', 'i0y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'i0x0_blk0_size', 'i0y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['i0x0_blk0']) if i.is_Array] assert len(arrays) == 2 assert len(FindNodes(VExpanded).visit(pbs['i0x0_blk0'])) == 2 - self.check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs), rotate) - self.check_array(arrays[1], ((1, 1), (1, 0)), (ys+2, zs+1), rotate) + check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs), rotate) + check_array(arrays[1], ((1, 1), (1, 0)), (ys+2, zs+1), rotate) # Check numerical output op0(time_M=1) @@ -920,12 +890,12 @@ def test_in_bounds_w_shift(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 2 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 2 - self.check_array(arrays[0], ((1, 0), (1, 1), (0, 0)), (xs+1, ys+2, zs), rotate) - self.check_array(arrays[1], ((1, 0), (1, 1), (0, 0)), (xs+1, ys+2, zs), rotate) + check_array(arrays[0], ((1, 0), (1, 1), (0, 0)), (xs+1, ys+2, zs), rotate) + check_array(arrays[1], ((1, 0), (1, 1), (0, 0)), (xs+1, ys+2, zs), rotate) # Check numerical output op0(time_M=1) @@ -967,13 +937,13 @@ def test_constant_symbolic_distance(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 3 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 3 - self.check_array(arrays[0], ((1, 0), (1, 0)), (xs+1, zs+1), rotate) - self.check_array(arrays[1], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) - self.check_array(arrays[2], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) + check_array(arrays[0], ((1, 0), (1, 0)), (xs+1, zs+1), rotate) + check_array(arrays[1], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) + check_array(arrays[2], ((1, 1), (1, 1)), (ys+2, zs+2), rotate) # Check numerical output op0(time_M=1) @@ -1012,11 +982,11 @@ def test_outlier_with_long_diameter(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - ys, zs = self.get_params(op1, 'y0_blk0_size', 'z_size') + ys, zs = get_params(op1, 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 1 - self.check_array(arrays[0], ((1, 1), (1, 0)), (ys+2, zs+1), rotate) + check_array(arrays[0], ((1, 1), (1, 0)), (ys+2, zs+1), rotate) # Check numerical output op0(time_M=1) @@ -1294,12 +1264,12 @@ def test_minimize_remainders_due_to_autopadding(self, rotate): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 0 assert arrays[0].padding == ((0, 0), (0, 0), (0, 30)) - self.check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+32), rotate) + check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+32), rotate) # Check loop bounds trees = retrieve_iteration_tree(bns['x0_blk0']) assert len(trees) == 2 @@ -1460,10 +1430,10 @@ def test_space_invariant_v2(self): op = Operator(eq) # Check code generation - ys = self.get_params(op, 'y_size')[0] + ys = get_params(op, 'y_size')[0] arrays = [i for i in FindSymbols().visit(op) if i.is_Array] assert len(arrays) == 1 - self.check_array(arrays[0], ((0, 0),), (ys,)) + check_array(arrays[0], ((0, 0),), (ys,)) trees = retrieve_iteration_tree(op) assert len(trees) == 2 assert trees[0].root.dim is y @@ -1482,12 +1452,12 @@ def test_space_invariant_v3(self): op = Operator(eq) - xs, ys, zs = self.get_params(op, 'x_size', 'y_size', 'z_size') + xs, ys, zs = get_params(op, 'x_size', 'y_size', 'z_size') arrays = [i for i in FindSymbols().visit(op) if i.is_Array] assert len(arrays) == 3 - self.check_array(arrays[0], ((0, 0),), (ys,)) - self.check_array(arrays[1], ((0, 0), (0, 0)), (xs, zs)) - self.check_array(arrays[2], ((0, 0), (0, 0)), (xs, ys)) + check_array(arrays[0], ((0, 0),), (ys,)) + check_array(arrays[1], ((0, 0), (0, 0)), (xs, zs)) + check_array(arrays[2], ((0, 0), (0, 0)), (xs, ys)) def test_space_invariant_v4(self): """ @@ -1507,121 +1477,12 @@ def test_space_invariant_v4(self): op = Operator(eqns) - xs, ys, zs = self.get_params(op, 'x_size', 'y_size', 'z_size') - arrays = self.get_arrays(op) + xs, ys, zs = get_params(op, 'x_size', 'y_size', 'z_size') + arrays = get_arrays(op) assert len(arrays) == 1 - self.check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs)) + check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs)) assert op._profiler._sections['section1'].sops == 15 - def test_unexpanded_v0(self): - """ - Without prematurely expanding derivatives. - """ - grid = Grid(shape=(10, 10, 10)) - - f = Function(name='f', grid=grid, space_order=4) - u = TimeFunction(name='u', grid=grid, space_order=4) - u1 = TimeFunction(name='u', grid=grid, space_order=4) - - eqn = Eq(u.forward, (u*cos(f)).dx + 1.) - - op0 = Operator(eqn) - op1 = Operator(eqn, opt=('advanced', {'expand': False})) - - # Check generated code - for op in [op0, op1]: - xs, ys, zs = self.get_params(op, 'x_size', 'y_size', 'z_size') - arrays = [i for i in self.get_arrays(op) if i._mem_heap] - assert len(arrays) == 1 - self.check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) - - op0.apply(time_M=10) - op1.apply(time_M=10, u=u1) - - assert np.allclose(u.data, u1.data, rtol=10e-6) - - def test_unexpanded_v1(self): - """ - Inspired by test_space_invariant_v5, but now try with unexpanded - derivatives. - """ - grid = Grid(shape=(10, 10, 10)) - - f = Function(name='f', grid=grid, space_order=4) - u = TimeFunction(name='u', grid=grid, space_order=4) - v = TimeFunction(name='v', grid=grid, space_order=4) - u1 = TimeFunction(name='u', grid=grid, space_order=4) - v1 = TimeFunction(name='v', grid=grid, space_order=4) - - eqns = [Eq(u.forward, (u*cos(f)).dx + v + 1.), - Eq(v.forward, (v*cos(f)).dy + u.forward.dx + 1.)] - - op0 = Operator(eqns) - op1 = Operator(eqns, opt=('advanced', {'expand': False})) - - # Check generated code - for op in [op0, op1]: - xs, ys, zs = self.get_params(op, 'x_size', 'y_size', 'z_size') - arrays = self.get_arrays(op) - assert len(arrays) == 1 - self.check_array(arrays[0], ((2, 2), (2, 2), (0, 0)), (xs+4, ys+4, zs)) - assert op1._profiler._sections['section1'].sops == 44 - - op0.apply(time_M=10) - op1.apply(time_M=10, u=u1, v=v1) - - assert np.allclose(u.data, u1.data, rtol=10e-5) - assert np.allclose(v.data, v1.data, rtol=10e-5) - - def test_unexpanded_v2(self): - grid = Grid(shape=(10, 10, 10)) - - u = TimeFunction(name='u', grid=grid, space_order=4) - v = TimeFunction(name='v', grid=grid, space_order=4) - u1 = TimeFunction(name='u', grid=grid, space_order=4) - v1 = TimeFunction(name='v', grid=grid, space_order=4) - - eqns = [Eq(u.forward, (u.dx.dy + v*u.dx + 1.)), - Eq(v.forward, (v.dy.dx + u.dx.dz + 1.))] - - op0 = Operator(eqns) - op1 = Operator(eqns, opt=('advanced', {'expand': False, - 'blocklevels': 0})) - - # Check generated code -- expect maximal fusion! - assert_structure(op1, - ['t,x,y,z', 't,x,y,z,i0', 't,x,y,z,i1', 't,x,y,z,i1,i0'], - 't,x,y,z,i0,i1,i0') - - op0.apply(time_M=5) - op1.apply(time_M=5, u=u1, v=v1) - - assert np.allclose(u.data, u1.data, rtol=10e-3) - assert np.allclose(v.data, v1.data, rtol=10e-3) - - def test_unexpanded_v3(self): - grid = Grid(shape=(10, 10, 10)) - - u = TimeFunction(name='u', grid=grid, space_order=4) - v = TimeFunction(name='v', grid=grid, space_order=4) - u1 = TimeFunction(name='u', grid=grid, space_order=4) - v1 = TimeFunction(name='v', grid=grid, space_order=4) - - eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)), - Eq(v.forward, (v + u.dx.dy + 1.))] - - op0 = Operator(eqns) - op1 = Operator(eqns, opt=('advanced', {'expand': False})) - - # Check generated code -- redundant IndexDerivatives have been caught! - op1._profiler._sections['section0'].sops == 65 - - op0.apply(time_M=5) - op1.apply(time_M=5, u=u1, v=v1) - - assert np.allclose(u.data, u1.data, rtol=10e-3) - assert np.allclose(v.data, v1.data, rtol=10e-3) - def test_catch_duplicate_from_different_clusters(self): """ Check that the compiler is able to detect redundant aliases when these @@ -1818,10 +1679,10 @@ def test_full_shape_big_temporaries(self): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'}) - xs, ys, zs = self.get_params(op1, 'x_size', 'y_size', 'z_size') + xs, ys, zs = get_params(op1, 'x_size', 'y_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 - self.check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2)) + check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2)) # Check that `cire-rotate=True` has no effect in this code has there's # no cross-loop blocking @@ -1911,14 +1772,14 @@ def g1_tilde(field, phi): arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 5 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 3 - xs, ys, zs = self.get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size') # The three kind of derivatives taken -- in x, y, z -- all are over # different expressions, so this leads to three temporaries of dimensionality, # in particular 3D for the x-derivative, 2D for the y-derivative, and 1D # for the z-derivative - self.check_array(arrays[2], ((2, 1), (0, 0), (0, 0)), (xs+3, ys, zs)) - self.check_array(arrays[3], ((2, 1), (0, 0)), (ys+3, zs)) - self.check_array(arrays[4], ((2, 1),), (zs+3,)) + check_array(arrays[2], ((2, 1), (0, 0), (0, 0)), (xs+3, ys, zs)) + check_array(arrays[3], ((2, 1), (0, 0)), (ys+3, zs)) + check_array(arrays[4], ((2, 1),), (zs+3,)) @pytest.mark.parametrize('so_ops', [(4, 51), (8, 95)]) @switchconfig(profiling='advanced') @@ -2171,12 +2032,12 @@ def test_tti_adjoint_akin_v2(self): # Check code generation bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 4 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 2 - self.check_array(arrays[2], ((6, 6), (6, 6), (6, 6)), (xs+12, ys+12, zs+12)) - self.check_array(arrays[3], ((3, 3),), (zs+6,)) + check_array(arrays[2], ((6, 6), (6, 6), (6, 6)), (xs+12, ys+12, zs+12)) + check_array(arrays[3], ((3, 3),), (zs+6,)) # Check numerical output op0(time_M=1) @@ -2395,10 +2256,10 @@ def test_maxpar_option_v2(self): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 1 - self.check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) + check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) # Check numerical output op0.apply(time_M=2) @@ -2419,11 +2280,11 @@ def test_maxpar_option_v3(self): op = Operator(eq, opt=('advanced', {'cire-maxpar': True})) # Check code generation - xs, ys = self.get_params(op, 'x_size', 'y_size') + xs, ys = get_params(op, 'x_size', 'y_size') arrays = [i for i in FindSymbols().visit(op) if i.is_Array] assert len(arrays) == 2 - self.check_array(arrays[0], ((2, 2), (2, 2)), (xs+4, ys+4)) - self.check_array(arrays[1], ((2, 2), (2, 2)), (xs+4, ys+4)) + check_array(arrays[0], ((2, 2), (2, 2)), (xs+4, ys+4)) + check_array(arrays[1], ((2, 2), (2, 2)), (xs+4, ys+4)) assert_structure(op, ['t,x,y', 't,x,y'], 't,x,y,x,y') @pytest.mark.parametrize('rotate', [False, True]) @@ -2577,12 +2438,12 @@ def test_grouping_fallback(self, rotate): # Check code generation # `min-storage` leads to one 2D and one 3D Arrays bns, pbs = assert_blocking(op1, {'x0_blk0'}) - xs, ys, zs = self.get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size') arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] assert len(arrays) == 3 assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 2 - self.check_array(arrays[1], ((4, 4), (0, 0)), (ys+8, zs), rotate) - self.check_array(arrays[2], ((4, 4),), (zs+8,)) # On purpose w/o `rotate` + check_array(arrays[1], ((4, 4), (0, 0)), (ys+8, zs), rotate) + check_array(arrays[2], ((4, 4),), (zs+8,)) # On purpose w/o `rotate` # Check numerical output op0.apply(time_M=2) @@ -2743,18 +2604,6 @@ def test_premature_evalderiv_lowering(self): assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 1 assert op._profiler._sections['section0'].sops == 16 - def test_fusion_after_unexpansion(self): - grid = Grid(shape=(10, 10)) - - u = TimeFunction(name='u', grid=grid, space_order=4) - - eqn = Eq(u.forward, u.dx + u.dy) - - op = Operator(eqn, opt=('advanced', {'expand': False})) - - assert op._profiler._sections['section0'].sops == 21 - assert_structure(op, ['t,x,y', 't,x,y,i0'], 't,x,y,i0') - class TestIsoAcoustic(object): diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py new file mode 100644 index 0000000000..b098773fcf --- /dev/null +++ b/tests/test_unexpansion.py @@ -0,0 +1,191 @@ +import numpy as np + +from conftest import assert_structure, get_params, get_arrays, check_array +from devito import Buffer, Eq, Function, TimeFunction, Grid, Operator, cos, sin + + +class TestBasic(object): + + def test_v0(self): + grid = Grid(shape=(10, 10, 10)) + + f = Function(name='f', grid=grid, space_order=4) + u = TimeFunction(name='u', grid=grid, space_order=4) + u1 = TimeFunction(name='u', grid=grid, space_order=4) + + eqn = Eq(u.forward, (u*cos(f)).dx + 1.) + + op0 = Operator(eqn) + op1 = Operator(eqn, opt=('advanced', {'expand': False})) + + # Check generated code + for op in [op0, op1]: + xs, ys, zs = get_params(op, 'x_size', 'y_size', 'z_size') + arrays = [i for i in get_arrays(op) if i._mem_heap] + assert len(arrays) == 1 + check_array(arrays[0], ((2, 2), (0, 0), (0, 0)), (xs+4, ys, zs)) + + op0.apply(time_M=10) + op1.apply(time_M=10, u=u1) + + assert np.allclose(u.data, u1.data, rtol=10e-6) + + def test_fusion_after_unexpansion(self): + grid = Grid(shape=(10, 10)) + + u = TimeFunction(name='u', grid=grid, space_order=4) + + eqn = Eq(u.forward, u.dx + u.dy) + + op = Operator(eqn, opt=('advanced', {'expand': False})) + + assert op._profiler._sections['section0'].sops == 21 + assert_structure(op, ['t,x,y', 't,x,y,i0'], 't,x,y,i0') + + def test_v1(self): + grid = Grid(shape=(10, 10, 10)) + + f = Function(name='f', grid=grid, space_order=4) + u = TimeFunction(name='u', grid=grid, space_order=4) + v = TimeFunction(name='v', grid=grid, space_order=4) + u1 = TimeFunction(name='u', grid=grid, space_order=4) + v1 = TimeFunction(name='v', grid=grid, space_order=4) + + eqns = [Eq(u.forward, (u*cos(f)).dx + v + 1.), + Eq(v.forward, (v*cos(f)).dy + u.forward.dx + 1.)] + + op0 = Operator(eqns) + op1 = Operator(eqns, opt=('advanced', {'expand': False})) + + # Check generated code + for op in [op0, op1]: + xs, ys, zs = get_params(op, 'x_size', 'y_size', 'z_size') + arrays = get_arrays(op) + assert len(arrays) == 1 + check_array(arrays[0], ((2, 2), (2, 2), (0, 0)), (xs+4, ys+4, zs)) + assert op1._profiler._sections['section1'].sops == 44 + + op0.apply(time_M=10) + op1.apply(time_M=10, u=u1, v=v1) + + assert np.allclose(u.data, u1.data, rtol=10e-5) + assert np.allclose(v.data, v1.data, rtol=10e-5) + + def test_v2(self): + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid, space_order=4) + v = TimeFunction(name='v', grid=grid, space_order=4) + u1 = TimeFunction(name='u', grid=grid, space_order=4) + v1 = TimeFunction(name='v', grid=grid, space_order=4) + + eqns = [Eq(u.forward, (u.dx.dy + v*u.dx + 1.)), + Eq(v.forward, (v.dy.dx + u.dx.dz + 1.))] + + op0 = Operator(eqns) + op1 = Operator(eqns, opt=('advanced', {'expand': False, + 'blocklevels': 0})) + + # Check generated code -- expect maximal fusion! + assert_structure(op1, + ['t,x,y,z', 't,x,y,z,i0', 't,x,y,z,i1', 't,x,y,z,i1,i0'], + 't,x,y,z,i0,i1,i0') + + op0.apply(time_M=5) + op1.apply(time_M=5, u=u1, v=v1) + + assert np.allclose(u.data, u1.data, rtol=10e-3) + assert np.allclose(v.data, v1.data, rtol=10e-3) + + def test_v3(self): + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid, space_order=4) + v = TimeFunction(name='v', grid=grid, space_order=4) + u1 = TimeFunction(name='u', grid=grid, space_order=4) + v1 = TimeFunction(name='v', grid=grid, space_order=4) + + eqns = [Eq(u.forward, (u.dx.dy + v*u + 1.)), + Eq(v.forward, (v + u.dx.dy + 1.))] + + op0 = Operator(eqns) + op1 = Operator(eqns, opt=('advanced', {'expand': False})) + + # Check generated code -- redundant IndexDerivatives have been caught! + op1._profiler._sections['section0'].sops == 65 + + op0.apply(time_M=5) + op1.apply(time_M=5, u=u1, v=v1) + + assert np.allclose(u.data, u1.data, rtol=10e-3) + assert np.allclose(v.data, v1.data, rtol=10e-3) + + def test_v4(self): + grid = Grid(shape=(16, 16, 16)) + t = grid.stepping_dim + x, y, z = grid.dimensions + + so = 4 + + a = Function(name='a', grid=grid, space_order=so) + f = Function(name='f', grid=grid, space_order=so) + e = Function(name='e', grid=grid, space_order=so) + r = Function(name='r', grid=grid, space_order=so) + p0 = TimeFunction(name='p0', grid=grid, time_order=2, space_order=so) + m0 = TimeFunction(name='m0', grid=grid, time_order=2, space_order=so) + + def g1(field, r, e): + return (cos(e) * cos(r) * field.dx(x0=x+x.spacing/2) + + cos(e) * sin(r) * field.dy(x0=y+y.spacing/2) - + sin(e) * field.dz(x0=z+z.spacing/2)) + + def g2(field, r, e): + return - (sin(r) * field.dx(x0=x+x.spacing/2) - + cos(r) * field.dy(x0=y+y.spacing/2)) + + def g3(field, r, e): + return (sin(e) * cos(r) * field.dx(x0=x+x.spacing/2) + + sin(e) * sin(r) * field.dy(x0=y+y.spacing/2) + + cos(e) * field.dz(x0=z+z.spacing/2)) + + def g1_tilde(field, r, e): + return ((cos(e) * cos(r) * field).dx(x0=x-x.spacing/2) + + (cos(e) * sin(r) * field).dy(x0=y-y.spacing/2) - + (sin(e) * field).dz(x0=z-z.spacing/2)) + + def g2_tilde(field, r, e): + return - ((sin(r) * field).dx(x0=x-x.spacing/2) - + (cos(r) * field).dy(x0=y-y.spacing/2)) + + def g3_tilde(field, r, e): + return ((sin(e) * cos(r) * field).dx(x0=x-x.spacing/2) + + (sin(e) * sin(r) * field).dy(x0=y-y.spacing/2) + + (cos(e) * field).dz(x0=z-z.spacing/2)) + + update_p = t.spacing**2 * a**2 / f * \ + (g1_tilde(f * g1(p0, r, e), r, e) + + g2_tilde(f * g2(p0, r, e), r, e) + + g3_tilde(f * g3(p0, r, e) + f * g3(m0, r, e), r, e)) + \ + (2 - t.spacing * a) + + update_m = t.spacing**2 * a**2 / f * \ + (g1_tilde(f * g1(m0, r, e), r, e) + + g2_tilde(f * g2(m0, r, e), r, e) + + g3_tilde(f * g3(m0, r, e) + f * g3(p0, r, e), r, e)) + \ + (2 - t.spacing * a) + + eqns = [Eq(p0.forward, update_p), + Eq(m0.forward, update_m)] + + op = Operator(eqns, subs=grid.spacing_map, + opt=('advanced', {'expand': False})) + + # Check code generation + assert op._profiler._sections['section1'].sops == 1442 + assert_structure(op, ['x,y,z', + 't,x0_blk0,y0_blk0,x,y,z', + 't,x0_blk0,y0_blk0,x,y,z,i1', + 't,x0_blk0,y0_blk0,x,y,z,i1,i0'], + 'x,y,z,t,x0_blk0,y0_blk0,x,y,z,i1,i0') + + op.cfunction From 821af91638266fe8dd564f13442baebbab558e90 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Mar 2023 11:00:45 +0000 Subject: [PATCH 2/9] compiler: Add Dependence.is_lex_ne --- devito/ir/support/basic.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 7ac4d739e1..621977f3cd 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -562,6 +562,11 @@ def is_lex_equal(self): """ return self.source.timestamp == self.sink.timestamp + @cached_property + def is_lex_ne(self): + """True if the source's and sink's timestamps differ, False otherwise.""" + return self.source.timestamp != self.sink.timestamp + @cached_property def is_lex_negative(self): """ From e5e15cac481936b6c61d6be889eaabad439d878a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Mar 2023 11:02:51 +0000 Subject: [PATCH 3/9] compiler: Improve roboustness of unexpansion --- devito/passes/clusters/derivatives.py | 18 ++++++++-- tests/test_unexpansion.py | 50 +++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 44f04eb519..cc1b8000aa 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -32,16 +32,25 @@ def dump(exprs, c): exprs[:] = [] for c in clusters: + # Can I reuse common IndexDerivatives popping up in different exprs + # within `c`? + # NOTE: this could be refined to rather identify groups of consecutive + # exprs sharing IndexDerivatives, but it's in practice an overkill, since + # we only end up here in artificious cases + unreusable = any(d.is_indep() and d.is_lex_ne for d in c.scope.d_all_gen()) exprs = [] seen = {} for e in c.exprs: expr, v = _lower_index_derivatives_core(e, c, weights, seen, sregistry) - if v: + if v and unreusable: dump(exprs, c) exprs.append(expr) processed.extend(v) + if unreusable: + seen = {} + dump(exprs, c) return processed, weights @@ -79,7 +88,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): # Have I seen this IndexDerivative already? try: return seen[expr], [] - except KeyError: + except (KeyError, TypeError): pass dims = retrieve_dimensions(expr, deep=True) @@ -104,7 +113,10 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1)) # Track IndexDerivative to avoid intra-Cluster duplicates - seen[expr] = s + try: + seen[expr] = s + except TypeError: + pass # Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the # StencilDimensions starting points diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index b098773fcf..9e0899981f 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -2,6 +2,7 @@ from conftest import assert_structure, get_params, get_arrays, check_array from devito import Buffer, Eq, Function, TimeFunction, Grid, Operator, cos, sin +from devito.types import Symbol class TestBasic(object): @@ -189,3 +190,52 @@ def g3_tilde(field, r, e): 'x,y,z,t,x0_blk0,y0_blk0,x,y,z,i1,i0') op.cfunction + + def test_v5(self): + grid = Grid(shape=(16, 16)) + + p0 = TimeFunction(name='p0', grid=grid, time_order=2, space_order=4, + save=Buffer(2)) + m0 = TimeFunction(name='m0', grid=grid, time_order=2, space_order=4, + save=Buffer(2)) + + eqns = [Eq(p0.forward, (p0.dx + m0.dx).dx + p0.backward), + Eq(m0.forward, m0.dx.dx + m0.backward)] + + op = Operator(eqns, subs=grid.spacing_map, + opt=('advanced', {'expand': False})) + + # Check code generation + assert op._profiler._sections['section0'].sops == 127 + assert_structure(op, ['t,x,y', 't,x,y,i1', 't,x,y,i1,i0'], 't,x,y,i1,i0') + + op.cfunction + + def test_v6(self): + grid = Grid(shape=(16, 16)) + + f = Function(name='f', grid=grid, space_order=4) + g = Function(name='g', grid=grid, space_order=4) + p0 = TimeFunction(name='p0', grid=grid, time_order=2, space_order=4, + save=Buffer(2)) + m0 = TimeFunction(name='m0', grid=grid, time_order=2, space_order=4, + save=Buffer(2)) + + s0 = Symbol(name='s0', dtype=np.float32) + + eqns = [Eq(p0.forward, (p0.dx + m0.dx).dx + p0.backward), + Eq(s0, 4., implicit_dims=p0.dimensions), + Eq(m0.forward, (m0.dx + s0).dx + f*m0.backward)] + + op = Operator(eqns, subs=grid.spacing_map, + opt=('advanced', {'expand': False})) + + # Check code generation + assert op._profiler._sections['section0'].sops == 183 + assert_structure( + op, + ['t,x,y', 't,x,y,i1', 't,x,y,i1,i0', 't,x,y,i1', 't,x,y,i1,i0'], + 't,x,y,i1,i0,i1,i0' + ) + + op.cfunction From 676066dd9daaf97535122725e911a3145c11f8cd Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 17 Apr 2023 09:47:39 +0000 Subject: [PATCH 4/9] compiler: Avoid premature lowering of SteppingDimensions This fixes DDA with Buffer(2) due to t-1 and t+1 mapped to the same ModuloDimension --- devito/ir/clusters/algorithms.py | 18 ++++--------- devito/operator/operator.py | 5 +++- devito/passes/iet/misc.py | 45 +++++++++++++++++++++++++++++--- devito/types/dimension.py | 3 +++ 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 8c294cff2b..0d6c72110d 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -281,7 +281,7 @@ def callback(self, clusters, prefix): mapper[size][si].add(iaf) # Construct the ModuloDimensions - mds = OrderedDict() + mds = [] for size, v in mapper.items(): for si, iafs in list(v.items()): # Offsets are sorted so that the semantic order (t0, t1, t2) follows @@ -290,15 +290,10 @@ def callback(self, clusters, prefix): # sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0 siafs = sorted(iafs, key=lambda i: -np.inf if i - si == 0 else (i - si)) - # Create the ModuloDimensions. Note that if `size < len(iafs)` then - # the same ModuloDimension may be used for multiple offsets - for iaf in siafs[:size]: + for iaf in siafs: name = '%s%d' % (si.name, len(mds)) offset = uxreplace(iaf, {si: d.root}) - md = ModuloDimension(name, si, offset, size, origin=iaf) - - key = lambda i: i.subs(si, 0) % size - mds[md] = [i for i in siafs if key(i) == key(iaf)] + mds.append(ModuloDimension(name, si, offset, size, origin=iaf)) # Replacement rule for ModuloDimensions def rule(size, e): @@ -320,11 +315,8 @@ def rule(size, e): exprs = c.exprs groups = as_mapper(mds, lambda d: d.modulo) for size, v in groups.items(): - mapper = {} - for md in v: - mapper.update({i: md for i in mds[md]}) - - func = partial(xreplace_indices, mapper=mapper, key=partial(rule, size)) + subs = {md.origin: md for md in v} + func = partial(xreplace_indices, mapper=subs, key=partial(rule, size)) exprs = [e.apply(func) for e in exprs] # Augment IterationSpace diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 3b5993d646..ec298cbfef 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -20,7 +20,7 @@ from devito.mpi import MPI from devito.parameters import configuration from devito.passes import (Graph, lower_index_derivatives, generate_implicit, - generate_macros, unevaluate) + generate_macros, minimize_symbols, unevaluate) from devito.symbolics import estimate_cost from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten, filter_sorted, frozendict, is_integer, split, timed_pass, @@ -458,6 +458,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): # Extract the necessary macros from the symbolic objects generate_macros(graph) + # Target-independent optimizations + minimize_symbols(graph) + return graph.root, graph # Read-only properties exposed to the outside world diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index a1abd3c658..4199283e54 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -4,14 +4,15 @@ import sympy from devito.finite_differences import Max, Min -from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes, - Transformer, filter_iterations, retrieve_iteration_tree) +from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, + FindNodes, Transformer, Uxreplace, filter_iterations, + retrieve_iteration_tree) from devito.passes.iet.engine import iet_pass from devito.symbolics import evalrel, has_integer_args -from devito.tools import split +from devito.tools import as_mapper, split __all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions', - 'generate_macros'] + 'generate_macros', 'minimize_symbols'] @iet_pass @@ -161,3 +162,39 @@ def _(expr): return {('MAX(a,b)', ('(((a) > (b)) ? (a) : (b))'))} else: return set() + + +@iet_pass +def minimize_symbols(iet): + """ + Remove unneccesary symbols. Currently applied sub-passes: + + * Remove redundant ModuloDimensions (e.g., due to using the + `save=Buffer(2)` API) + """ + iet = remove_redundant_moddims(iet) + + return iet, {} + + +def remove_redundant_moddims(iet): + subs0 = {} + subs1 = {} + for n in FindNodes(Iteration).visit(iet): + mds = [d for d in n.uindices + if d.is_Modulo and d.origin is not None] + if not mds: + continue + + mapper = as_mapper(mds, key=lambda md: md.origin % md.modulo) + for k, v in mapper.items(): + chosen = v.pop(0) + subs0.update({d: chosen for d in v}) + + uindices = [d for d in n.uindices if d not in subs0] + subs1[n] = n._rebuild(uindices=uindices) + + iet = Transformer(subs1, nested=True).visit(iet) + iet = Uxreplace(subs0).visit(iet) + + return iet diff --git a/devito/types/dimension.py b/devito/types/dimension.py index e81fa3f4f6..43e7a3d0f3 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -1502,6 +1502,9 @@ def __sub__(self, other): def __rsub__(self, other): return self.func(other, -self) + def __mod__(self, other): + return sympy.Mod(sympy.Add(*self.args), other) + class AffineIndexAccessFunction(IndexAccessFunction): """ From 32fe68d183b4e1d05ab4c21c8623026968e4353f Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 30 Mar 2023 09:09:10 +0000 Subject: [PATCH 5/9] compiler: Improve unexpansion --- devito/ir/support/basic.py | 8 ++- devito/passes/clusters/derivatives.py | 96 ++++++++++++++++++--------- devito/passes/clusters/misc.py | 4 ++ tests/test_unexpansion.py | 9 +-- 4 files changed, 75 insertions(+), 42 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 621977f3cd..36ea735109 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -411,8 +411,12 @@ def distance(self, other): # Indexed representing an arbitrary access along `x`, within the `t` # IterationSpace, while the sink lives within the `tx` IterationSpace if len(self.itintervals[n:]) != len(other.itintervals[n:]): - ret.append(S.Infinity) - return Vector(*ret) + v = Vector(*ret) + if v != 0: + return v + else: + ret.append(S.Infinity) + return Vector(*ret) # It still could be an imaginary dependence, e.g. `a[3] -> a[4]` or, more # nasty, `a[i+1, 3] -> a[i, 4]` diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index cc1b8000aa..497468de36 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -1,5 +1,5 @@ from devito.finite_differences import IndexDerivative -from devito.ir import Interval, IterationSpace +from devito.ir import Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse from devito.symbolics import (retrieve_dimensions, reuse_if_untouched, q_leaf, uxreplace) @@ -11,7 +11,7 @@ @timed_pass() def lower_index_derivatives(clusters, mode=None, **kwargs): - clusters, weights = _lower_index_derivatives(clusters, **kwargs) + clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs) if not weights: return clusters @@ -19,12 +19,15 @@ def lower_index_derivatives(clusters, mode=None, **kwargs): if mode != 'noop': clusters = fuse(clusters, toposort='maximal') + clusters = CDE(mapper).process(clusters) + return clusters def _lower_index_derivatives(clusters, sregistry=None, **kwargs): - processed = [] weights = {} + processed = [] + mapper = {} def dump(exprs, c): if exprs: @@ -32,31 +35,20 @@ def dump(exprs, c): exprs[:] = [] for c in clusters: - # Can I reuse common IndexDerivatives popping up in different exprs - # within `c`? - # NOTE: this could be refined to rather identify groups of consecutive - # exprs sharing IndexDerivatives, but it's in practice an overkill, since - # we only end up here in artificious cases - unreusable = any(d.is_indep() and d.is_lex_ne for d in c.scope.d_all_gen()) - exprs = [] - seen = {} for e in c.exprs: - expr, v = _lower_index_derivatives_core(e, c, weights, seen, sregistry) - if v and unreusable: + expr, v = _core(e, c, weights, mapper, sregistry) + if v: dump(exprs, c) + processed.extend(v) exprs.append(expr) - processed.extend(v) - - if unreusable: - seen = {} dump(exprs, c) - return processed, weights + return processed, weights, mapper -def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): +def _core(expr, c, weights, mapper, sregistry): """ Recursively carry out the core of `lower_index_derivatives`. """ @@ -66,7 +58,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): args = [] processed = [] for a in expr.args: - e, clusters = _lower_index_derivatives_core(a, c, weights, seen, sregistry) + e, clusters = _core(a, c, weights, mapper, sregistry) args.append(e) processed.extend(clusters) @@ -85,12 +77,6 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): w = weights[k] = w0._rebuild(name=name) expr = uxreplace(expr, {w0.indexed: w.indexed}) - # Have I seen this IndexDerivative already? - try: - return seen[expr], [] - except (KeyError, TypeError): - pass - dims = retrieve_dimensions(expr, deep=True) dims = filter_ordered(d for d in dims if isinstance(d, StencilDimension)) @@ -100,7 +86,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): # upper and lower offsets, we honor it dims = tuple(d for d in dims if d not in c.ispace) - intervals = [Interval(d, 0, 0) for d in dims] + intervals = [Interval(d) for d in dims] ispace0 = IterationSpace(intervals) extra = (c.ispace.itdimensions + dims,) @@ -112,16 +98,60 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry): ispace1 = ispace.project(lambda d: d is not dims[-1]) processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1)) - # Track IndexDerivative to avoid intra-Cluster duplicates - try: - seen[expr] = s - except TypeError: - pass - # Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the # StencilDimensions starting points subs = {expr.weights: expr.weights.subs(d, d - d._min) for d in dims} expr1 = Inc(s, uxreplace(expr.expr, subs)) processed.append(c.rebuild(exprs=expr1, ispace=ispace)) + # Track lowered IndexDerivative for subsequent optimization by the caller + mapper.setdefault(expr1.rhs, []).append(s) + return s, processed + + +class CDE(Queue): + + """ + Common derivative elimination. + """ + + def __init__(self, mapper): + super().__init__() + + self.mapper = {k: v for k, v in mapper.items() if len(v) > 1} + + def process(self, clusters): + return self._process_fdta(clusters, 1, subs={}, seen=set()) + + def callback(self, clusters, prefix, subs=None, seen=None): + processed = [] + for c in clusters: + if c in seen: + processed.append(c) + continue + + exprs = [] + for e in c.exprs: + k, v = e.args + + if k in subs: + continue + + try: + subs[k] = subs[v] + continue + except KeyError: + pass + + if v in self.mapper: + subs[v] = k + exprs.append(e) + else: + exprs.append(uxreplace(e, subs)) + + processed.append(c.rebuild(exprs=exprs)) + + seen.update(processed) + + return processed diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 424e16fa22..ef0e931576 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -1,6 +1,7 @@ from collections import Counter, defaultdict from itertools import groupby, product +from devito.finite_differences import IndexDerivative from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock, WaitLock, WithLock, FetchUpdate, PrefetchUpdate) @@ -188,6 +189,9 @@ def _key(self, c): # Clusters representing HaloTouches should get merged, if possible key += (c.is_halo_touch,) + # Promoting adjacency of IndexDerivatives will maximize their reuse + key += (any(e.find(IndexDerivative) for e in c.exprs),) + return key def _apply_heuristics(self, clusters): diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 9e0899981f..d9d01b6fc9 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -215,7 +215,6 @@ def test_v6(self): grid = Grid(shape=(16, 16)) f = Function(name='f', grid=grid, space_order=4) - g = Function(name='g', grid=grid, space_order=4) p0 = TimeFunction(name='p0', grid=grid, time_order=2, space_order=4, save=Buffer(2)) m0 = TimeFunction(name='m0', grid=grid, time_order=2, space_order=4, @@ -231,11 +230,7 @@ def test_v6(self): opt=('advanced', {'expand': False})) # Check code generation - assert op._profiler._sections['section0'].sops == 183 - assert_structure( - op, - ['t,x,y', 't,x,y,i1', 't,x,y,i1,i0', 't,x,y,i1', 't,x,y,i1,i0'], - 't,x,y,i1,i0,i1,i0' - ) + assert op._profiler._sections['section0'].sops == 133 + assert_structure(op, ['t,x,y', 't,x,y,i1', 't,x,y,i1,i0'], 't,x,y,i1,i0') op.cfunction From 5ab77fabf52bd44d96f58a49665db35973356843 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 17 Apr 2023 16:14:14 +0000 Subject: [PATCH 6/9] compiler: Patch CDE pass --- devito/passes/clusters/derivatives.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 497468de36..00e7526f88 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -122,9 +122,10 @@ def __init__(self, mapper): self.mapper = {k: v for k, v in mapper.items() if len(v) > 1} def process(self, clusters): - return self._process_fdta(clusters, 1, subs={}, seen=set()) + return self._process_fdta(clusters, 1, subs0={}, seen=set()) - def callback(self, clusters, prefix, subs=None, seen=None): + def callback(self, clusters, prefix, subs0=None, seen=None): + subs = {} processed = [] for c in clusters: if c in seen: @@ -135,11 +136,11 @@ def callback(self, clusters, prefix, subs=None, seen=None): for e in c.exprs: k, v = e.args - if k in subs: + if k in subs0: continue try: - subs[k] = subs[v] + subs0[k] = subs[v] continue except KeyError: pass @@ -148,7 +149,7 @@ def callback(self, clusters, prefix, subs=None, seen=None): subs[v] = k exprs.append(e) else: - exprs.append(uxreplace(e, subs)) + exprs.append(uxreplace(e, {**subs0, **subs})) processed.append(c.rebuild(exprs=exprs)) From 9e96bc13e7fd348542049ea01886410dcb2ca002 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 18 Apr 2023 12:35:53 +0000 Subject: [PATCH 7/9] compiler: Refine topo-fusion --- devito/passes/clusters/misc.py | 61 ++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index ef0e931576..9d0e7463bd 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -146,19 +146,34 @@ def callback(self, cgroups, prefix): else: return [ClusterGroup(processed, prefix)] - def _key(self, c): - # Two Clusters/ClusterGroups are fusion candidates if their key is identical + class Key(tuple): - key = (frozenset(c.ispace.itintervals),) + """ + A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that + two Clusters (ClusterGroups) are topo-fusible if and only if their Key is + identical. + + A Key contains several elements that can logically be split into two + groups -- the `strict` and the `weak` components of the Key. + Two Clusters (ClusterGroups) having same `strict` but different `weak` parts + are, as by definition, not fusible; however, since at least their `strict` + parts match, they can at least be topologically reordered. + """ - # If there are writes to thread-shared object, make it part of the key. - # This will promote fusion of non-adjacent Clusters writing to (some form of) - # shared memory, which in turn will minimize the number of necessary barriers - key += (any(f._mem_shared for f in c.scope.writes),) - # Same story for reads from thread-shared objects - key += (any(f._mem_shared for f in c.scope.reads),) + def __new__(cls, strict, weak): + obj = super().__new__(cls, strict + weak) + obj.strict = tuple(strict) + obj.weak = tuple(weak) - key += (c.guards if any(c.guards) else None,) + return obj + + def _key(self, c): + strict = [] + + strict.extend([ + frozenset(c.ispace.itintervals), + c.guards if any(c.guards) else None + ]) # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and # WithLocks, but not with any other SyncOps @@ -181,16 +196,28 @@ def _key(self, c): mapper[k].add(type(s)) else: mapper[k].add(s) - mapper[k] = frozenset(mapper[k]) - if any(mapper.values()): - mapper = frozendict(mapper) - key += (mapper,) + if k in mapper: + mapper[k] = frozenset(mapper[k]) + strict.append(frozendict(mapper)) + + weak = [] # Clusters representing HaloTouches should get merged, if possible - key += (c.is_halo_touch,) + weak.append(c.is_halo_touch) + + # If there are writes to thread-shared object, make it part of the key. + # This will promote fusion of non-adjacent Clusters writing to (some form of) + # shared memory, which in turn will minimize the number of necessary barriers + # Same story for reads from thread-shared objects + weak.extend([ + any(f._mem_shared for f in c.scope.writes), + any(f._mem_shared for f in c.scope.reads) + ]) # Promoting adjacency of IndexDerivatives will maximize their reuse - key += (any(e.find(IndexDerivative) for e in c.exprs),) + weak.append(any(e.find(IndexDerivative) for e in c.exprs)) + + key = self.Key(strict, weak) return key @@ -240,7 +267,7 @@ def dump(): def _toposort(self, cgroups, prefix): # Are there any ClusterGroups that could potentially be fused? If # not, do not waste time computing a new topological ordering - counter = Counter(self._key(cg) for cg in cgroups) + counter = Counter(self._key(cg).strict for cg in cgroups) if not any(v > 1 for it, v in counter.most_common()): return ClusterGroup(cgroups, prefix) From 2621ff3e1f40328c8c0d502eb90bcbd66c6b4ff6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 19 Apr 2023 08:30:29 +0000 Subject: [PATCH 8/9] compiler: Enhance Reconstructable with variadic args support --- devito/tools/abc.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/devito/tools/abc.py b/devito/tools/abc.py index a1f32e91a9..8789f40b9c 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -130,7 +130,11 @@ def __init__(self, a, b, c=4): * `a._rebuild(c=5) -> x(3, 5, 5)` * `a._rebuild(1, c=7) -> x(1, 5, 7)` """ - args += tuple(getattr(self, i) for i in self.__rargs__[len(args):]) + for i in self.__rargs__[len(args):]: + if i.startswith('*'): + args += tuple(getattr(self, i[1:])) + else: + args += (getattr(self, i),) kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs}) # Should we use a constum reconstructor? From 8077f6e75e47b5da8dfe3a23cbc97b42de4686ac Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 19 Apr 2023 08:30:57 +0000 Subject: [PATCH 9/9] compiler: Revamp FIndexed for correct reconstruction --- devito/passes/iet/linearization.py | 2 +- devito/types/misc.py | 23 +++++++++++++++-------- tests/test_symbolics.py | 2 ++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index 9866e755b7..9cd16dc416 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -259,7 +259,7 @@ def _(f, indexeds, tracker, strides, sregistry): if len(i.indices) == i.function.ndim: v = tuple(strides.values())[-n:] - subs[i] = FIndexed(i, pname, strides=v) + subs[i] = FIndexed.from_indexed(i, pname, strides=v) else: # Honour custom indexing subs[i] = i.base[sum(i.indices)] diff --git a/devito/types/misc.py b/devito/types/misc.py index eb0f8b2b8b..031762c69e 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -68,20 +68,21 @@ class FIndexed(Indexed, Pickable): `uX[x*ny + y]`, where `X` is a string provided by the caller. """ - __rargs__ = ('indexed', 'pname') + __rargs__ = ('base', '*indices') __rkwargs__ = ('strides',) - def __new__(cls, indexed, pname, strides=None): - plabel = Symbol(name=pname, dtype=indexed.dtype) - base = IndexedData(plabel, None, function=indexed.function) - obj = super().__new__(cls, base, *indexed.indices) - - obj.indexed = indexed - obj.pname = pname + def __new__(cls, base, *args, strides=None): + obj = super().__new__(cls, base, *args) obj.strides = as_tuple(strides) return obj + @classmethod + def from_indexed(cls, indexed, pname, strides=None): + label = Symbol(name=pname, dtype=indexed.dtype) + base = IndexedData(label, None, function=indexed.function) + return FIndexed(base, *indexed.indices, strides=strides) + def __repr__(self): return "%s(%s)" % (self.name, ", ".join(str(i) for i in self.indices)) @@ -90,10 +91,16 @@ def __repr__(self): def _hashable_content(self): return super()._hashable_content() + (self.strides,) + func = Pickable._rebuild + @property def name(self): return self.function.name + @property + def pname(self): + return self.base.name + @property def free_symbols(self): # The functional representation of the FIndexed "hides" the strides, which diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 807ed35216..751cb197cb 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -13,6 +13,7 @@ INT, FieldFromComposite, IntDiv, ccode, uxreplace) from devito.tools import as_tuple from devito.types import Array, Bundle, LocalObject, Object, Symbol as dSymbol +from devito.types import Array, Bundle, FIndexed, LocalObject, Object, Symbol as dSymbol # noqa def test_float_indices(): @@ -358,6 +359,7 @@ def test_solve_time(): ('f[x, y+1]', '{f.indexed: g.indexed}', 'g[x, y+1]'), ('cos(f)', '{cos: sin}', 'sin(f)'), ('cos(f + sin(g))', '{cos: sin, sin: cos}', 'sin(f + cos(g))'), + ('FIndexed(f.indexed, x, y)', '{x: 0}', 'FIndexed(f.indexed, 0, y)'), ]) def test_uxreplace(expr, subs, expected): grid = Grid(shape=(4, 4))