Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpi: Instrument compute0 core after specialising as ComputeCall #2143

Merged
merged 3 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,42 @@

# ElementalFunction machinery

class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}


class ElementalFunction(Callable):

"""
Expand All @@ -21,6 +57,7 @@ class ElementalFunction(Callable):
supplying bounds and step increment for each Dimension listed in
``dynamic_parameters``.
"""
_Call_cls = ElementalCall

is_ElementalFunction = True

Expand All @@ -47,53 +84,18 @@ def dynamic_defaults(self):

def make_call(self, dynamic_args_mapper=None, incr=False, retobj=None,
is_indirect=False):
return ElementalCall(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


class ElementalCall(Call):

def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
incr=False, retobj=None, is_indirect=False):
self._mapper = mapper or {}

arguments = list(as_tuple(arguments))
dynamic_args_mapper = dynamic_args_mapper or {}
for k, v in dynamic_args_mapper.items():
tv = as_tuple(v)

# Sanity check
if k not in self._mapper:
raise ValueError("`k` is not a dynamic parameter" % k)
if len(self._mapper[k]) != len(tv):
raise ValueError("Expected %d values for dynamic parameter `%s`, given %d"
% (len(self._mapper[k]), k, len(tv)))
# Create the argument list
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)

@cached_property
def dynamic_defaults(self):
return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()}
return self._Call_cls(self.name, list(self.parameters), dict(self._mapper),
dynamic_args_mapper, incr, retobj, is_indirect)


def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static'):
def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static',
efunc_type=ElementalFunction):
"""
Shortcut to create an ElementalFunction.
"""
return ElementalFunction(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)
return efunc_type(name, iet, retval=retval,
parameters=derive_parameters(iet), prefix=prefix,
dynamic_parameters=dynamic_parameters)


# Callable machinery
Expand Down
16 changes: 13 additions & 3 deletions devito/mpi/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction,
Expression, ExpressionBundle, AugmentedExpression,
Iteration, List, Prodder, Return, make_efunc, FindNodes,
Transformer)
Transformer, ElementalCall)
from devito.mpi import MPI
from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite,
IndexedPointer, Macro, cast_mapper, subs_op_args)
Expand Down Expand Up @@ -572,6 +572,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs):
return HaloUpdate('haloupdate%s' % key, iet, parameters)


class ComputeCall(ElementalCall):
pass


class ComputeFunction(ElementalFunction):
_Call_cls = ComputeCall


class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder):

"""
Expand Down Expand Up @@ -647,7 +655,8 @@ def _make_compute(self, hs, key, *args):
if hs.body.is_Call:
return None
else:
return make_efunc('compute%d' % key, hs.body, hs.arguments)
return make_efunc('compute%d' % key, hs.body, hs.arguments,
efunc_type=ComputeFunction)

def _call_compute(self, hs, compute, *args):
if compute is None:
Expand Down Expand Up @@ -952,7 +961,8 @@ def _make_compute(self, hs, key, msgs, callpoke):
mapper = {i: List(body=[callpoke, i]) for i in
FindNodes(ExpressionBundle).visit(hs.body)}
iet = Transformer(mapper).visit(hs.body)
return make_efunc('compute%d' % key, iet, hs.arguments)
return make_efunc('compute%d' % key, iet, hs.arguments,
efunc_type=ComputeFunction)

def _make_poke(self, hs, key, msgs):
lflag = Symbol(name='lflag')
Expand Down
4 changes: 2 additions & 2 deletions devito/operator/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.ir.support import IntervalGroup
from devito.logger import warning, error
from devito.mpi import MPI
from devito.mpi.routines import MPICall, MPIList, RemainderCall
from devito.mpi.routines import MPICall, MPIList, RemainderCall, ComputeCall
from devito.parameters import configuration
from devito.symbolics import subs_op_args
from devito.tools import DefaultOrderedDict, flatten
Expand Down Expand Up @@ -332,7 +332,7 @@ class AdvancedProfilerVerbose2(AdvancedProfilerVerbose):

@property
def trackable_subsections(self):
return (MPICall, BusyWait)
return (MPICall, BusyWait, ComputeCall)


class AdvisorProfiler(AdvancedProfiler):
Expand Down
6 changes: 4 additions & 2 deletions devito/passes/iet/instrument.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from devito.ir.iet import (BusyWait, FindNodes, FindSymbols, MapNodes, Section,
TimedList, Transformer)
from devito.mpi.routines import (HaloUpdateCall, HaloWaitCall, MPICall, MPIList,
HaloUpdateList, HaloWaitList, RemainderCall)
HaloUpdateList, HaloWaitList, RemainderCall,
ComputeCall)
from devito.passes.iet.engine import iet_pass
from devito.types import Timer

Expand Down Expand Up @@ -36,14 +37,15 @@ def track_subsections(iet, **kwargs):
HaloUpdateCall: 'haloupdate',
HaloWaitCall: 'halowait',
RemainderCall: 'remainder',
ComputeCall: 'compute',
HaloUpdateList: 'haloupdate',
HaloWaitList: 'halowait',
BusyWait: 'busywait'
}

mapper = {}

for NodeType in [MPIList, MPICall, BusyWait]:
for NodeType in [MPIList, MPICall, BusyWait, ComputeCall]:
for k, v in MapNodes(Section, NodeType).visit(iet).items():
for i in v:
if i in mapper or not any(issubclass(i.__class__, n)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols,
retrieve_iteration_tree)
from devito.mpi import MPI
from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall
from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall, ComputeCall
from devito.mpi.distributed import CustomTopology
from devito.tools import Bunch

from examples.seismic.acoustic import acoustic_setup

pytestmark = skipif(['nompi'], whole_module=True)
Expand Down Expand Up @@ -1400,6 +1401,7 @@ def test_min_code_size(self):
assert len(op._func_table) == 7
assert len(calls) == 4
assert 'haloupdate1' not in op._func_table
assert len(FindNodes(ComputeCall).visit(op)) == 1

@pytest.mark.parallel(mode=[(1, 'diag2')])
def test_many_functions(self):
Expand All @@ -1418,6 +1420,23 @@ def test_many_functions(self):
assert len(calls) == 2
assert calls[0].ncomps == 7

@switchconfig(profiling='advanced2')
@pytest.mark.parallel(mode=[
(1, 'full'),
])
def test_profiled_regions(self):
grid = Grid(shape=(10, 10, 10))

f = TimeFunction(name='f', grid=grid, space_order=2)
g = TimeFunction(name='g', grid=grid, space_order=2)

eqns = [Eq(f.forward, f.dx2 + 1.),
Eq(g.forward, g.dx2 + 1.)]

op = Operator(eqns)
assert op._profiler.all_sections == ['section0', 'haloupdate0', 'halowait0',
'remainder0', 'compute0']

@pytest.mark.parallel(mode=1)
def test_enforce_haloupdate_if_unwritten_function(self):
grid = Grid(shape=(16, 16))
Expand Down