Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: Fix haloupdate with inner dim [v2] #2272

Merged
merged 5 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,54 +374,53 @@ class Communications(Queue):

B = Symbol(name='⊥')

@timed_pass(name='schedule')
@timed_pass(name='communications')
def process(self, clusters):
return self._process_fatd(clusters, 1, seen=set())

def callback(self, clusters, prefix, seen=None):
if seen.issuperset(clusters):
if not prefix:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be both not prefix or seen.issuperset(clusters) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no the seen check is performed a few lines below on a per-cluster basis

return clusters

d = prefix[-1].dim

# Construct the mock exprs representing the halo accesses
exprs = []
# Construct a representation of the halo accesses
processed = []
for c in clusters:
if c.properties.is_sequential(d):
if c.properties.is_sequential(d) or \
c in seen:
continue

halo_scheme = HaloScheme(c.exprs, c.ispace)
hs = HaloScheme(c.exprs, c.ispace)
if hs.is_void or \
not d._defines & hs.distributed_aindices:
continue

if not halo_scheme.is_void and \
c.properties.is_parallel_relaxed(d):
points = set()
for f in halo_scheme.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)
points = set()
for f in hs.fmapper:
for a in c.scope.getreads(f):
points.add(a.access)

# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch to ever
# be rescheduled after `c` upon topological sorting
points.update(a.access for a in c.scope.accesses if a.is_write)
# We also add all written symbols to ultimately create mock WARs
# with `c`, which will prevent the newly created HaloTouch to ever
# be rescheduled after `c` upon topological sorting
points.update(a.access for a in c.scope.accesses if a.is_write)

# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)
# Sort for determinism
# NOTE: not sorting might impact code generation. The order of
# the args is important because that's what search functions honor!
points = sorted(points, key=str)

rhs = HaloTouch(*points, halo_scheme=halo_scheme)
# Construct the HaloTouch Cluster
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))

# Insert only if not redundant, to avoid useless pollution
if not any(rhs == e.rhs for e in exprs):
exprs.append(Eq(self.B, rhs))
key = lambda i: i in prefix[:-1] or i in hs.loc_indices
ispace = c.ispace.project(key)

processed = []
if exprs:
ispace = prefix[:prefix.index(d)]
properties = prefix.properties.drop(d)
halo_touch = c.rebuild(exprs=expr, ispace=ispace)

processed.append(Cluster(exprs, ispace, c.guards, properties))
seen.update(clusters)
processed.append(halo_touch)
seen.update({halo_touch, c})

processed.extend(clusters)

Expand Down
5 changes: 4 additions & 1 deletion devito/ir/iet/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def iet_build(stree):
nsections += 1

elif i.is_Halo:
body = HaloSpot(queues.pop(i), i.halo_scheme)
try:
body = HaloSpot(queues.pop(i), i.halo_scheme)
except KeyError:
body = HaloSpot(None, i.halo_scheme)

elif i.is_Sync:
body = SyncSpot(i.sync_ops, body=queues.pop(i, None))
Expand Down
30 changes: 26 additions & 4 deletions devito/ir/stree/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito.ir.support import (SEQUENTIAL, Any, Interval, IterationInterval,
IterationSpace, normalize_properties, normalize_syncs)
from devito.mpi.halo_scheme import HaloScheme
from devito.tools import Bunch, DefaultOrderedDict
from devito.tools import Bunch, DefaultOrderedDict, as_mapper

__all__ = ['stree_build']

Expand Down Expand Up @@ -85,6 +85,10 @@ def stree_build(clusters, profiler=None, **kwargs):
if needs_nodehalo(it.dim, c.halo_scheme):
v.bottom.parent = NodeHalo(c.halo_scheme, v.bottom.parent)
break
else:
if c.halo_scheme:
assert not c.exprs # See preprocess() -- we rarely end up here!
tip = NodeHalo(c.halo_scheme, v.bottom)

# Add in NodeExprs
exprs = []
Expand Down Expand Up @@ -150,11 +154,14 @@ def preprocess(clusters, options=None, **kwargs):
for c in clusters:
if c.is_halo_touch:
hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)
queue.append(c.rebuild(halo_scheme=hs))
queue.append(c.rebuild(exprs=[], halo_scheme=hs))

elif c.is_critical_region and c.syncs:
processed.append(c.rebuild(exprs=None, guards=c.guards, syncs=c.syncs))

elif c.is_wild:
continue

else:
dims = set(c.ispace.promote(lambda d: d.is_Block).itdims)

