From 2fbed4d2960eac239c56f45db3f49dba2cc284e6 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Mon, 9 Dec 2024 19:28:53 +0000 Subject: [PATCH] compiler: Fix misc review comments --- devito/ir/iet/nodes.py | 20 +- devito/mpi/halo_scheme.py | 29 +- devito/passes/iet/mpi.py | 30 +- .../seismic/tutorials/09_viscoelastic.ipynb | 2 +- examples/seismic/viscoelastic/operators.py | 2 +- tests/test_mpi.py | 368 +++++++++--------- 6 files changed, 229 insertions(+), 222 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 86e432ad71..23f293ad1e 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -16,7 +16,7 @@ Forward, WithLock, PrefetchUpdate, detect_io) from devito.symbolics import ListInitializer, CallFromPointer, ccode from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten, - ctypes_to_cstr, OrderedSet) + ctypes_to_cstr) from devito.types.basic import (AbstractFunction, AbstractSymbol, Basic, Indexed, Symbol) from devito.types.object import AbstractObject, LocalObject @@ -1438,20 +1438,7 @@ def DummyExpr(*args, init=False): # Nodes required for distributed-memory halo exchange -class HaloMixin: - - def __repr__(self): - fstrings = [] - for f in self.fmapper.keys(): - loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values())) - loc_indices_str = str(list(loc_indices)) if loc_indices else "" - fstrings.append("%s%s" % (f.name, loc_indices_str)) - - functions = ",".join(fstrings) - return "<%s(%s)>" % (self.__class__.__name__, functions) - - -class HaloSpot(HaloMixin, Node): +class HaloSpot(Node): """ A halo exchange operation (e.g., send, recv, wait, ...) required to @@ -1508,6 +1495,9 @@ def body(self): def functions(self): return tuple(self.fmapper) + def __repr__(self): + funcs = self.halo_scheme.__reprfuncs__() + return "<%s(%s)>" % (self.__class__.__name__, funcs) # Utility classes diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index f1bbb79003..d80127fae2 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -9,10 +9,9 @@ from devito import configuration from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT from devito.ir.support import Forward, Scope -from devito.ir.iet.nodes import HaloMixin from devito.symbolics.manipulation import _uxreplace_registry from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten, - frozendict, is_integer, filter_sorted) + frozendict, is_integer, filter_sorted, OrderedSet) from devito.types import Grid __all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch'] @@ -36,8 +35,8 @@ class HaloSchemeEntry(Reconstructable): def __init__(self, loc_indices, loc_dirs, halos, dims): self.loc_indices = frozendict(loc_indices) self.loc_dirs = frozendict(loc_dirs) - self.halos = halos - self.dims = dims + self.halos = frozenset(halos) + self.dims = frozenset(dims) def __eq__(self, other): if not isinstance(other, HaloSchemeEntry): @@ -48,10 +47,10 @@ def __eq__(self, other): self.dims == other.dims) def __hash__(self): - return hash((frozenset(self.loc_indices.items()), - frozenset(self.loc_dirs.items()), - frozenset(self.halos), - frozenset(self.dims))) + return hash((tuple(self.loc_indices.items()), + tuple(self.loc_dirs.items()), + self.halos, + self.dims)) def __repr__(self): return (f"HaloSchemeEntry(loc_indices={self.loc_indices}, " @@ -63,7 +62,7 @@ def __repr__(self): OMapper = namedtuple('OMapper', 'core owned') -class HaloScheme(HaloMixin): +class HaloScheme(): """ A HaloScheme describes a set of halo exchanges through a mapper: @@ -121,6 +120,18 @@ def __init__(self, exprs, ispace): self._honored[i.root] = frozenset([(ltk, rtk)]) self._honored = frozendict(self._honored) + def __reprfuncs__(self): + fstrings = [] + for f in self.fmapper.keys(): + loc_indices = OrderedSet(*(self.fmapper[f].loc_indices.values())) + loc_indices_str = str(list(loc_indices)) if loc_indices else "" + fstrings.append("%s%s" % (f.name, loc_indices_str)) + + return ",".join(fstrings) + + def __repr__(self): + return "<%s(%s)>" % (self.__class__.__name__, self.__reprfuncs__()) + def __eq__(self, other): return (isinstance(other, HaloScheme) and self._mapper == other._mapper and diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 204fcc2cae..030dfacfad 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -12,7 +12,7 @@ from devito.mpi.reduction_scheme import DistReduce from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder from devito.passes.iet.engine import iet_pass -from devito.tools import generator, frozendict +from devito.tools import generator __all__ = ['mpiize'] @@ -94,7 +94,6 @@ def _hoist_invariant(iet): continue for f, v in hs1.fmapper.items(): - if f not in hs0.functions: continue @@ -114,7 +113,7 @@ def _hoist_invariant(iet): else: raw_loc_indices[d] = v - hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices)) + hse = hse._rebuild(loc_indices=raw_loc_indices) hs1.halo_scheme.fmapper[f] = hse hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) @@ -348,20 +347,27 @@ def _filter_iter_mapper(iet): Given an IET, return a mapper from Iterations to the HaloSpots. Additionally, filter out Iterations that are not of interest. """ - iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) - iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void] - for k, v in iter_mapper.items()} - iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} - iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + iter_mapper = {} + for k, v in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items(): + filtered_hs = [hs for hs in v if not hs.halo_scheme.is_void] + if k is not None and len(filtered_hs) > 1: + iter_mapper[k] = filtered_hs return iter_mapper def _make_cond_mapper(iet): - cond_mapper = MapHaloSpots().visit(iet) - return {hs: {i for i in v if i.is_Conditional and - not isinstance(i.condition, GuardFactorEq)} - for hs, v in cond_mapper.items()} + + cond_mapper = {} + for hs, v in MapHaloSpots().visit(iet).items(): + conditionals = set() + for i in v: + if i.is_Conditional and not isinstance(i.condition, GuardFactorEq): + conditionals.add(i) + + cond_mapper[hs] = conditionals + + return cond_mapper def _check_control_flow(hs0, hs1, cond_mapper): diff --git a/examples/seismic/tutorials/09_viscoelastic.ipynb b/examples/seismic/tutorials/09_viscoelastic.ipynb index 5d7cabdd75..0d1d524613 100644 --- a/examples/seismic/tutorials/09_viscoelastic.ipynb +++ b/examples/seismic/tutorials/09_viscoelastic.ipynb @@ -457,7 +457,7 @@ "source": [ "# References\n", "\n", - "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelatic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n", + "[1] Johan O. A. Roberston, *et.al.* (1994). \"Viscoelastic finite-difference modeling\" GEOPHYSICS, 59(9), 1444-1456.\n", "\n", "\n", "[2] https://janth.home.xs4all.nl/Software/fdelmodcManual.pdf" diff --git a/examples/seismic/viscoelastic/operators.py b/examples/seismic/viscoelastic/operators.py index fdf1110004..9e0269665a 100644 --- a/examples/seismic/viscoelastic/operators.py +++ b/examples/seismic/viscoelastic/operators.py @@ -64,4 +64,4 @@ def ForwardOperator(model, geometry, space_order=4, save=False, **kwargs): # Substitute spacing terms to reduce flops return Operator([u_v, u_r, u_t] + src_rec_expr, subs=model.spacing_map, - name='ViscoElForward', **kwargs) + name='ViscoIsoElasticForward', **kwargs) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 7d28dfe7bc..05a5b4b82d 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1008,183 +1008,6 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 0 - @pytest.fixture - def setup(self): - shape = (2,) - so = 2 - tn = 30 - - grid = Grid(shape=shape) - - # Velocity and pressure fields - v = TimeFunction(name='v', grid=grid, space_order=so) - tau = TimeFunction(name='tau', grid=grid, space_order=so) - - # First order elastic-like dependencies equations - pde_v = v.dt - (tau.dx) - pde_tau = tau.dt - ((v.forward).dx) - u_v = Eq(v.forward, solve(pde_v, v.forward)) - u_tau = Eq(tau.forward, solve(pde_tau, tau.forward)) - - # Receiver - rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn) - rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) - - return grid, v, tau, u_v, u_tau, rec - - @pytest.mark.parallel(mode=1) - def test_issue_2448_I(self, mode, setup): - _, v, tau, u_v, u_tau, rec = setup - - rec_term0 = rec.interpolate(expr=v) - - op0 = Operator([u_v, u_tau, rec_term0]) - - calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)] - - assert len(calls) == 3 - assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 - assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 - assert calls[0].arguments[0] is v - assert calls[1].arguments[0] is tau - assert calls[2].arguments[0] is v - - @pytest.mark.parallel(mode=1) - def test_issue_2448_II(self, mode, setup): - _, v, tau, u_v, u_tau, rec = setup - - rec_term1 = rec.interpolate(expr=v.forward) - - op1 = Operator([u_v, u_tau, rec_term1]) - - calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)] - - assert len(calls) == 3 - assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0 - assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3 - assert calls[0].arguments[0] is tau - assert calls[1].arguments[0] is v - assert calls[2].arguments[0] is v - - @pytest.mark.parallel(mode=1) - def test_issue_2448_III(self, mode, setup): - grid, v, tau, u_v, u_tau, rec = setup - - # Additional velocity and pressure fields - v2 = TimeFunction(name='v2', grid=grid, space_order=2) - tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) - - # First order elastic-like dependencies equations - pde_v2 = v2.dt - (tau2.dx) - pde_tau2 = tau2.dt - ((v2.forward).dx) - u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) - u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) - - # Receiver - rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) - rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) - - rec_term0 = rec.interpolate(expr=v) - rec_term2 = rec2.interpolate(expr=v2) - - op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2]) - - calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)] - - assert len(calls) == 5 - assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2 - assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3 - assert calls[0].arguments[0] is v - assert calls[1].arguments[0] is v2 - assert calls[2].arguments[0] is tau - assert calls[2].arguments[1] is tau2 - assert calls[3].arguments[0] is v - assert calls[4].arguments[0] is v2 - - @pytest.mark.parallel(mode=1) - def test_issue_2448_IV(self, mode, setup): - grid, v, tau, u_v, u_tau, rec = setup - - # Additional velocity and pressure fields - v2 = TimeFunction(name='v2', grid=grid, space_order=2) - tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) - - # First order elastic-like dependencies equations - pde_v2 = v2.dt - (tau2.dx) - pde_tau2 = tau2.dt - ((v2.forward).dx) - u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) - u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) - - # Receiver - rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) - rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) - - rec_term0 = rec.interpolate(expr=v) - rec_term3 = rec2.interpolate(expr=v2.forward) - - op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3]) - - calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)] - - assert len(calls) == 5 - assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1 - assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4 - assert calls[0].arguments[0] is v - assert calls[1].arguments[0] is tau - assert calls[1].arguments[1] is tau2 - assert calls[2].arguments[0] is v - assert calls[3].arguments[0] is v2 - assert calls[4].arguments[0] is v2 - - @pytest.mark.parallel(mode=1) - def test_issue_2448_backward(self, mode): - ''' - Similar to test_issue_2448, but with backward instead of forward - so that the hoisted halo has different starting point - ''' - shape = (2,) - so = 2 - - grid = Grid(shape=shape) - t = grid.stepping_dim - - tn = 7 - - # Velocity and pressure fields - v = TimeFunction(name='v', grid=grid, space_order=so) - v.data_with_halo[0, :] = 1. - v.data_with_halo[1, :] = 3. - - tau = TimeFunction(name='tau', grid=grid, space_order=so) - tau.data_with_halo[:] = 1. - - # First order elastic-like dependencies equations - pde_v = v.dt - (tau.dx) - pde_tau = tau.dt - ((v.backward).dx) - - u_v = Eq(v.backward, solve(pde_v, v)) - u_tau = Eq(tau.backward, solve(pde_tau, tau)) - - # Test two variants of receiver interpolation - nrec = 1 - rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) - rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) - - # Test receiver interpolation 0, here we have a halo exchange hoisted - op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v)) - - calls = [i for i in FindNodes(Call).visit(op0) - if isinstance(i, HaloUpdateCall)] - - # The correct we want - assert len(calls) == 3 - assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 - assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 - assert calls[0].arguments[0] is v - assert calls[0].arguments[3].args[0] is t.symbolic_max - assert calls[1].arguments[0] is tau - assert calls[2].arguments[0] is v - @pytest.mark.parallel(mode=1) def test_avoid_haloupdate_with_subdims(self, mode): grid = Grid(shape=(4,)) @@ -1400,7 +1223,7 @@ def test_avoid_fullmode_if_crossloop_dep(self, mode): assert np.all(f.data[:] == 2.) @pytest.mark.parallel(mode=2) - def test_avoid_halopudate_if_flowdep_along_other_dim(self, mode): + def test_avoid_haloupdate_if_flowdep_along_other_dim(self, mode): grid = Grid(shape=(10,)) x = grid.dimensions[0] t = grid.stepping_dim @@ -1513,8 +1336,10 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode): * the second and third Eqs cannot be fused in the same loop - In the IET we end up with *one* HaloSpots, placed right before the - second Eq. The third Eq will seamlessy find its halo up-to-date. + In the IET we end up with *two* HaloSpots, one placed before the + time loop, and one placed before the second Eq. The third Eq, + reading from f[t0], will seamlessy find its halo up-to-date, + due to the f[t1] being updated in the previous time iteration. """ grid = Grid(shape=(10,)) x = grid.dimensions[0] @@ -1545,6 +1370,7 @@ def test_merge_haloupdate_if_diff_locindices_v1(self, mode): assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[2])) == 0 op.apply(time_M=1) + glb_pos_map = f.grid.distributor.glb_pos_map R = 1e-07 # Can't use np.all due to rounding error at the tails if LEFT in glb_pos_map[x]: @@ -2911,7 +2737,7 @@ def test_adjoint_F_no_omp(self, mode): self.run_adjoint_F(3) -class TestElastic: +class TestElasticLike: @pytest.mark.parallel(mode=[(1, 'diag')]) def test_elastic_structure(self, mode): @@ -2927,7 +2753,6 @@ def test_elastic_structure(self, mode): mu = Function(name='mu', grid=grid) ro = Function(name='b', grid=grid) - # The receiver rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=10) rec_term = rec.interpolate(expr=v[0] + v[1]) @@ -2945,7 +2770,6 @@ def test_elastic_structure(self, mode): calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)] - # The correct we want assert len(calls) == 5 assert len(FindNodes(HaloUpdateCall).visit(op.body.body[1].body[1].body[0])) == 1 @@ -2960,12 +2784,188 @@ def test_elastic_structure(self, mode): assert calls[4].arguments[0] is v[0] assert calls[4].arguments[1] is v[1] + @pytest.fixture + def setup(self): + """ + This fixture sets up the grid, fields, elastic-like + equations and receivers for test_issue_2448_*. + """ + shape = (2,) + so = 2 + tn = 30 + + grid = Grid(shape=shape) + + # Velocity and pressure fields + v = TimeFunction(name='v', grid=grid, space_order=so) + tau = TimeFunction(name='tau', grid=grid, space_order=so) + + # First order elastic-like dependencies equations + pde_v = v.dt - (tau.dx) + pde_tau = tau.dt - ((v.forward).dx) + u_v = Eq(v.forward, solve(pde_v, v.forward)) + u_tau = Eq(tau.forward, solve(pde_tau, tau.forward)) + + rec = SparseTimeFunction(name="rec", grid=grid, npoint=1, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=1) + + return grid, v, tau, u_v, u_tau, rec + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v0(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup + + rec_term0 = rec.interpolate(expr=v) + + op0 = Operator([u_v, u_tau, rec_term0]) + + calls = [i for i in FindNodes(Call).visit(op0) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is tau + assert calls[2].arguments[0] is v + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v1(self, mode, setup): + _, v, tau, u_v, u_tau, rec = setup + + rec_term1 = rec.interpolate(expr=v.forward) + + op1 = Operator([u_v, u_tau, rec_term1]) + + calls = [i for i in FindNodes(Call).visit(op1) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[0])) == 0 + assert len(FindNodes(HaloUpdateCall).visit(op1.body.body[1].body[1])) == 3 + assert calls[0].arguments[0] is tau + assert calls[1].arguments[0] is v + assert calls[2].arguments[0] is v + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v2(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup + + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) + + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = tau2.dt - ((v2.forward).dx) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) + + rec_term0 = rec.interpolate(expr=v) + rec_term2 = rec2.interpolate(expr=v2) + + op2 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term2]) + + calls = [i for i in FindNodes(Call).visit(op2) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is v2 + assert calls[2].arguments[0] is tau + assert calls[2].arguments[1] is tau2 + assert calls[3].arguments[0] is v + assert calls[4].arguments[0] is v2 + + @pytest.mark.parallel(mode=1) + def test_issue_2448_v3(self, mode, setup): + grid, v, tau, u_v, u_tau, rec = setup + + # Additional velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=2) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=2) + + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = tau2.dt - ((v2.forward).dx) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=1, nt=30) + rec2.coordinates.data[:, 0] = np.linspace(0., grid.shape[0], num=1) + + rec_term0 = rec.interpolate(expr=v) + rec_term3 = rec2.interpolate(expr=v2.forward) + + op3 = Operator([u_v, u_v2, u_tau, u_tau2, rec_term0, rec_term3]) + + calls = [i for i in FindNodes(Call).visit(op3) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 5 + assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op3.body.body[1].body[1].body[1])) == 4 + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is tau + assert calls[1].arguments[1] is tau2 + assert calls[2].arguments[0] is v + assert calls[3].arguments[0] is v2 + assert calls[4].arguments[0] is v2 + + @pytest.mark.parallel(mode=1) + def test_issue_2448_backward(self, mode): + ''' + Similar to test_issue_2448, but with backward instead of forward + so that the hoisted halo has different starting point + ''' + shape = (2,) + so = 2 + + grid = Grid(shape=shape) + t = grid.stepping_dim + + tn = 7 + + # Velocity and pressure fields + v = TimeFunction(name='v', grid=grid, space_order=so) + v.data_with_halo[0, :] = 1. + v.data_with_halo[1, :] = 3. + + tau = TimeFunction(name='tau', grid=grid, space_order=so) + tau.data_with_halo[:] = 1. + + # First order elastic-like dependencies equations + pde_v = v.dt - (tau.dx) + pde_tau = tau.dt - ((v.backward).dx) + + u_v = Eq(v.backward, solve(pde_v, v)) + u_tau = Eq(tau.backward, solve(pde_tau, tau)) + + # Test two variants of receiver interpolation + nrec = 1 + rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + + # Test receiver interpolation 0, here we have a halo exchange hoisted + op0 = Operator([u_v] + [u_tau] + rec.interpolate(expr=v)) + + calls = [i for i in FindNodes(Call).visit(op0) + if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 3 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op0.body.body[1].body[1].body[1])) == 2 + assert calls[0].arguments[0] is v + assert calls[0].arguments[3].args[0] is t.symbolic_max + assert calls[1].arguments[0] is tau + assert calls[2].arguments[0] is v + class TestTTIOp: @pytest.mark.parallel(mode=1) def test_halo_structure(self, mode): - solver = TestTTI().tti_operator(opt='advanced', space_order=8) op = solver.op_fwd(save=False)