Skip to content

Commit 0512906

Browse files
committed
compiler: Fix detection of global distributed reductions
1 parent 1f16737 commit 0512906

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

devito/ir/clusters/algorithms.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,14 @@ def reduction_comms(clusters):
447447
processed = []
448448
fifo = []
449449
for c in clusters:
450-
# Schedule the global reductions encountered before `c`, if the
451-
# IterationSpace of `c` is such that the reduction can be carried out
450+
# Schedule the global distributed reductions encountered before `c`,
451+
# if `c`'s IterationSpace is such that the reduction can be carried out
452452
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
453453
if found:
454454
exprs = [Eq(dr.var, dr) for dr in found]
455455
processed.append(c.rebuild(exprs=exprs))
456456

457-
# Detect the global reductions in `c`
457+
# Detect the global distributed reductions in `c`
458458
for e in c.exprs:
459459
op = e.operation
460460
if op is None or c.is_sparse:
@@ -465,12 +465,23 @@ def reduction_comms(clusters):
465465
if grid is None:
466466
continue
467467

468-
# The IterationSpace within which the global reduction is carried out
468+
# Is Inc/Max/Min/... actually used for a reduction?
469469
ispace = c.ispace.project(lambda d: d in var.free_symbols)
470470
if ispace.itdims == c.ispace.itdims:
471-
# Inc/Max/Min/... being used for a non-reduction operation
472471
continue
473472

473+
# The reduced Dimensions
474+
rdims = set(c.ispace.itdims) - set(ispace.itdims)
475+
476+
# The reduced Dimensions inducing a global distributed reduction
477+
grdims = {d for d in rdims if d._defines & c.dist_dimensions}
478+
if not grdims:
479+
continue
480+
481+
# The IterationSpace within which the global distributed reduction
482+
# must be carried out
483+
ispace = c.ispace.prefix(lambda d: d in var.free_symbols)
484+
474485
fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))
475486

476487
processed.append(c)

devito/ir/clusters/cluster.py

+13
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ def used_dimensions(self):
167167
idims = set.union(*[set(e.implicit_dims) for e in self.exprs])
168168
return {i for i in self.free_symbols if i.is_Dimension} | idims
169169

170+
@cached_property
171+
def dist_dimensions(self):
172+
"""
173+
The Cluster's distributed Dimensions.
174+
"""
175+
ret = set()
176+
for f in self.functions:
177+
try:
178+
ret.update(f._dist_dimensions)
179+
except AttributeError:
180+
pass
181+
return frozenset(ret)
182+
170183
@cached_property
171184
def scope(self):
172185
return Scope(self.exprs)

devito/ir/support/space.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ def prefix(self, key):
954954
try:
955955
i = self.project(key)[-1]
956956
except IndexError:
957-
return None
957+
return null_ispace
958958

959959
return self[:self.index(i.dim) + 1]
960960

tests/test_mpi.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44

55
from conftest import _R, assert_blocking, assert_structure
66
from devito import (Grid, Constant, Function, TimeFunction, SparseFunction,
7-
SparseTimeFunction, Dimension, ConditionalDimension, SubDimension,
8-
SubDomain, Eq, Ne, Inc, NODE, Operator, norm, inner, configuration,
9-
switchconfig, generic_derivative, PrecomputedSparseFunction,
10-
DefaultDimension)
7+
SparseTimeFunction, Dimension, ConditionalDimension,
8+
SubDimension, SubDomain, Eq, Ne, Inc, NODE, Operator, norm,
9+
inner, configuration, switchconfig, generic_derivative,
10+
PrecomputedSparseFunction, DefaultDimension)
1111
from devito.arch.compiler import OneapiCompiler
1212
from devito.data import LEFT, RIGHT
1313
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
1414
retrieve_iteration_tree)
1515
from devito.mpi import MPI
1616
from devito.mpi.routines import (HaloUpdateCall, HaloUpdateList, MPICall,
17-
ComputeCall, AllreduceCall)
17+
ComputeCall)
1818
from devito.mpi.distributed import CustomTopology
1919
from devito.tools import Bunch
2020

@@ -929,8 +929,7 @@ def test_avoid_haloupdate_as_nostencil_advanced(self, mode):
929929

930930
# No stencil in the expressions, so no halo update required!
931931
calls = FindNodes(Call).visit(op)
932-
assert len(calls) == 2
933-
assert all(isinstance(i, AllreduceCall) for i in calls)
932+
assert len(calls) == 0
934933

935934
@pytest.mark.parallel(mode=1)
936935
def test_avoid_redundant_haloupdate(self, mode):

0 commit comments

Comments
 (0)