Expand All @@ -181,8 +188,23 @@ def preprocess(clusters, options=None, **kwargs):
ispace = c.ispace.project(syncs)
processed.append(c.rebuild(exprs=[], ispace=ispace, syncs=syncs))

halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=halo_scheme))
if all(c1.ispace.is_subset(c.ispace) for c1 in found):
# 99% of the cases we end up here
hs = HaloScheme.union([c1.halo_scheme for c1 in found])
processed.append(c.rebuild(halo_scheme=hs))
elif options['mpi']:
# We end up here with e.g. `t,x,y,z,f` where `f` is a sequential
# dimension requiring a loc-index in the HaloScheme. The compiler
# will generate the non-perfect loop nest `t,f ; t,x,y,z,f`, with
# the first nest triggering all necessary halo exchanges along `f`
mapper = as_mapper(found, lambda c1: c1.ispace)
for k, v in mapper.items():
hs = HaloScheme.union([c1.halo_scheme for c1 in v])
processed.append(c.rebuild(exprs=[], ispace=k, halo_scheme=hs))
processed.append(c)
else:
# Avoid ugly empty loops
processed.append(c)

# Sanity check!
try:
Expand Down
13 changes: 13 additions & 0 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,19 @@ def reorder(self, relations=None, mode=None):

return IterationSpace(intervals, self.sub_iterators, self.directions)

def is_subset(self, other):
"""
True if `self` is included within `other`, False otherwise.
"""
if not self:
return True

d = self[-1].dim
try:
return self == other[:other.index(d) + 1]
except ValueError:
return False

