Skip to content

Commit

Permalink
compiler: Add dist-drop-unwritten opt-option
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed May 25, 2023
1 parent 7f1b6a7 commit df820d6
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 26 deletions.
3 changes: 3 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _normalize_kwargs(cls, **kwargs):
o['par-dynamic-work'] = oo.pop('par-dynamic-work', cls.PAR_DYNAMIC_WORK)
o['par-nested'] = oo.pop('par-nested', cls.PAR_NESTED)

# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
Expand Down
3 changes: 3 additions & 0 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _normalize_kwargs(cls, **kwargs):
o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs)))
o['gpu-create'] = as_tuple(oo.pop('gpu-create', ()))

# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
Expand Down
8 changes: 7 additions & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class BasicOperator(Operator):
The supported MPI modes.
"""

DIST_DROP_UNWRITTEN = True
"""
Drop halo exchanges for read-only Function, even in presence of
stencil-like data accesses.
"""

INDEX_MODE = "int64"
"""
The type of the expression used to compute array indices. Either `int64`
Expand Down Expand Up @@ -281,7 +287,7 @@ def _specialize_iet(cls, graph, **kwargs):
# from HaloSpot optimization)
# Note that if MPI is disabled then this pass will act as a no-op
if 'mpi' not in passes:
passes_mapper['mpi'](graph)
passes_mapper['mpi'](graph, **kwargs)

# Run passes
applied = []
Expand Down
54 changes: 30 additions & 24 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@


@iet_pass
def optimize_halospots(iet):
def optimize_halospots(iet, **kwargs):
"""
Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved
around in order to improve the halo exchange performance.
"""
iet = _drop_halospots(iet)
iet = _hoist_halospots(iet)
iet = _merge_halospots(iet)
iet = _drop_if_unwritten(iet)
iet = _drop_if_unwritten(iet, **kwargs)
iet = _mark_overlappable(iet)

return iet, {}
Expand All @@ -39,9 +39,9 @@ def _drop_halospots(iet):

# If all HaloSpot reads pertain to reductions, then the HaloSpot is useless
for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items():
for f in hs.fmapper:
scope = Scope([i.expr for i in expressions])
if all(i.is_reduction for i in scope.reads.get(f, [])):
scope = Scope([i.expr for i in expressions])
for f, v in scope.reads.items():
if f in hs.fmapper and all(i.is_reduction for i in v):
mapper[hs].add(f)

# Transform the IET introducing the "reduced" HaloSpots
Expand Down Expand Up @@ -73,7 +73,7 @@ def rule1(dep, candidates, loc_dims):
# A reduction isn't a stopper to hoisting
return dep.write is not None and dep.write.is_reduction

hoist_rules = [rule0, rule1]
rules = [rule0, rule1]

# Precompute scopes to save time
scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}
Expand All @@ -92,13 +92,16 @@ def rule1(dep, candidates, loc_dims):
for n, i in enumerate(iters):
candidates = [i.dim._defines for i in iters[n:]]

test = True
all_candidates = set().union(*candidates)
reads = scopes[i].getreads(f)
if any(set(a.ispace.dimensions) & all_candidates
for a in reads):
continue

for dep in scopes[i].d_flow.project(f):
if any(rule(dep, candidates, loc_dims) for rule in hoist_rules):
continue
test = False
break
if test:
if not any(r(dep, candidates, loc_dims) for r in rules):
break
else:
hsmapper[hs] = hsmapper[hs].drop(f)
imapper[i].append(hs.halo_scheme.project(f))
break
Expand Down Expand Up @@ -148,7 +151,7 @@ def rule2(dep, hs, loc_indices):
return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v
for d, v in loc_indices.items())

merge_rules = [rule0, rule1, rule2]
rules = [rule0, rule1, rule2]

# Analysis
mapper = {}
Expand All @@ -165,13 +168,10 @@ def rule2(dep, hs, loc_indices):
mapper[hs] = hs.halo_scheme

for f, v in hs.fmapper.items():
test = True
for dep in scope.d_flow.project(f):
if any(rule(dep, hs, v.loc_indices) for rule in merge_rules):
continue
test = False
break
if test:
if not any(r(dep, hs, v.loc_indices) for r in rules):
break
else:
try:
mapper[hs0] = HaloScheme.union([mapper[hs0],
hs.halo_scheme.project(f)])
Expand All @@ -191,21 +191,27 @@ def rule2(dep, hs, loc_indices):
return iet


