diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index b2b552530df..e3f815a2f43 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -447,14 +447,14 @@ def reduction_comms(clusters): processed = [] fifo = [] for c in clusters: - # Schedule the global reductions encountered before `c`, if the - # IterationSpace of `c` is such that the reduction can be carried out + # Schedule the global distributed reductions encountered before `c`, + # if `c`'s IterationSpace is such that the reduction can be carried out found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace)) if found: exprs = [Eq(dr.var, dr) for dr in found] processed.append(c.rebuild(exprs=exprs)) - # Detect the global reductions in `c` + # Detect the global distributed reductions in `c` for e in c.exprs: op = e.operation if op is None or c.is_sparse: @@ -465,12 +465,23 @@ def reduction_comms(clusters): if grid is None: continue - # The IterationSpace within which the global reduction is carried out + # Is Inc/Max/Min/... actually used for a reduction? ispace = c.ispace.project(lambda d: d in var.free_symbols) if ispace.itdims == c.ispace.itdims: - # Inc/Max/Min/... being used for a non-reduction operation continue + # The reduced Dimensions + rdims = set(c.ispace.itdims) - set(ispace.itdims) + + # The reduced Dimensions inducing a global distributed reduction + grdims = {d for d in rdims if d._defines & c.dist_dimensions} + if not grdims: + continue + + # The IterationSpace within which the global distributed reduction + # must be carried out + ispace = c.ispace.prefix(lambda d: d in var.free_symbols) + fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace)) processed.append(c) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 05ba8c82fa9..dcb4ad22e54 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -167,6 +167,19 @@ def used_dimensions(self): idims = set.union(*[set(e.implicit_dims) for e in self.exprs]) return {i for i in self.free_symbols if i.is_Dimension} | idims + @cached_property + def dist_dimensions(self): + """ + The Cluster's distributed Dimensions. + """ + ret = set() + for f in self.functions: + try: + ret.update(f._dist_dimensions) + except AttributeError: + pass + return frozenset(ret) + @cached_property def scope(self): return Scope(self.exprs) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 2c773e08a5c..b322a1402f1 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -954,7 +954,7 @@ def prefix(self, key): try: i = self.project(key)[-1] except IndexError: - return None + return null_ispace return self[:self.index(i.dim) + 1] diff --git a/tests/test_mpi.py b/tests/test_mpi.py index c31e91d9861..fa70b9a0881 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -4,17 +4,17 @@ from conftest import _R, assert_blocking, assert_structure from devito import (Grid, Constant, Function, TimeFunction, SparseFunction, - SparseTimeFunction, Dimension, ConditionalDimension, SubDimension, - SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration, - switchconfig, generic_derivative, PrecomputedSparseFunction, - DefaultDimension) + SparseTimeFunction, Dimension, ConditionalDimension, + SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm, + inner, configuration, switchconfig, generic_derivative, + PrecomputedSparseFunction, DefaultDimension) from devito.arch.compiler import OneapiCompiler from devito.data import LEFT, RIGHT from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, retrieve_iteration_tree) from devito.mpi import MPI from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall, - ComputeCall, AllreduceCall) + ComputeCall) from devito.mpi.distributed import CustomTopology from devito.tools import Bunch @@ -929,8 +929,7 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode): # No stencil in the expressions, so no halo update required! calls = FindNodes(Call).visit(op) - assert len(calls) == 2 - assert all(isinstance(i, AllreduceCall) for i in calls) + assert len(calls) == 0 @pytest.mark.parallel(mode=1) def test_avoid_redundant_haloupdate(self, mode):