diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 72ff6e3c04..96b1b89c57 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1466,7 +1466,7 @@ def __repr__(self): functions = ",".join(fstrings) - return f"<{self.__class__.__name__}({functions})>" + return "<%s(%s)>" % (self.__class__.__name__, functions) @property def halo_scheme(self): diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 64a47623ff..266bcbb15e 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -131,7 +131,7 @@ def __repr__(self): functions = ",".join(fstrings) - return f"<{self.__class__.__name__}({functions})>" + return "<%s(%s)>" % (self.__class__.__name__, functions) def __eq__(self, other): return (isinstance(other, HaloScheme) and @@ -677,8 +677,8 @@ def _uxreplace_dispatch_haloscheme(hs0, rule): # Nope, let's try with the next Indexed, if any continue - hse = hse0.rebuild(loc_indices=frozendict(loc_indices), - loc_dirs=frozendict(loc_dirs)) + hse = hse0._rebuild(loc_indices=frozendict(loc_indices), + loc_dirs=frozendict(loc_dirs)) else: continue diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 72e351bc84..5595695eb1 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -1,6 +1,7 @@ from collections import defaultdict from sympy import S +from itertools import combinations from devito.ir.iet import (Call, Expression, HaloSpot, Iteration, FindNodes, MapNodes, MapHaloSpots, Transformer, @@ -19,11 +20,11 @@ @iet_pass def optimize_halospots(iet, **kwargs): """ - Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved - around in order to improve the halo exchange performance. + Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, hoisted, + merged and moved around in order to improve the halo exchange performance. """ - iet = _drop_halospots(iet) - iet = _hoist_halospots(iet) + iet = _drop_reduction_halospots(iet) + iet = _hoist_invariant(iet) iet = _merge_halospots(iet) iet = _drop_if_unwritten(iet, **kwargs) iet = _mark_overlappable(iet) @@ -31,7 +32,7 @@ def optimize_halospots(iet, **kwargs): return iet, {} -def _drop_halospots(iet): +def _drop_reduction_halospots(iet): """ Remove HaloSpots that: @@ -48,17 +49,19 @@ def _drop_halospots(iet): mapper[hs].add(f) # Transform the IET introducing the "reduced" HaloSpots - subs = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs])) - for hs in FindNodes(HaloSpot).visit(iet)} - iet = Transformer(subs, nested=True).visit(iet) + mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(mapper[hs])) + for hs in FindNodes(HaloSpot).visit(iet)} + iet = Transformer(mapper, nested=True).visit(iet) return iet -def _hoist_halospots(iet): +def _hoist_invariant(iet): """ Hoist HaloSpots from inner to outer Iterations where all data dependencies - would be honored. + would be honored. This pass is particularly useful to avoid redundant + halo exchanges when the same data is redundantly exchanged within the + same Iteration tree level. Example: haloupd v[t0] @@ -80,108 +83,123 @@ def _hoist_halospots(iet): hsmapper = {} imapper = defaultdict(list) - # Look for parent Iterations of children HaloSpots - for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items(): - for i, hs0 in enumerate(halo_spots): + iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) + + # Drop void `halo_scheme`s from the analysis + iter_mapper = {k: [hs for hs in v if not hs.halo_scheme.is_void] + for k, v in iter_mapper.items()} + + # Drop pairs that have keys that are None + iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} + + # Drop iter_mapper pairs where len(halo_spots) <= 1 + iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + + for it, halo_spots in iter_mapper.items(): - # Nothing to do if the HaloSpot is void - if hs0.halo_scheme.is_void: + for hs0, hs1 in combinations(halo_spots, r=2): + + if ensure_control_flow(hs0, hs1, cond_mapper): continue - for hs1 in halo_spots[i+1:]: - # If there are Conditionals involved, both `hs0` and `hs1` must be - # within the same Conditional, otherwise we would break the control - if cond_mapper.get(hs0) != cond_mapper.get(hs1): - continue + # If there are overlapping time accesses, skip + hs0_mdims = hs0.halo_scheme.loc_values + hs1_mdims = hs1.halo_scheme.loc_values + if hs0_mdims.intersection(hs1_mdims): + continue - # If there are overlapping time accesses, skip - if hs0.halo_scheme.loc_values.intersection(hs1.halo_scheme.loc_values): + # Loop over the functions in the HaloSpots + for f, v in hs1.fmapper.items(): + + # If the function is not in both HaloSpots, skip + if f not in hs0.functions: continue - # Loop over the functions in the HaloSpots - for f, v in hs1.fmapper.items(): - # If no time accesses, skip - if not hs1.halo_scheme.fmapper[f].loc_indices: - continue + for dep in scopes[it].d_flow.project(f): + if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()): + break + else: + # hs1 is lifted out of `it`, and we need to get + # the new indexing for the HaloSpot + hse = hs1.halo_scheme.fmapper[f] + raw_loc_indices = {} + + for d in hse.loc_indices: + md = hse.loc_indices[d] + if md in it.uindices: + new_min = md.symbolic_min.subs(it.dim, + it.dim.symbolic_min) + raw_loc_indices[d] = new_min + else: + raw_loc_indices[d] = md + + hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices)) + hs1.halo_scheme.fmapper[f] = hse - # If the function is not in both HaloSpots, skip - if f not in hs0.functions: - continue + hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) - for it in iters: - for dep in scopes[it].d_flow.project(f): - if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): - break - else: - hse = hs1.halo_scheme.fmapper[f] - raw_loc_indices = {} - # Entering here means we can lift, and we need to update - # the loc_indices with known values - # TODO: Can I get this in a more elegant way? - for d in hse.loc_indices: - md = hse.loc_indices[d] - if md.is_Symbol: - root = md.root - hse_min = md.symbolic_min - new_min = hse_min.subs(root, root.symbolic_min) - raw_loc_indices[d] = new_min - else: - # md is in form of an expression - assert d.symbolic_min in md.free_symbols - raw_loc_indices[d] = md - - hse = hse._rebuild(loc_indices=frozendict(raw_loc_indices)) - hs1.halo_scheme.fmapper[f] = hse - - hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) - imapper[it].append(hs1.halo_scheme.project(f)) + imapper[it].append(hs1.halo_scheme.project(f)) mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss)) for i, hss in imapper.items()} + mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) for i, hs in hsmapper.items()}) iet = Transformer(mapper, nested=True).visit(iet) - return iet def _merge_halospots(iet): """ Merge HaloSpots on the same Iteration tree level where all data dependencies - would be honored. + would be honored. Helps to avoid redundant halo exchanges when the same data is + redundantly exchanged within the same Iteration tree level as well as to initiate + multiple halo exchanges at once. + + Example: + + for time for time + haloupd v[t0] haloupd v[t0], h + W v[t1]- R v[t0] W v[t1]- R v[t0] + haloupd v[t0], h + W g[t1]- R v[t0], h W g[t1]- R v[t0], h + """ # Analysis cond_mapper = _make_cond_mapper(iet) mapper = {} - for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items(): - if iter is None or len(halo_spots) <= 1: - continue - scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)]) + iter_mapper = MapNodes(Iteration, HaloSpot, 'immediate').visit(iet) + + # Drop pairs that have keys that are None + iter_mapper = {k: v for k, v in iter_mapper.items() if k is not None} + + # Drop iter_mapper pairs where len(halo_spots) <= 1 + iter_mapper = {k: v for k, v in iter_mapper.items() if len(v) > 1} + + for it, halo_spots in iter_mapper.items(): + + scope = Scope([e.expr for e in FindNodes(Expression).visit(it)]) hs0 = halo_spots[0] - mapper[hs0] = hs0.halo_scheme for hs1 in halo_spots[1:]: - mapper[hs1] = hs1.halo_scheme - # If there are Conditionals involved, both `hs0` and `hs1` must be - # within the same Conditional, otherwise we would break the control - # flow semantics - if cond_mapper.get(hs0) != cond_mapper.get(hs1): + if ensure_control_flow(hs0, hs1, cond_mapper): continue for f, v in hs1.fmapper.items(): for dep in scope.d_flow.project(f): - if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): + if not any(r(dep, hs1, v.loc_indices) for r in motion_rules()): break else: + # hs1 is merged with hs0 hs = hs1.halo_scheme.project(f) - mapper[hs0] = HaloScheme.union([mapper[hs0], hs]) - mapper[hs1] = mapper[hs1].drop(f) + mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs]) + mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f) # Post-process analysis mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) @@ -353,7 +371,7 @@ def mpiize(graph, **kwargs): make_reductions(graph, mpimode=mpimode, **kwargs) -# Utility functions to avoid code duplication +# *** Utilities def _make_cond_mapper(iet): cond_mapper = MapHaloSpots().visit(iet) @@ -362,9 +380,20 @@ def _make_cond_mapper(iet): for hs, v in cond_mapper.items()} -def merge_rules(): - # Merge rules -- if the retval is True, then it means the input `dep` is not - # a stopper to halo merging +def ensure_control_flow(hs0, hs1, cond_mapper): + """ + # If there are Conditionals involved, both `hs0` and `hs1` must be + # within the same Conditional, otherwise we would break the control + """ + cond0 = cond_mapper.get(hs0) + cond1 = cond_mapper.get(hs1) + + return cond0 != cond1 + + +def motion_rules(): + # Code motion rules -- if the retval is True, then it means the input `dep` is not + # a stopper to moving the HaloSpot `hs` around def rule0(dep, hs, loc_indices): # E.g., `dep=W -> R` => True @@ -372,16 +401,10 @@ def rule0(dep, hs, loc_indices): for d in dep.cause) def rule1(dep, hs, loc_indices): - # TODO This is apparently never hit, but feeling uncomfortable to remove it - return (dep.is_regular and - dep.read is not None and - all(not any(dep.read.touched_halo(d.root)) for d in dep.cause)) - - def rule2(dep, hs, loc_indices): # E.g., `dep=W -> R` and `loc_indices={t: t0}` => True return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v for d, v in loc_indices.items()) - rules = [rule0, rule1, rule2] + rules = [rule0, rule1] return rules diff --git a/tests/test_mpi.py b/tests/test_mpi.py index c02d122731..f4c17b5f0b 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -5,11 +5,10 @@ from conftest import _R, assert_blocking, assert_structure from devito import (Grid, Constant, Function, TimeFunction, SparseFunction, SparseTimeFunction, VectorTimeFunction, TensorTimeFunction, - Dimension, ConditionalDimension, div, + Dimension, ConditionalDimension, div, solve, diag, grad, SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, switchconfig, generic_derivative, - PrecomputedSparseFunction, DefaultDimension, Buffer, - solve, diag, grad) + PrecomputedSparseFunction, DefaultDimension, Buffer) from devito.arch.compiler import OneapiCompiler from devito.data import LEFT, RIGHT from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, @@ -19,7 +18,6 @@ ComputeCall) from devito.mpi.distributed import CustomTopology from devito.tools import Bunch -from devito.types.dimension import SpaceDimension from examples.seismic.acoustic import acoustic_setup from examples.seismic import demo_model @@ -1013,13 +1011,10 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode): @pytest.mark.parallel(mode=1) def test_issue_2448(self, mode): - extent = (10.,) shape = (2,) so = 2 - x = SpaceDimension(name='x', spacing=Constant(name='h_x', - value=extent[0]/(shape[0]-1))) - grid = Grid(extent=extent, shape=shape, dimensions=(x,)) + grid = Grid(shape=shape) # Time related tn = 30 @@ -1038,7 +1033,7 @@ def test_issue_2448(self, mode): # Test two variants of receiver interpolation nrec = 1 rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn) - rec.coordinates.data[:, 0] = np.linspace(0., extent[0], num=nrec) + rec.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) # The receiver 0 rec_term0 = rec.interpolate(expr=v) @@ -1077,6 +1072,70 @@ def test_issue_2448(self, mode): assert calls[1].arguments[0] is v assert calls[2].arguments[0] is v + # Further complicate/stree-test adding an artifical example + # with two hoisting opportunities + + # Velocity and pressure fields + v2 = TimeFunction(name='v2', grid=grid, space_order=so) + tau2 = TimeFunction(name='tau2', grid=grid, space_order=so) + + # First order elastic-like dependencies equations + pde_v2 = v2.dt - (tau2.dx) + pde_tau2 = (tau2.dt - ((v2.forward).dx)) + u_v2 = Eq(v2.forward, solve(pde_v2, v2.forward)) + + u_tau2 = Eq(tau2.forward, solve(pde_tau2, tau2.forward)) + + # Test two variants of receiver interpolation + nrec = 1 + rec2 = SparseTimeFunction(name="rec2", grid=grid, npoint=nrec, nt=tn) + rec2.coordinates.data[:, 0] = np.linspace(0., shape[0], num=nrec) + + # The receiver 0 + rec_term2 = rec2.interpolate(expr=v2) + + # The receiver 1 + rec_term3 = rec2.interpolate(expr=v2.forward) + + # Test receiver interpolation 0, here we have a halo exchange hoisted + op2 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term2) + + calls = [i for i in FindNodes(Call).visit(op2) + if isinstance(i, HaloUpdateCall)] + + # The correct we want + assert len(calls) == 5 + + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 2 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 3 + + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is v2 + assert calls[2].arguments[0] is tau + assert calls[2].arguments[1] is tau2 + + assert calls[3].arguments[0] is v + assert calls[4].arguments[0] is v2 + + # Test receiver interpolation 0, here we have a halo exchange hoisted + op3 = Operator([u_v] + [u_v2] + [u_tau] + [u_tau2] + rec_term0 + rec_term3) + + calls = [i for i in FindNodes(Call).visit(op3) + if isinstance(i, HaloUpdateCall)] + + # The correct we want + assert len(calls) == 5 + + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[0])) == 1 + assert len(FindNodes(HaloUpdateCall).visit(op2.body.body[1].body[1].body[1])) == 4 + + assert calls[0].arguments[0] is v + assert calls[1].arguments[0] is tau + assert calls[1].arguments[1] is tau2 + assert calls[2].arguments[0] is v + assert calls[3].arguments[0] is v2 + assert calls[4].arguments[0] is v2 + @pytest.mark.parallel(mode=1) def test_avoid_haloupdate_with_subdims(self, mode): grid = Grid(shape=(4,))