def is_compatible(self, other):
"""
A relaxed version of ``__eq__``, in which only non-derived dimensions
Expand Down
2 changes: 2 additions & 0 deletions devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from devito.types.utils import DimensionTuple


__all__ = ['CustomTopology']

# Do not prematurely initialize MPI
# This allows launching a Devito program from within another Python program
# that has *already* initialized MPI
Expand Down
12 changes: 10 additions & 2 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __len__(self):
return len(self._mapper)

def __hash__(self):
return (self._mapper.__hash__(), self.honored.__hash__())
return hash((self._mapper.__hash__(), self.honored.__hash__()))

@classmethod
def build(cls, fmapper, honored):
Expand Down Expand Up @@ -582,13 +582,21 @@ def _sympystr(self, printer):
return str(self)

def __hash__(self):
return id(self)
return hash(self.halo_scheme)

def __eq__(self, other):
return isinstance(other, HaloTouch) and self.halo_scheme == other.halo_scheme

func = Reconstructable._rebuild

@property
def fmapper(self):
return self.halo_scheme.fmapper

@property
def dims(self):
return frozenset().union(*[v.dims for v in self.fmapper.values()])


def _uxreplace_dispatch_haloscheme(hs0, rule):
changed = False
Expand Down
15 changes: 9 additions & 6 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten
from devito.tools import as_tuple, flatten, filter_ordered
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
from devito.types.utils import DimensionTuple
Expand Down Expand Up @@ -163,7 +163,7 @@ def r(self):

@cached_property
def _rdim(self):
parent = self.sfunction.dimensions[-1]
parent = self.sfunction._sparse_dim
dims = [CustomDimension("r%s%s" % (self.sfunction.name, d.name),
-self.r+1, self.r, 2*self.r, parent)
for d in self._gdims]
Expand All @@ -184,15 +184,18 @@ def _rdim(self):

def _augment_implicit_dims(self, implicit_dims, extras=None):
if extras is not None:
extra = set([i for v in extras for i in v.dimensions]) - set(self._gdims)
extra = filter_ordered([i for v in extras for i in v.dimensions
if i not in self._gdims and
i not in self.sfunction.dimensions])
extra = tuple(extra)
else:
extra = tuple()

if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
idims = self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra
idims = extra + as_tuple(implicit_dims) + self.sfunction.dimensions
return tuple(idims)

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -283,7 +286,7 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)
Expand Down
10 changes: 8 additions & 2 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def rule1(dep, candidates, loc_dims):
for q in d._defines])

for n, i in enumerate(iters):
if i not in scopes:
continue

candidates = [i.dim._defines for i in iters[n:]]

all_candidates = set().union(*candidates)
Expand Down Expand Up @@ -251,9 +254,10 @@ def _mark_overlappable(iet):
found = []
for hs in FindNodes(HaloSpot).visit(iet):
expressions = FindNodes(Expression).visit(hs)
scope = Scope([i.expr for i in expressions])
if not expressions:
continue

test = True
scope = Scope([i.expr for i in expressions])

# Comp/comm overlaps is legal only if the OWNED regions can grow
# arbitrarly, which means all of the dependences must be carried
Expand All @@ -270,6 +274,8 @@ def _mark_overlappable(iet):
# f[x, y] = ...
test = False
break
else:
test = True

# Heuristic: avoid comp/comm overlap for sparse Iteration nests
if test:
Expand Down
4 changes: 4 additions & 0 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,10 @@ def _hashable_content(self):
def indices(self):
return DimensionTuple(*super().indices, getters=self.function.dimensions)

@cached_property
def dimensions(self):
return self.function.dimensions

@property
def function(self):
return self.base.function
Expand Down
5 changes: 3 additions & 2 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ class SparseTimeFunction(AbstractSparseTimeFunction, SparseFunction):
__rkwargs__ = tuple(filter_ordered(AbstractSparseTimeFunction.__rkwargs__ +
SparseFunction.__rkwargs__))

def interpolate(self, expr, u_t=None, p_t=None, increment=False):
def interpolate(self, expr, u_t=None, p_t=None, increment=False, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.

Expand All @@ -921,7 +921,8 @@ def interpolate(self, expr, u_t=None, p_t=None, increment=False):
if p_t is not None:
subs = {self.time_dim: p_t}

return super().interpolate(expr, increment=increment, self_subs=subs)
return super().interpolate(expr, increment=increment, self_subs=subs,
implicit_dims=implicit_dims)

def inject(self, field, expr, u_t=None, p_t=None, implicit_dims=None):
"""
Expand Down
12 changes: 4 additions & 8 deletions examples/seismic/tti/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ def ForwardOperator(model, geometry, space_order=4,

# Source and receivers
expr = src * dt / m if kernel == 'staggered' else src * dt**2 / m
stencils += src.inject(field=u.forward, expr=expr)
stencils += src.inject(field=v.forward, expr=expr)
stencils += src.inject(field=(u.forward, v.forward), expr=expr)
stencils += rec.interpolate(expr=u + v)

# Substitute spacing terms to reduce flops
Expand Down Expand Up @@ -601,8 +600,7 @@ def AdjointOperator(model, geometry, space_order=4,

# Construct expression to inject receiver values
expr = rec * dt / m if kernel == 'staggered' else rec * dt**2 / m
stencils += rec.inject(field=p.backward, expr=expr)
stencils += rec.inject(field=r.backward, expr=expr)
stencils += rec.inject(field=(p.backward, r.backward), expr=expr)

# Create interpolation expression for the adjoint-source
stencils += srca.interpolate(expr=p + r)
Expand Down Expand Up @@ -661,8 +659,7 @@ def JacobianOperator(model, geometry, space_order=4,
eqn2 = FD_kernel(model, du, dv, space_order, qu=lin_usrc, qv=lin_vsrc)

# Construct expression to inject source values, injecting at u0(t+dt)/v0(t+dt)
src_term = src.inject(field=u0.forward, expr=src * dt**2 / m)
src_term += src.inject(field=v0.forward, expr=src * dt**2 / m)
src_term = src.inject(field=(u0.forward, v0.forward), expr=src * dt**2 / m)

# Create interpolation expression for receivers, extracting at du(t)+dv(t)
rec_term = rec.interpolate(expr=du + dv)
Expand Down Expand Up @@ -716,8 +713,7 @@ def JacobianAdjOperator(model, geometry, space_order=4,
dm_update = Inc(dm, - (u0 * du.dt2 + v0 * dv.dt2))

# Add expression for receiver injection
rec_term = rec.inject(field=du.backward, expr=rec * dt**2 / m)
rec_term += rec.inject(field=dv.backward, expr=rec * dt**2 / m)
rec_term = rec.inject(field=(du.backward, dv.backward), expr=rec * dt**2 / m)

# Substitute spacing terms to reduce flops
return Operator(eqn + rec_term + [dm_update], subs=model.spacing_map,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ class SparseFirst(SparseFunction):
ds = DefaultDimension("ps", default_value=3)
grid = Grid((11, 11))
dims = grid.dimensions
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3))
s.coordinates.data[:] = [[.5, .5], [.2, .2]]
s = SparseFirst(name="s", grid=grid, npoint=2, dimensions=(dr, ds), shape=(2, 3),
coordinates=[[.5, .5], [.2, .2]])

# Check dimensions and shape are correctly initialized
assert s.indices[s._sparse_position] == dr
Expand Down
Loading
Loading