def _drop_if_unwritten(iet):
def _drop_if_unwritten(iet, options=None, **kwargs):
"""
Drop HaloSpots for unwritten Functions.
Notes
-----
This may be relaxed if Devito+MPI were to be used within existing
legacy codes, which would call the generated library directly.
This may be relaxed if Devito were to be used within existing legacy codes,
which would call the generated library directly.
"""
drop_unwritten = options['dist-drop-unwritten']
if not callable(drop_unwritten):
key = lambda f: drop_unwritten
else:
key = drop_unwritten

# Analysis
writes = {i.write for i in FindNodes(Expression).visit(iet)}
mapper = {}
for hs in FindNodes(HaloSpot).visit(iet):
for f in hs.fmapper:
if f not in writes:
if f not in writes and key(f):
mapper[hs] = mapper.get(hs, hs.halo_scheme).drop(f)

# Post-process analysis
Expand Down Expand Up @@ -321,7 +327,7 @@ def mpiize(graph, **kwargs):
options = kwargs['options']

if options['optcomms']:
optimize_halospots(graph)
optimize_halospots(graph, **kwargs)

mpimode = options['mpi']
if mpimode:
Expand Down
33 changes: 32 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from conftest import assert_structure
from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension,
MatrixSparseTimeFunction, SparseTimeFunction, SubDimension,
SubDomain, SubDomainSet, TimeFunction, Operator, configuration)
SubDomain, SubDomainSet, TimeFunction, Operator, configuration,
switchconfig)
from devito.arch import get_gpu_info
from devito.exceptions import InvalidArgument
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1096,6 +1097,36 @@ def test_streaming_split_noleak(self):
assert np.all(u.data[0] == u1.data[0])
assert np.all(u.data[1] == u1.data[1])

@pytest.mark.skip(reason="Unsupported MPI + .dx when streaming backwards")
@pytest.mark.parallel(mode=4)
@switchconfig(safe_math=True) # Or NVC will crash
def test_streaming_w_mpi(self):
nt = 5
grid = Grid(shape=(16, 16))

u = TimeFunction(name='u', grid=grid)
usave = TimeFunction(name='usave', grid=grid, save=nt, space_order=4)
vsave = TimeFunction(name='vsave', grid=grid, save=nt, space_order=4)
vsave1 = TimeFunction(name='vsave', grid=grid, save=nt, space_order=4)

eqns = [Eq(u.backward, u + 1.),
Eq(vsave, usave.dx2)]

key = lambda f: f is not usave

op0 = Operator(eqns, opt='noop')
op1 = Operator(eqns, opt=('buffering', 'streaming', 'orchestrate',
{'dist-drop-unwritten': key,
'gpu-fit': [vsave]}))

for i in range(nt):
usave.data[i] = i

op0.apply()
op1.apply(vsave=vsave1)

assert np.all(vsave.data, vsave1.data, rtol=1.e-5)

@pytest.mark.parametrize('opt,opt_options,gpu_fit', [
(('buffering', 'streaming', 'orchestrate'), {}, False),
(('buffering', 'streaming', 'orchestrate'), {'linearize': True}, False)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,26 @@ def test_many_functions(self):
assert len(calls) == 2
assert calls[0].ncomps == 7

@pytest.mark.parallel(mode=1)
def test_enforce_haloupdate_if_unwritten_function(self):
grid = Grid(shape=(16, 16))

u = TimeFunction(name='u', grid=grid)
v = TimeFunction(name='v', grid=grid)
w = TimeFunction(name='w', grid=grid)
usave = TimeFunction(name='usave', grid=grid, save=10, space_order=4)

eqns = [Eq(w.forward, v.forward.dx + w + 1., subdomain=grid.interior),
Eq(u.forward, u + 1.),
Eq(v.forward, u.forward + usave.dx4, subdomain=grid.interior)]

key = lambda f: f is not usave

op = Operator(eqns, opt=('advanced', {'dist-drop-unwritten': key}))

calls = FindNodes(Call).visit(op)
assert len(calls) == 2 # One for `v` and one for `usave`


class TestOperatorAdvanced(object):

Expand Down

0 comments on commit df820d6

Please sign in to comment.