diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index cc50e34e74..e55cce67ee 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -12,6 +12,42 @@ # ElementalFunction machinery +class ElementalCall(Call): + + def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None, + incr=False, retobj=None, is_indirect=False): + self._mapper = mapper or {} + + arguments = list(as_tuple(arguments)) + dynamic_args_mapper = dynamic_args_mapper or {} + for k, v in dynamic_args_mapper.items(): + tv = as_tuple(v) + + # Sanity check + if k not in self._mapper: + raise ValueError("`k` is not a dynamic parameter" % k) + if len(self._mapper[k]) != len(tv): + raise ValueError("Expected %d values for dynamic parameter `%s`, given %d" + % (len(self._mapper[k]), k, len(tv))) + # Create the argument list + for i, j in zip(self._mapper[k], tv): + arguments[i] = j if incr is False else (arguments[i] + j) + + super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect) + + def _rebuild(self, *args, dynamic_args_mapper=None, incr=False, + retobj=None, **kwargs): + # This guarantees that `ec._rebuild(arguments=ec.arguments) == ec` + return super(ElementalCall, self)._rebuild( + *args, dynamic_args_mapper=dynamic_args_mapper, incr=incr, + retobj=retobj, **kwargs + ) + + @cached_property + def dynamic_defaults(self): + return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()} + + class ElementalFunction(Callable): """ @@ -21,6 +57,7 @@ class ElementalFunction(Callable): supplying bounds and step increment for each Dimension listed in ``dynamic_parameters``. """ + _Call_cls = ElementalCall is_ElementalFunction = True @@ -47,53 +84,18 @@ def dynamic_defaults(self): def make_call(self, dynamic_args_mapper=None, incr=False, retobj=None, is_indirect=False): - return ElementalCall(self.name, list(self.parameters), dict(self._mapper), - dynamic_args_mapper, incr, retobj, is_indirect) - - -class ElementalCall(Call): - - def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None, - incr=False, retobj=None, is_indirect=False): - self._mapper = mapper or {} - - arguments = list(as_tuple(arguments)) - dynamic_args_mapper = dynamic_args_mapper or {} - for k, v in dynamic_args_mapper.items(): - tv = as_tuple(v) - - # Sanity check - if k not in self._mapper: - raise ValueError("`k` is not a dynamic parameter" % k) - if len(self._mapper[k]) != len(tv): - raise ValueError("Expected %d values for dynamic parameter `%s`, given %d" - % (len(self._mapper[k]), k, len(tv))) - # Create the argument list - for i, j in zip(self._mapper[k], tv): - arguments[i] = j if incr is False else (arguments[i] + j) - - super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect) - - def _rebuild(self, *args, dynamic_args_mapper=None, incr=False, - retobj=None, **kwargs): - # This guarantees that `ec._rebuild(arguments=ec.arguments) == ec` - return super(ElementalCall, self)._rebuild( - *args, dynamic_args_mapper=dynamic_args_mapper, incr=incr, - retobj=retobj, **kwargs - ) - - @cached_property - def dynamic_defaults(self): - return {k: tuple(self.arguments[i] for i in v) for k, v in self._mapper.items()} + return self._Call_cls(self.name, list(self.parameters), dict(self._mapper), + dynamic_args_mapper, incr, retobj, is_indirect) -def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static'): +def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static', + efunc_type=ElementalFunction): """ Shortcut to create an ElementalFunction. """ - return ElementalFunction(name, iet, retval=retval, - parameters=derive_parameters(iet), prefix=prefix, - dynamic_parameters=dynamic_parameters) + return efunc_type(name, iet, retval=retval, + parameters=derive_parameters(iet), prefix=prefix, + dynamic_parameters=dynamic_parameters) # Callable machinery diff --git a/devito/mpi/routines.py b/devito/mpi/routines.py index 3b512d8e7a..0ef4b20c65 100644 --- a/devito/mpi/routines.py +++ b/devito/mpi/routines.py @@ -12,7 +12,7 @@ from devito.ir.iet import (Call, Callable, Conditional, ElementalFunction, Expression, ExpressionBundle, AugmentedExpression, Iteration, List, Prodder, Return, make_efunc, FindNodes, - Transformer) + Transformer, ElementalCall) from devito.mpi import MPI from devito.symbolics import (Byref, CondNe, FieldFromPointer, FieldFromComposite, IndexedPointer, Macro, cast_mapper, subs_op_args) @@ -572,6 +572,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): return HaloUpdate('haloupdate%s' % key, iet, parameters) +class ComputeCall(ElementalCall): + pass + + +class ComputeFunction(ElementalFunction): + _Call_cls = ComputeCall + + class OverlapHaloExchangeBuilder(DiagHaloExchangeBuilder): """ @@ -647,7 +655,8 @@ def _make_compute(self, hs, key, *args): if hs.body.is_Call: return None else: - return make_efunc('compute%d' % key, hs.body, hs.arguments) + return make_efunc('compute%d' % key, hs.body, hs.arguments, + efunc_type=ComputeFunction) def _call_compute(self, hs, compute, *args): if compute is None: @@ -952,7 +961,8 @@ def _make_compute(self, hs, key, msgs, callpoke): mapper = {i: List(body=[callpoke, i]) for i in FindNodes(ExpressionBundle).visit(hs.body)} iet = Transformer(mapper).visit(hs.body) - return make_efunc('compute%d' % key, iet, hs.arguments) + return make_efunc('compute%d' % key, iet, hs.arguments, + efunc_type=ComputeFunction) def _make_poke(self, hs, key, msgs): lflag = Symbol(name='lflag') diff --git a/devito/operator/profiling.py b/devito/operator/profiling.py index 207e44351f..2a58677481 100644 --- a/devito/operator/profiling.py +++ b/devito/operator/profiling.py @@ -16,7 +16,7 @@ from devito.ir.support import IntervalGroup from devito.logger import warning, error from devito.mpi import MPI -from devito.mpi.routines import MPICall, MPIList, RemainderCall +from devito.mpi.routines import MPICall, MPIList, RemainderCall, ComputeCall from devito.parameters import configuration from devito.symbolics import subs_op_args from devito.tools import DefaultOrderedDict, flatten @@ -332,7 +332,7 @@ class AdvancedProfilerVerbose2(AdvancedProfilerVerbose): @property def trackable_subsections(self): - return (MPICall, BusyWait) + return (MPICall, BusyWait, ComputeCall) class AdvisorProfiler(AdvancedProfiler): diff --git a/devito/passes/iet/instrument.py b/devito/passes/iet/instrument.py index 8ede2aa0ac..db9b446d7f 100644 --- a/devito/passes/iet/instrument.py +++ b/devito/passes/iet/instrument.py @@ -1,7 +1,8 @@ from devito.ir.iet import (BusyWait, FindNodes, FindSymbols, MapNodes, Section, TimedList, Transformer) from devito.mpi.routines import (HaloUpdateCall, HaloWaitCall, MPICall, MPIList, - HaloUpdateList, HaloWaitList, RemainderCall) + HaloUpdateList, HaloWaitList, RemainderCall, + ComputeCall) from devito.passes.iet.engine import iet_pass from devito.types import Timer @@ -36,6 +37,7 @@ def track_subsections(iet, **kwargs): HaloUpdateCall: 'haloupdate', HaloWaitCall: 'halowait', RemainderCall: 'remainder', + ComputeCall: 'compute', HaloUpdateList: 'haloupdate', HaloWaitList: 'halowait', BusyWait: 'busywait' @@ -43,7 +45,7 @@ def track_subsections(iet, **kwargs): mapper = {} - for NodeType in [MPIList, MPICall, BusyWait]: + for NodeType in [MPIList, MPICall, BusyWait, ComputeCall]: for k, v in MapNodes(Section, NodeType).visit(iet).items(): for i in v: if i in mapper or not any(issubclass(i.__class__, n) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 18f7d5d6c3..a6d236a1f3 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -11,9 +11,10 @@ from devito.ir.iet import (Call, Conditional, Iteration, FindNodes, FindSymbols, retrieve_iteration_tree) from devito.mpi import MPI -from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall +from devito.mpi.routines import HaloUpdateCall, HaloUpdateList, MPICall, ComputeCall from devito.mpi.distributed import CustomTopology from devito.tools import Bunch + from examples.seismic.acoustic import acoustic_setup pytestmark = skipif(['nompi'], whole_module=True) @@ -1400,6 +1401,7 @@ def test_min_code_size(self): assert len(op._func_table) == 7 assert len(calls) == 4 assert 'haloupdate1' not in op._func_table + assert len(FindNodes(ComputeCall).visit(op)) == 1 @pytest.mark.parallel(mode=[(1, 'diag2')]) def test_many_functions(self): @@ -1418,6 +1420,23 @@ def test_many_functions(self): assert len(calls) == 2 assert calls[0].ncomps == 7 + @switchconfig(profiling='advanced2') + @pytest.mark.parallel(mode=[ + (1, 'full'), + ]) + def test_profiled_regions(self): + grid = Grid(shape=(10, 10, 10)) + + f = TimeFunction(name='f', grid=grid, space_order=2) + g = TimeFunction(name='g', grid=grid, space_order=2) + + eqns = [Eq(f.forward, f.dx2 + 1.), + Eq(g.forward, g.dx2 + 1.)] + + op = Operator(eqns) + assert op._profiler.all_sections == ['section0', 'haloupdate0', 'halowait0', + 'remainder0', 'compute0'] + @pytest.mark.parallel(mode=1) def test_enforce_haloupdate_if_unwritten_function(self): grid = Grid(shape=(16, 16))