diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 1e71e45428..7137c08b06 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -60,6 +60,9 @@ def _normalize_kwargs(cls, **kwargs): o['par-dynamic-work'] = oo.pop('par-dynamic-work', cls.PAR_DYNAMIC_WORK) o['par-nested'] = oo.pop('par-nested', cls.PAR_NESTED) + # Distributed parallelism + o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) + # Misc o['expand'] = oo.pop('expand', cls.EXPAND) o['optcomms'] = oo.pop('optcomms', True) diff --git a/devito/core/gpu.py b/devito/core/gpu.py index fbe71b08b8..8a75b2857a 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -75,6 +75,9 @@ def _normalize_kwargs(cls, **kwargs): o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs))) o['gpu-create'] = as_tuple(oo.pop('gpu-create', ())) + # Distributed parallelism + o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN) + # Misc o['expand'] = oo.pop('expand', cls.EXPAND) o['optcomms'] = oo.pop('optcomms', True) diff --git a/devito/core/operator.py b/devito/core/operator.py index 6620b0c60d..3237dae5cb 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -100,6 +100,12 @@ class BasicOperator(Operator): The supported MPI modes. """ + DIST_DROP_UNWRITTEN = True + """ + Drop halo exchanges for read-only Function, even in presence of + stencil-like data accesses. + """ + INDEX_MODE = "int64" """ The type of the expression used to compute array indices. Either `int64` @@ -281,7 +287,7 @@ def _specialize_iet(cls, graph, **kwargs): # from HaloSpot optimization) # Note that if MPI is disabled then this pass will act as a no-op if 'mpi' not in passes: - passes_mapper['mpi'](graph) + passes_mapper['mpi'](graph, **kwargs) # Run passes applied = [] @@ -300,10 +306,6 @@ def _specialize_iet(cls, graph, **kwargs): if 'init' not in passes: passes_mapper['init'](graph) - # Enforce pthreads if CPU-GPU orchestration requested - if 'orchestrate' in passes and 'pthreadify' not in passes: - passes_mapper['pthreadify'](graph, sregistry=sregistry) - # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) @@ -311,6 +313,10 @@ def _specialize_iet(cls, graph, **kwargs): if 'linearize' not in passes and options['linearize']: passes_mapper['linearize'](graph) + # Enforce pthreads if CPU-GPU orchestration requested + if 'orchestrate' in passes and 'pthreadify' not in passes: + passes_mapper['pthreadify'](graph, sregistry=sregistry) + return graph diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 3dd4f9b555..857e62f7c6 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -9,7 +9,7 @@ DataSpace, Guards, Properties, Scope, detect_accesses, detect_io, normalize_properties, normalize_syncs, sdims_min, sdims_max) -from devito.mpi.halo_scheme import HaloTouch +from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.symbolics import estimate_cost from devito.tools import as_tuple, flatten, frozendict, infer_dtype @@ -26,7 +26,7 @@ class Cluster(object): exprs : expr-like or list of expr-like An ordered sequence of expressions computing a tensor. ispace : IterationSpace, optional - The cluster iteration space. + The Cluster iteration space. guards : dict, optional Mapper from Dimensions to expr-like, representing the conditions under which the Cluster should be computed. @@ -37,9 +37,12 @@ class Cluster(object): Mapper from Dimensions to lists of SyncOps, that is ordered sequences of synchronization operations that must be performed in order to compute the Cluster asynchronously. + halo_scheme : HaloScheme, optional + The halo exchanges required by the Cluster. """ - def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None): + def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None, + halo_scheme=None): ispace = ispace or IterationSpace([]) self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) @@ -57,6 +60,8 @@ def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None) properties = properties.drop(d) self._properties = properties + self._halo_scheme = halo_scheme + def __repr__(self): return "Cluster([%s])" % ('\n' + ' '*9).join('%s' % i for i in self.exprs) @@ -91,7 +96,9 @@ def from_clusters(cls, *clusters): raise ValueError("Cannot build a Cluster from Clusters with " "non-compatible synchronization operations") - return Cluster(exprs, ispace, guards, properties, syncs) + halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) + + return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) def rebuild(self, *args, **kwargs): """ @@ -110,7 +117,8 @@ def rebuild(self, *args, **kwargs): ispace=kwargs.get('ispace', self.ispace), guards=kwargs.get('guards', self.guards), properties=kwargs.get('properties', self.properties), - syncs=kwargs.get('syncs', self.syncs)) + syncs=kwargs.get('syncs', self.syncs), + halo_scheme=kwargs.get('halo_scheme', self.halo_scheme)) @property def exprs(self): @@ -144,6 +152,10 @@ def properties(self): def syncs(self): return self._syncs + @property + def halo_scheme(self): + return self._halo_scheme + @cached_property def free_symbols(self): return set().union(*[e.free_symbols for e in self.exprs]) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index aba0d0d98c..fc9e907d75 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -288,10 +288,13 @@ def functions(self): retval.append(s.function) except AttributeError: continue + if self.base is not None: retval.append(self.base.function) + if self.retobj is not None: retval.append(self.retobj.function) + return tuple(filter_ordered(retval)) @cached_property @@ -309,10 +312,15 @@ def expr_symbols(self): retval.extend(i.free_symbols) except AttributeError: pass + if self.base is not None: retval.append(self.base) - if self.retobj is not None: + + if isinstance(self.retobj, Indexed): + retval.extend(self.retobj.free_symbols) + elif self.retobj is not None: retval.append(self.retobj) + return tuple(filter_ordered(retval)) @property @@ -744,6 +752,8 @@ class CallableBody(Node): maps : Transfer or list of Transfer, optional Data maps for `body` (a data map may e.g. trigger a data transfer from host to device). + strides : list of Nodes, optional + Statements defining symbols used to access linearized arrays. objs : list of Definitions, optional Object definitions for `body`. unmaps : Transfer or list of Transfer, optional @@ -756,19 +766,21 @@ class CallableBody(Node): is_CallableBody = True - _traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps', 'objs', - 'body', 'unmaps', 'unbundles', 'frees'] + _traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps', + 'strides', 'objs', 'body', 'unmaps', 'unbundles', 'frees'] - def __init__(self, body, init=(), unpacks=(), allocs=(), casts=(), + def __init__(self, body, init=(), unpacks=(), strides=(), allocs=(), casts=(), bundles=(), objs=(), maps=(), unmaps=(), unbundles=(), frees=()): # Sanity check assert not isinstance(body, CallableBody), "CallableBody's cannot be nested" self.body = as_tuple(body) - self.init = as_tuple(init) + self.unpacks = as_tuple(unpacks) + self.init = as_tuple(init) self.allocs = as_tuple(allocs) self.casts = as_tuple(casts) + self.strides = as_tuple(strides) self.bundles = as_tuple(bundles) self.maps = as_tuple(maps) self.objs = as_tuple(objs) @@ -1201,6 +1213,10 @@ def __getattr__(self, name): def functions(self): return as_tuple(self.nthreads) + @property + def expr_symbols(self): + return as_tuple(self.nthreads) + @property def root(self): return self.body[0] diff --git a/devito/ir/stree/algorithms.py b/devito/ir/stree/algorithms.py index 109d258050..e6aebf44f5 100644 --- a/devito/ir/stree/algorithms.py +++ b/devito/ir/stree/algorithms.py @@ -1,4 +1,3 @@ -from collections import defaultdict from itertools import groupby from anytree import findall @@ -8,10 +7,9 @@ from devito.ir.stree.tree import (ScheduleTree, NodeIteration, NodeConditional, NodeSync, NodeExprs, NodeSection, NodeHalo) from devito.ir.support import (SEQUENTIAL, Any, Interval, IterationInterval, - IterationSpace, normalize_properties) + IterationSpace, normalize_properties, normalize_syncs) from devito.mpi.halo_scheme import HaloScheme from devito.tools import Bunch, DefaultOrderedDict -from devito.types.dimension import BOTTOM __all__ = ['stree_build'] @@ -20,39 +18,32 @@ def stree_build(clusters, profiler=None, **kwargs): """ Create a ScheduleTree from a ClusterGroup. """ - clusters, hsmap = preprocess(clusters) + clusters = preprocess(clusters, **kwargs) stree = ScheduleTree() section = None - base = IterationInterval(Interval(BOTTOM), [], Any) prev = Cluster(None) - mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) - mapper[base] = Bunch(top=stree, bottom=stree) + mapper = DefaultOrderedDict(lambda: Bunch(top=None, middle=None, bottom=None)) + mapper[base] = Bunch(top=stree, middle=stree, bottom=stree) for c in clusters: - if reuse_subtree(c, prev): + if reuse_whole_subtree(c, prev): tip = mapper[base].bottom maybe_reusable = prev.itintervals else: # Add any guards/Syncs outside of the outermost Iteration - tip = augment_subtree(c, None, stree) + tip = augment_whole_subtree(c, stree, mapper, base) maybe_reusable = [] - # Is there a HaloTouch to attach? - try: - hs = hsmap[c] - except KeyError: - hs = None - index = 0 for it0, it1 in zip(c.itintervals, maybe_reusable): if it0 != it1: break d = it0.dim - if needs_nodehalo(d, hs): + if needs_nodehalo(d, c.halo_scheme): break index += 1 @@ -66,12 +57,15 @@ def stree_build(clusters, profiler=None, **kwargs): mapper[it0].top.properties, c.properties[it0.dim] ) - if reuse_subtree(c, prev, d): + if reuse_whole_subtree(c, prev, d): tip = mapper[it0].bottom + elif reuse_partial_subtree(c, prev, d): + tip = mapper[it0].middle + tip = augment_partial_subtree(c, tip, mapper, it0) + break else: tip = mapper[it0].top - tip = augment_subtree(c, d, tip) - mapper[it0].bottom = tip + tip = augment_whole_subtree(c, tip, mapper, it0) break # Nested sub-trees, instead, will not be used anymore @@ -84,13 +78,12 @@ def stree_build(clusters, profiler=None, **kwargs): d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d, ())) mapper[it].top = tip - tip = augment_subtree(c, d, tip) - mapper[it].bottom = tip + tip = augment_whole_subtree(c, tip, mapper, it) # Attach NodeHalo if necessary for it, v in mapper.items(): - if needs_nodehalo(it.dim, hs): - v.bottom.parent = NodeHalo(hs, v.bottom.parent) + if needs_nodehalo(it.dim, c.halo_scheme): + v.bottom.parent = NodeHalo(c.halo_scheme, v.bottom.parent) break # Add in NodeExprs @@ -132,52 +125,85 @@ def stree_build(clusters, profiler=None, **kwargs): return stree -# *** Utility functions to construct the ScheduleTree +# *** Utilities to construct the ScheduleTree +base = IterationInterval(Interval(None), [], Any) -def preprocess(clusters): + +def preprocess(clusters, options=None, **kwargs): """ - Remove the HaloTouches from `clusters` and create a mapping associating + Remove the HaloTouch's from `clusters` and create a mapping associating each removed HaloTouch to the first Cluster necessitating it. """ - processed = [] - hsmap = defaultdict(list) - queue = [] - + processed = [] for c in clusters: if c.is_halo_touch: - queue.append(HaloScheme.union(e.rhs.halo_scheme for e in c.exprs)) + hs = HaloScheme.union(e.rhs.halo_scheme for e in c.exprs) + queue.append(c.rebuild(halo_scheme=hs)) else: dims = set(c.ispace.promote(lambda d: d.is_Block).itdimensions) - for hs in list(queue): - if hs.distributed_aindices & dims: - queue.remove(hs) - hsmap[c].append(hs) + found = [] + for c1 in list(queue): + distributed_aindices = c1.halo_scheme.distributed_aindices + + diff = dims - distributed_aindices + intersection = dims & distributed_aindices + + if all(c1.guards.get(d) == c.guards.get(d) for d in diff) and \ + len(intersection) > 0: + found.append(c1) + queue.remove(c1) + + syncs = normalize_syncs(c.syncs, *[c1.syncs for c1 in found]) + halo_scheme = HaloScheme.union([c1.halo_scheme for c1 in found]) + + processed.append(c.rebuild(syncs=syncs, halo_scheme=halo_scheme)) + + # Sanity check! + try: + assert not queue + except AssertionError: + if options['mpi']: + raise RuntimeError("Unsupported MPI for the given equations") - processed.append(c) + return processed - hsmap = {c: HaloScheme.union(hss) for c, hss in hsmap.items()} - return processed, hsmap +def reuse_partial_subtree(c0, c1, d=None): + return c0.guards.get(d) == c1.guards.get(d) -def reuse_subtree(c0, c1, d=None): +def reuse_whole_subtree(c0, c1, d=None): return (c0.guards.get(d) == c1.guards.get(d) and c0.syncs.get(d) == c1.syncs.get(d)) -def augment_subtree(cluster, d, tip): - if d in cluster.guards: - tip = NodeConditional(cluster.guards[d], tip) +def augment_partial_subtree(cluster, tip, mapper, it=None): + d = it.dim + if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) + + mapper[it].bottom = tip + return tip +def augment_whole_subtree(cluster, tip, mapper, it): + d = it.dim + + if d in cluster.guards: + tip = NodeConditional(cluster.guards[d], tip) + + mapper[it].middle = mapper[it].bottom = tip + + return augment_partial_subtree(cluster, tip, mapper, it) + + def needs_nodehalo(d, hs): - return hs and d._defines.intersection(hs.distributed_aindices) + return d and hs and d._defines.intersection(hs.distributed_aindices) def reuse_section(candidate, section): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index d0fefc37df..f4ab873575 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -175,7 +175,9 @@ def affine(self, dims): m[d] = set(self.get(d, [])) | {AFFINE} return Properties(m) - def sequentialize(self, dims): + def sequentialize(self, dims=None): + if dims is None: + dims = list(self) m = dict(self) for d in as_tuple(dims): m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL}) diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index e0bcd8d956..5f08f48020 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -1,13 +1,12 @@ from collections import defaultdict -from devito.symbolics import (CallFromPointer, retrieve_indexed, retrieve_terminals, - uxreplace) +from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals from devito.tools import DefaultOrderedDict, as_tuple, flatten, filter_sorted, split from devito.types import (Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension) __all__ = ['AccessMode', 'Stencil', 'IMask', 'detect_accesses', 'detect_io', - 'pull_dims', 'shift_back', 'sdims_min', 'sdims_max'] + 'pull_dims', 'sdims_min', 'sdims_max'] class AccessMode(object): @@ -270,26 +269,6 @@ def pull_dims(exprs, flag=True): return dims -def shift_back(objects): - """ - Translate the Indexeds in a given collection of objects along all - Dimensions by the left-halo. - - This is useful to create input for recursive compilation from pre-existing, - and therefore already lowered/indefixied, input. - """ - processed = [] - for o in as_tuple(objects): - subs = {} - for i in retrieve_indexed(o): - subs[i] = i.subs({v: v - s.left - for v, s in zip(i.indices, i.function._size_halo)}) - - processed.append(uxreplace(o, subs)) - - return processed - - # *** Utility functions for expressions that potentially contain StencilDimensions def sdims_min(expr): diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 47f0df29ba..d6c9f8933f 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -120,6 +120,10 @@ def union(self, halo_schemes): """ Create a new HaloScheme from the union of a set of HaloSchemes. """ + halo_schemes = [hs for hs in halo_schemes if hs is not None] + if not halo_schemes: + return None + fmapper = {} honored = {} for i in as_tuple(halo_schemes): diff --git a/devito/operator/profiling.py b/devito/operator/profiling.py index 80251aa135..207e44351f 100644 --- a/devito/operator/profiling.py +++ b/devito/operator/profiling.py @@ -225,7 +225,7 @@ def summary(self, args, dtype, reduce_over=None): # Number of FLOPs performed try: ops = int(subs_op_args(data.ops, args)) - except TypeError: + except (AttributeError, TypeError): # E.g., a section comprising just function calls, or at least # a sequence of unrecognized or non-conventional expr statements ops = np.nan @@ -236,7 +236,7 @@ def summary(self, args, dtype, reduce_over=None): # Compulsory traffic traffic = float(subs_op_args(data.traffic, args)*dtype().itemsize) - except TypeError: + except (AttributeError, TypeError): # E.g., the section has a dynamic loop size points = np.nan diff --git a/devito/passes/clusters/asynchrony.py b/devito/passes/clusters/asynchrony.py index 343ce6f89b..cb874d5295 100644 --- a/devito/passes/clusters/asynchrony.py +++ b/devito/passes/clusters/asynchrony.py @@ -323,7 +323,7 @@ def __init__(self, drop=False, syncs=None, insert=None): def is_memcpy(cluster): """ - True if `cluster` emulates a memcpy involving a mapped Array, False otherwise. + True if `cluster` emulates a memcpy involving an Array, False otherwise. """ if len(cluster.exprs) != 1: return False @@ -333,5 +333,4 @@ def is_memcpy(cluster): if not (a.is_Indexed and b.is_Indexed): return False - return ((a.function.is_Array and a.function._mem_mapped) or - (b.function.is_Array and b.function._mem_mapped)) + return a.function.is_Array or b.function.is_Array diff --git a/devito/passes/clusters/buffering.py b/devito/passes/clusters/buffering.py index 1b8025305f..e5826397db 100644 --- a/devito/passes/clusters/buffering.py +++ b/devito/passes/clusters/buffering.py @@ -101,11 +101,22 @@ def callback(f): return None assert callable(callback) + v0 = kwargs.get('opt_init_onread', True) + if callable(v0): + init_onread = v0 + else: + init_onread = lambda f: v0 + v1 = kwargs.get('opt_init_onwrite', False) + if callable(v1): + init_onwrite = v1 + else: + init_onwrite = lambda f: v1 + options = { 'buf-async-degree': options['buf-async-degree'], 'buf-fuse-tasks': options['fuse-tasks'], - 'buf-init-onread': kwargs.get('opt_init_onread', True), - 'buf-init-onwrite': kwargs.get('opt_init_onwrite', False), + 'buf-init-onread': init_onread, + 'buf-init-onwrite': init_onwrite, 'buf-callback': kwargs.get('opt_buffer'), } @@ -133,71 +144,66 @@ def callback(self, clusters, prefix, cache=None): return clusters d = prefix[-1].dim + try: + pd = prefix[-2].dim + except IndexError: + pd = None # Locate all Function accesses within the provided `clusters` accessmap = AccessMapper(clusters) - # Create the buffers - buffers = BufferBatch() - for f, accessv in accessmap.items(): - # Has a buffer already been produced for `f`? - if f in cache: - continue + init_onread = self.options['buf-init-onread'] + init_onwrite = self.options['buf-init-onwrite'] + # Create and initialize buffers + init = [] + buffers = [] + for f, accessv in accessmap.items(): # Is `f` really a buffering candidate? dim = self.callback0(f) - if dim is None or d not in dim._defines: + if dim is None or not d._defines & dim._defines: continue - b = cache[f] = buffers.make(f, dim, accessv, self.options, self.sregistry) + b = Buffer(f, dim, d, accessv, cache, self.options, self.sregistry) + buffers.append(b) - if not buffers: - return clusters - - try: - pd = prefix[-2].dim - except IndexError: - pd = None - - # Create Eqs to initialize buffers. Note: a buffer needs to be initialized - # only if the buffered Function is read in at least one place or in the case - # of non-uniform SubDimensions, to avoid uninitialized values to be copied-back - # into the buffered Function - init_onread = self.options['buf-init-onread'] - init_onwrite = self.options['buf-init-onwrite'] - init = [] - for b in buffers: if b.is_read or not b.has_uniform_subdims: # Special case: avoid initialization if not strictly necessary # See docstring for more info about what this implies - if b.size == 1 and not init_onread: + if b.size == 1 and not init_onread(b.function): continue dims = b.function.dimensions - lhs = b.indexed[[b.initmap.get(d, Map(d, d)).b for d in dims]] - rhs = b.function[[b.initmap.get(d, Map(d, d)).f for d in dims]] + lhs = b.indexed[dims]._subs(dim, b.firstidx.b) + rhs = b.function[dims]._subs(dim, b.firstidx.f) - elif b.is_write and init_onwrite: + elif b.is_write and init_onwrite(b.function): dims = b.buffer.dimensions lhs = b.buffer.indexify() rhs = 0 else: + # NOTE: a buffer must be initialized only if the buffered + # Function is read in at least one place or in the case of + # non-uniform SubDimensions, to avoid uninitialized values to + # be copied-back into the buffered Function continue expr = lower_exprs(Eq(lhs, rhs)) ispace = b.writeto - guards = {pd: GuardBound(d.root.symbolic_min, d.root.symbolic_max) - for d in b.contraction_mapper} + guards = {pd: GuardBound(dim.root.symbolic_min, dim.root.symbolic_max)} properties = {d: {AFFINE, PARALLEL} for d in ispace.itdimensions} init.append(Cluster(expr, ispace, guards=guards, properties=properties)) + if not buffers: + return clusters + # Substitution rules to replace buffered Functions with buffers subs = {} for b in buffers: for a in b.accessv.accesses: - subs[a] = b.indexed[[b.index_mapper_flat.get(i, i) for i in a.indices]] + subs[a] = b.indexed[[b.index_mapper.get(i, i) for i in a.indices]] processed = [] for c in clusters: @@ -213,8 +219,8 @@ def callback(self, clusters, prefix, cache=None): continue dims = b.function.dimensions - lhs = b.indexed[[b.lastmap.get(d, Map(d, d)).b for d in dims]] - rhs = b.function[[b.lastmap.get(d, Map(d, d)).f for d in dims]] + lhs = b.indexed[dims]._subs(b.dim, b.lastidx.b) + rhs = b.function[dims]._subs(b.dim, b.lastidx.f) expr = lower_exprs(Eq(lhs, rhs)) ispace = b.readfrom @@ -245,8 +251,8 @@ def callback(self, clusters, prefix, cache=None): continue dims = b.function.dimensions - lhs = b.function[[b.lastmap.get(d, Map(d, d)).f for d in dims]] - rhs = b.indexed[[b.lastmap.get(d, Map(d, d)).b for d in dims]] + lhs = b.function[dims]._subs(b.dim, b.lastidx.f) + rhs = b.indexed[dims]._subs(b.dim, b.lastidx.b) expr = lower_exprs(uxreplace(Eq(lhs, rhs), b.subdims_mapper)) ispace = b.written @@ -276,12 +282,10 @@ def callback(self, clusters, prefix, cache=None): else: continue - contracted = set().union(*[d._defines for d in b.contraction_mapper]) - processed1 = [] for c in processed: if b.buffer in c.functions: - key1 = lambda d: d not in contracted + key1 = lambda d: d not in b.dim._defines dims = c.ispace.project(key1).itdimensions ispace = c.ispace.lift(dims, key0()) processed1.append(c.rebuild(ispace=ispace)) @@ -292,24 +296,6 @@ def callback(self, clusters, prefix, cache=None): return init + processed -class BufferBatch(list): - - def __init__(self): - super().__init__() - - def make(self, *args): - """ - Create a Buffer. See Buffer.__doc__. - """ - b = Buffer(*args) - self.append(b) - return b - - @property - def functions(self): - return {b.function for b in self} - - class Buffer(object): """ @@ -319,27 +305,31 @@ class Buffer(object): ---------- function : DiscreteFunction The object for which the buffer is created. + dim : Dimension + The Dimension in `function` to be replaced with a ModuloDimension. d : Dimension - The Dimension in `function` to be contracted, that is to be replaced - with a ModuloDimension. + The iteration Dimension from which `function` was extracted. accessv : AccessValue All accesses involving `function`. + cache : dict + Mapper between buffered Functions and previously created Buffers. options : dict, optional The compilation options. See `buffering.__doc__`. sregistry : SymbolRegistry The symbol registry, to create unique names for buffers and Dimensions. """ - def __init__(self, function, d, accessv, options, sregistry): + def __init__(self, function, dim, d, accessv, cache, options, sregistry): # Parse compilation options async_degree = options['buf-async-degree'] callback = options['buf-callback'] self.function = function + self.dim = dim + self.d = d self.accessv = accessv - self.contraction_mapper = {} - self.index_mapper = defaultdict(dict) + self.index_mapper = {} self.sub_iterators = defaultdict(list) self.subdims_mapper = DefaultOrderedDict(set) @@ -347,12 +337,12 @@ def __init__(self, function, d, accessv, options, sregistry): # E.g., `u[time,x] + u[time+1,x] -> `ub[sb0,x] + ub[sb1,x]`, where `sb0` # and `sb1` are ModuloDimensions starting at `time` and `time+1` respectively dims = list(function.dimensions) - assert d in function.dimensions + assert dim in function.dimensions # Determine the buffer size, and therefore the span of the ModuloDimension, # along the contracting Dimension `d` - indices = filter_ordered(i.indices[d] for i in accessv.accesses) - slots = [i.subs({d: 0, d.spacing: 1}) for i in indices] + indices = filter_ordered(i.indices[dim] for i in accessv.accesses) + slots = [i.subs({dim: 0, dim.spacing: 1}) for i in indices] try: size = max(slots) - min(slots) + 1 except TypeError: @@ -371,16 +361,16 @@ def __init__(self, function, d, accessv, options, sregistry): else: size = async_degree - # Replace `d` with a suitable CustomDimension `bd` + # Create `xd` -- a contraction Dimension for `dim` try: - bd = sregistry.get('bds', (d, size)) + xd = sregistry.get('xds', (dim, size)) except KeyError: name = sregistry.make_name(prefix='db') - v = CustomDimension(name, 0, size-1, size, d) - bd = sregistry.setdefault('bds', (d, size), v) - self.contraction_mapper[d] = dims[dims.index(d)] = bd + v = CustomDimension(name, 0, size-1, size, dim) + xd = sregistry.setdefault('xds', (dim, size), v) + self.xd = dims[dims.index(dim)] = xd - # Finally create the ModuloDimensions as children of `bd` + # Finally create the ModuloDimensions as children of `xd` if size > 1: # Note: indices are sorted so that the semantic order (sb0, sb1, sb2) # follows SymPy's index ordering (time, time-1, time+1) after modulo @@ -391,37 +381,38 @@ def __init__(self, function, d, accessv, options, sregistry): key=lambda i: -np.inf if i - p == 0 else (i - p)) for i in indices: try: - md = sregistry.get('mds', (bd, i)) + md = sregistry.get('mds', (xd, i)) except KeyError: name = sregistry.make_name(prefix='sb') - v = ModuloDimension(name, bd, i, size) - md = sregistry.setdefault('mds', (bd, i), v) - self.index_mapper[d][i] = md + v = ModuloDimension(name, xd, i, size) + md = sregistry.setdefault('mds', (xd, i), v) + self.index_mapper[i] = md self.sub_iterators[d.root].append(md) else: assert len(indices) == 1 - self.index_mapper[d][indices[0]] = 0 + self.index_mapper[indices[0]] = 0 # Track the SubDimensions used to index into `function` for e in accessv.mapper: m = {i.root: i for i in e.free_symbols if isinstance(i, Dimension) and (i.is_Sub or not i.is_Derived)} - for d, v in m.items(): - self.subdims_mapper[d].add(v) + for d0, d1 in m.items(): + self.subdims_mapper[d0].add(d1) if any(len(v) > 1 for v in self.subdims_mapper.values()): # Non-uniform SubDimensions. At this point we're going to raise # an exception. It's either illegal or still unsupported for v in self.subdims_mapper.values(): for d0, d1 in combinations(v, 2): if d0.overlap(d1): - raise InvalidOperator("Cannot apply `buffering` to `%s` as it " - "is accessed over the overlapping " - " SubDimensions `<%s, %s>`" % - (function, d0, d1)) + raise InvalidOperator( + "Cannot apply `buffering` to `%s` as it is accessed " + "over the overlapping SubDimensions `<%s, %s>`" % + (function, d0, d1)) raise NotImplementedError("`buffering` does not support multiple " "non-overlapping SubDimensions yet.") else: - self.subdims_mapper = {d: v.pop() for d, v in self.subdims_mapper.items()} + self.subdims_mapper = {d0: v.pop() + for d0, v in self.subdims_mapper.items()} # Build and sanity-check the buffer IterationIntervals self.itintervals_mapper = {} @@ -429,37 +420,39 @@ def __init__(self, function, d, accessv, options, sregistry): for i in e.ispace.itintervals: v = self.itintervals_mapper.setdefault(i.dim, i.args) if v != self.itintervals_mapper[i.dim]: - raise NotImplementedError("Cannot apply `buffering` as the buffered " - "function `%s` is accessed over multiple, " - "non-compatible iteration spaces along the " - "Dimension `%s`" % (function.name, i.dim)) + raise NotImplementedError( + "Cannot apply `buffering` as the buffered function `%s` is " + "accessed over multiple, non-compatible iteration spaces " + "along the Dimension `%s`" % (function.name, i.dim)) # Also add IterationIntervals for initialization along `x`, should `xi` be # the only written Dimension in the `x` hierarchy - for d, (interval, _, _) in list(self.itintervals_mapper.items()): - for i in d._defines: + for d0, (interval, _, _) in list(self.itintervals_mapper.items()): + for i in d0._defines: self.itintervals_mapper.setdefault(i, (interval.relaxed, (), Forward)) # Finally create the actual buffer - kwargs = { - 'name': sregistry.make_name(prefix='%sb' % function.name), - 'dimensions': dims, - 'dtype': function.dtype, - 'halo': function.halo, - 'space': 'mapped', - 'mapped': function - } - try: - self.buffer = callback(function, **kwargs) - except TypeError: - self.buffer = Array(**kwargs) + if function in cache: + self.buffer = cache[function] + else: + kwargs = { + 'name': sregistry.make_name(prefix='%sb' % function.name), + 'dimensions': dims, + 'dtype': function.dtype, + 'halo': function.halo, + 'space': 'mapped', + 'mapped': function + } + try: + self.buffer = cache[function] = callback(function, **kwargs) + except TypeError: + self.buffer = cache[function] = Array(**kwargs) def __repr__(self): - return "Buffer[%s,<%s>]" % (self.buffer.name, - ','.join(str(i) for i in self.contraction_mapper)) + return "Buffer[%s,<%s>]" % (self.buffer.name, self.xd) @property def size(self): - return np.prod([v.symbolic_size for v in self.contraction_mapper.values()]) + return self.xd.symbolic_size @property def firstread(self): @@ -493,13 +486,6 @@ def has_uniform_subdims(self): def indexed(self): return self.buffer.indexed - @cached_property - def index_mapper_flat(self): - ret = {} - for mapper in self.index_mapper.values(): - ret.update(mapper) - return ret - @cached_property def writeto(self): """ @@ -516,8 +502,7 @@ def writeto(self): # in principle this could be accessed through a stencil interval = interval.translate(v0=-h.left, v1=h.right) except KeyError: - # E.g., the contraction Dimension `db0` - assert d in self.contraction_mapper.values() + assert d is self.xd interval, si, direction = Interval(d, 0, 0), (), Forward intervals.append(interval) sub_iterators[d] = si @@ -537,13 +522,13 @@ def written(self): intervals = [] sub_iterators = {} directions = {} - for dd in self.function.dimensions: + for dd in self.buffer.dimensions: d = dd.xreplace(self.subdims_mapper) try: interval, si, direction = self.itintervals_mapper[d] except KeyError: # E.g., d=time_sub - assert d.is_NonlinearDerived + assert d.is_NonlinearDerived or d.is_Custom d = d.root interval, si, direction = self.itintervals_mapper[d] intervals.append(interval) @@ -561,11 +546,8 @@ def readfrom(self): The `readfrom` IterationSpace, that is the iteration space that must be iterated over to update the buffer with the buffered Function values. """ - cdims = set().union(*[d._defines - for d in flatten(self.contraction_mapper.items())]) - - ispace0 = self.written.project(lambda d: d in cdims) - ispace1 = self.writeto.project(lambda d: d not in cdims) + ispace0 = self.written.project(lambda d: d in self.xd._defines) + ispace1 = self.writeto.project(lambda d: d not in self.xd._defines) extra = (ispace0.itdimensions + ispace1.itdimensions,) ispace = IterationSpace.union(ispace0, ispace1, relations=extra) @@ -573,49 +555,40 @@ def readfrom(self): return ispace @cached_property - def lastmap(self): + def lastidx(self): """ - A mapper from contracted Dimensions to a 2-tuple of indices representing, - respectively, the "last" write to the buffer and the "last" read from the - buffered Function. For example, `{time: (sb1, time+1)}`. + A 2-tuple of indices representing, respectively, the "last" write to the + buffer and the "last" read from the buffered Function. For example, + `(sb1, time+1)` in the case of a forward-propagating `time` Dimension. """ - mapper = {} - for d, m in self.index_mapper.items(): - try: - func = max if self.written.directions[d.root] is Forward else min - v = func(m) - except TypeError: - func = vmax if self.written.directions[d.root] is Forward else vmin - v = func(*[Vector(i) for i in m])[0] - mapper[d] = Map(m[v], v) + try: + func = max if self.written[self.d].direction is Forward else min + v = func(self.index_mapper) + except TypeError: + func = vmax if self.written[self.d].direction is Forward else vmin + v = func(*[Vector(i) for i in self.index_mapper])[0] - return mapper + return Map(self.index_mapper[v], v) @cached_property - def initmap(self): + def firstidx(self): """ - A mapper from contracted Dimensions to indices representing the min points - for buffer initialization. For example, in the case of a forward-propagating - `time` Dimension, we could have `{time: (time_m + db0) % 2, (time_m + db0)}`; - likewise, for backwards, `{time: (time_M - 2 + db0) % 4, time_M - 2 + db0}`. + A 2-tuple of indices representing the min points for buffer initialization. + For example, in the case of a forward-propagating `time` Dimension, we + could have `((time_m + db0) % 2, (time_m + db0))`; likewise, for + backwards, `((time_M - 2 + db0) % 4, time_M - 2 + db0)`. """ - mapper = {} - for d, bd in self.contraction_mapper.items(): - indices = list(self.index_mapper[d]) - - # The buffer is initialized at `d_m(d_M) - offset`. E.g., a buffer with - # six slots, used to replace a buffered Function accessed at `d-3`, `d` - # and `d + 2`, will have `offset = 3` - p, offset = offset_from_centre(d, indices) + # The buffer is initialized at `d_m(d_M) - offset`. E.g., a buffer with + # six slots, used to replace a buffered Function accessed at `d-3`, `d` + # and `d + 2`, will have `offset = 3` + p, offset = offset_from_centre(self.dim, list(self.index_mapper)) - if self.written.directions[d.root] is Forward: - v = p.subs(d.root, d.root.symbolic_min) - offset + bd - else: - v = p.subs(d.root, d.root.symbolic_max) - offset + bd - - mapper[d] = Map(v % bd.symbolic_size, v) + if self.written[self.dim].direction is Forward: + v = p._subs(self.dim.root, self.dim.root.symbolic_min) - offset + self.xd + else: + v = p._subs(self.dim.root, self.dim.root.symbolic_max) - offset + self.xd - return mapper + return Map(v % self.xd.symbolic_size, v) class AccessValue(object): diff --git a/devito/passes/iet/asynchrony.py b/devito/passes/iet/asynchrony.py index 7e1ff6e36b..1350091aca 100644 --- a/devito/passes/iet/asynchrony.py +++ b/devito/passes/iet/asynchrony.py @@ -7,6 +7,7 @@ Conditional, Dereference, DummyExpr, FindNodes, FindSymbols, Iteration, List, PointerCast, Return, ThreadCallable, Transformer, While, maybe_alias) +from devito.passes.iet.definitions import DataManager from devito.passes.iet.engine import iet_pass from devito.symbolics import (CondEq, CondNe, FieldFromComposite, FieldFromPointer, Null) @@ -22,6 +23,7 @@ def pthreadify(graph, **kwargs): lower_async_callables(graph, track=track, root=graph.root, **kwargs) lower_async_calls(graph, track=track, **kwargs) + DataManager(**kwargs).place_definitions(graph) @iet_pass @@ -45,10 +47,11 @@ def lower_async_callables(iet, track=None, root=None, sregistry=None): # The `cfields` are the constant fields, that is the fields whose value # definitely never changes across different executions of `ìet`; the # `ncfields` are instead the non-constant fields, that is the fields whose - # value may or may not change across different calls to `iet` + # value may or may not change across different calls to `iet`. Clearly objects + # passed by pointer don't really matter fields = iet.parameters defines = FindSymbols('defines').visit(root.body) - ncfields, cfields = split(fields, lambda i: i in defines) + ncfields, cfields = split(fields, lambda i: i in defines and i.is_Symbol) # Postprocess `ncfields` ncfields = sanitize_ncfields(ncfields) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 43139d667a..e4195c95c6 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -326,7 +326,7 @@ def place_definitions(self, iet, globs=None, **kwargs): # Process all other definitions, essentially all temporary objects # created by the compiler up to this point (Array, LocalObject, etc.) storage = Storage() - defines = FindSymbols('defines-aliases').visit(iet) + defines = FindSymbols('defines-aliases|globals').visit(iet) for i in FindSymbols().visit(iet): if i in defines: @@ -343,7 +343,7 @@ def place_definitions(self, iet, globs=None, **kwargs): self._alloc_mapped_array_on_high_bw_mem(iet, i, storage) elif i._mem_stack: self._alloc_array_on_low_lat_mem(iet, i, storage) - else: + elif globs is not None: # Track, to be handled by the EntryFunction being a global obj! globs.add(i) elif i.is_ObjectArray: @@ -384,10 +384,15 @@ def place_casts(self, iet, **kwargs): # (ii) Declaring a raw pointer, e.g., `float * r0 = NULL; *malloc(&(r0), ...) defines = set(FindSymbols('defines|globals').visit(iet)) bases = sorted({i.base for i in indexeds}, key=lambda i: i.name) + + # Some objects don't distinguish their _C_symbol because they are known, + # by construction, not to require it, thus making the generated code + # cleaner. These objects don't need a cast + bases = [i for i in bases if i.name != i.function._C_name] + + # Create and attach the type casts casts = tuple(self.lang.PointerCast(i.function, obj=i) for i in bases if i not in defines) - - # Incorporate the newly created casts if casts: iet = iet._rebuild(body=iet.body._rebuild(casts=casts + iet.body.casts)) diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 72ef3f324e..c29c625122 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -407,6 +407,11 @@ def update_args(root, efuncs, dag): defines = FindSymbols('defines').visit(root.body) drop_args = [a for a in root.parameters if a in defines] + # 3) removed a symbol that was previously necessary (e.g., `x_size` after + # linearization) + symbols = FindSymbols('basics').visit(root.body) + drop_args.extend(a for a in root.parameters if a.is_Symbol and a not in symbols) + if not (new_args or drop_args): return efuncs diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index d861408bff..9866e755b7 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -186,8 +186,8 @@ def linearize_accesses(iet, key0, tracker=None, sregistry=None, options=None, # 5) Attach `stmts0` and `stmts1` to `iet` if stmts0: assert len(stmts1) > 0 - stmts = filter_ordered(stmts0) + [BlankLine] + stmts1 + [BlankLine] - body = iet.body._rebuild(body=tuple(stmts) + iet.body.body) + stmts = filter_ordered(stmts0) + [BlankLine] + stmts1 + body = iet.body._rebuild(strides=stmts) iet = iet._rebuild(body=body) else: assert len(stmts1) == 0 diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 7f237b0489..ebafdfb460 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -14,7 +14,7 @@ @iet_pass -def optimize_halospots(iet): +def optimize_halospots(iet, **kwargs): """ Optimize the HaloSpots in ``iet``. HaloSpots may be dropped, merged and moved around in order to improve the halo exchange performance. @@ -22,7 +22,7 @@ def optimize_halospots(iet): iet = _drop_halospots(iet) iet = _hoist_halospots(iet) iet = _merge_halospots(iet) - iet = _drop_if_unwritten(iet) + iet = _drop_if_unwritten(iet, **kwargs) iet = _mark_overlappable(iet) return iet, {} @@ -39,9 +39,9 @@ def _drop_halospots(iet): # If all HaloSpot reads pertain to reductions, then the HaloSpot is useless for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items(): - for f in hs.fmapper: - scope = Scope([i.expr for i in expressions]) - if all(i.is_reduction for i in scope.reads.get(f, [])): + scope = Scope([i.expr for i in expressions]) + for f, v in scope.reads.items(): + if f in hs.fmapper and all(i.is_reduction for i in v): mapper[hs].add(f) # Transform the IET introducing the "reduced" HaloSpots @@ -73,7 +73,7 @@ def rule1(dep, candidates, loc_dims): # A reduction isn't a stopper to hoisting return dep.write is not None and dep.write.is_reduction - hoist_rules = [rule0, rule1] + rules = [rule0, rule1] # Precompute scopes to save time scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} @@ -92,13 +92,16 @@ def rule1(dep, candidates, loc_dims): for n, i in enumerate(iters): candidates = [i.dim._defines for i in iters[n:]] - test = True + all_candidates = set().union(*candidates) + reads = scopes[i].getreads(f) + if any(set(a.ispace.dimensions) & all_candidates + for a in reads): + continue + for dep in scopes[i].d_flow.project(f): - if any(rule(dep, candidates, loc_dims) for rule in hoist_rules): - continue - test = False - break - if test: + if not any(r(dep, candidates, loc_dims) for r in rules): + break + else: hsmapper[hs] = hsmapper[hs].drop(f) imapper[i].append(hs.halo_scheme.project(f)) break @@ -148,7 +151,7 @@ def rule2(dep, hs, loc_indices): return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v for d, v in loc_indices.items()) - merge_rules = [rule0, rule1, rule2] + rules = [rule0, rule1, rule2] # Analysis mapper = {} @@ -165,13 +168,10 @@ def rule2(dep, hs, loc_indices): mapper[hs] = hs.halo_scheme for f, v in hs.fmapper.items(): - test = True for dep in scope.d_flow.project(f): - if any(rule(dep, hs, v.loc_indices) for rule in merge_rules): - continue - test = False - break - if test: + if not any(r(dep, hs, v.loc_indices) for r in rules): + break + else: try: mapper[hs0] = HaloScheme.union([mapper[hs0], hs.halo_scheme.project(f)]) @@ -191,21 +191,27 @@ def rule2(dep, hs, loc_indices): return iet -def _drop_if_unwritten(iet): +def _drop_if_unwritten(iet, options=None, **kwargs): """ Drop HaloSpots for unwritten Functions. Notes ----- - This may be relaxed if Devito+MPI were to be used within existing - legacy codes, which would call the generated library directly. + This may be relaxed if Devito were to be used within existing legacy codes, + which would call the generated library directly. """ + drop_unwritten = options['dist-drop-unwritten'] + if not callable(drop_unwritten): + key = lambda f: drop_unwritten + else: + key = drop_unwritten + # Analysis writes = {i.write for i in FindNodes(Expression).visit(iet)} mapper = {} for hs in FindNodes(HaloSpot).visit(iet): for f in hs.fmapper: - if f not in writes: + if f not in writes and key(f): mapper[hs] = mapper.get(hs, hs.halo_scheme).drop(f) # Post-process analysis @@ -321,7 +327,7 @@ def mpiize(graph, **kwargs): options = kwargs['options'] if options['optcomms']: - optimize_halospots(graph) + optimize_halospots(graph, **kwargs) mpimode = options['mpi'] if mpimode: diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index 819246a844..eccb506e7b 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -471,6 +471,11 @@ def _make_parallel(self, iet): if not candidates: continue + # Ignore if already a ParallelIteration (e.g., by-product of + # recursive compilation) + if any(isinstance(n, ParallelIteration) for n in candidates): + continue + # Outer parallelism root, partree = self._make_partree(candidates, index=i) if partree is None or root in mapper: diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 879f6b7250..87a56c13b7 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -513,8 +513,7 @@ class DefFunction(Function, Pickable): https://github.com/sympy/sympy/issues/4297 """ - __rargs__ = ('name',) - __rkwargs__ = ('arguments',) + __rargs__ = ('name', 'arguments') def __new__(cls, name, arguments=None, **kwargs): _arguments = [] @@ -550,6 +549,8 @@ def __str__(self): def _sympystr(self, printer): return str(self) + func = Pickable._rebuild + # Pickling support __reduce_ex__ = Pickable.__reduce_ex__ diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index eb35c68e91..b4ea829b70 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -2,11 +2,11 @@ from collections.abc import Iterable from functools import singledispatch -from sympy import Pow, Add, Mul, Min, Max, SympifyError, sympify +from sympy import Pow, Add, Mul, Min, Max, SympifyError, Tuple, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort -from devito.symbolics.extended_sympy import rfunc +from devito.symbolics.extended_sympy import DefFunction, rfunc from devito.symbolics.queries import q_leaf from devito.symbolics.search import retrieve_indexed, retrieve_functions from devito.tools import as_list, as_tuple, flatten, split, transitive_closure @@ -22,7 +22,7 @@ def uxreplace(expr, rule): """ - An alternative to SymPy's ``xreplace`` for when the caller can guarantee + An alternative to SymPy's `xreplace` for when the caller can guarantee that no re-evaluations are necessary or when re-evaluations should indeed be avoided at all costs (e.g., to prevent SymPy from unpicking Devito transformations, such as factorization). @@ -33,12 +33,15 @@ def uxreplace(expr, rule): By avoiding re-evaluations, this function is typically much quicker than SymPy's xreplace. - A further feature of ``uxreplace`` consists of enabling the substitution - of compound nodes. Consider the expression `a*b*c*d`; if one wants to replace - `b*c` with say `e*f`, then the following mapper may be passed: - `{a*b*c*d: {b: e*f, c: None}}`. This way, only the `b` and `c` pertaining to - `a*b*c*d` will be affected, and in particular `c` will be dropped, while `b` - will be replaced by `e*f`, thus obtaining `a*d*e*f`. + A further feature of `uxreplace` is the support for compound nodes. + Consider the expression `a*b*c*d`; if one wants to replace `b*c` with say + `e*f`, then the following mapper may be passed: `{a*b*c*d: {b: e*f, c: None}}`. + This way, only the `b` and `c` pertaining to `a*b*c*d` will be affected, and + in particular `c` will be dropped, while `b` will be replaced by `e*f`, thus + obtaining `a*d*e*f`. + + Finally, `uxreplace` supports Reconstructable objects, that is, it searches + for replacement opportunities inside the Reconstructable's `__rkwargs__`. """ return _uxreplace(expr, rule)[0] @@ -66,35 +69,73 @@ def _uxreplace(expr, rule): changed = False if rule: - for a in eargs: - try: - ax, flag = _uxreplace(a, rule) - args.append(ax) - changed |= flag - except AttributeError: - # E.g., un-sympified numbers - args.append(a) + eargs, flag = _uxreplace_dispatch(eargs, rule) + args.extend(eargs) + changed |= flag + + # If a Reconstructable object, we need to parse the kwargs as well + if _uxreplace_registry.dispatchable(expr): + v = {i: getattr(expr, i) for i in expr.__rkwargs__} + kwargs, flag = _uxreplace_dispatch(v, rule) + else: + kwargs, flag = {}, False + changed |= flag + if changed: - return _uxreplace_handle(expr, args), True + return _uxreplace_handle(expr, args, kwargs), True return expr, False @singledispatch -def _uxreplace_handle(expr, args): +def _uxreplace_dispatch(unknown, rule): + return unknown, False + + +@_uxreplace_dispatch.register(Basic) +def _(expr, rule): + return _uxreplace(expr, rule) + + +@_uxreplace_dispatch.register(tuple) +@_uxreplace_dispatch.register(Tuple) +@_uxreplace_dispatch.register(list) +def _(iterable, rule): + ret = [] + changed = False + for a in iterable: + ax, flag = _uxreplace(a, rule) + ret.append(ax) + changed |= flag + return iterable.__class__(ret), changed + + +@_uxreplace_dispatch.register(dict) +def _(mapper, rule): + ret = {} + changed = False + for k, v in mapper.items(): + vx, flag = _uxreplace_dispatch(v, rule) + ret[k] = vx + changed |= flag + return ret, changed + + +@singledispatch +def _uxreplace_handle(expr, args, kwargs): return expr.func(*args) @_uxreplace_handle.register(Min) @_uxreplace_handle.register(Max) @_uxreplace_handle.register(Pow) -def _(expr, args): +def _(expr, args, kwargs): evaluate = all(i.is_Number for i in args) return expr.func(*args, evaluate=evaluate) @_uxreplace_handle.register(Add) -def _(expr, args): +def _(expr, args, kwargs): if all(i.is_commutative for i in args): _addsort(args) _eval_numbers(expr, args) @@ -104,7 +145,7 @@ def _(expr, args): @_uxreplace_handle.register(Mul) -def _(expr, args): +def _(expr, args, kwargs): if all(i.is_commutative for i in args): _mulsort(args) _eval_numbers(expr, args) @@ -113,21 +154,39 @@ def _(expr, args): return expr._new_rawargs(*args) -@_uxreplace_handle.register(ComponentAccess) -@_uxreplace_handle.register(Eq) -def _(expr, args): - # Handler for all other Reconstructable objects - kwargs = {i: getattr(expr, i) for i in expr.__rkwargs__} +def _uxreplace_handle_reconstructable(expr, args, kwargs): return expr.func(*args, **kwargs) -def _eval_numbers(expr, args): +class UxreplaceRegistry(list): + """ - Helper function for in-place reduction of the expr arguments. + A registry used by `uxreplace` to handle Reconstructable objects. These + may differ from canonincal SymPy objects since: + + * They carry one or more fields (the so called "__rkwargs__") in addition + to the classic SymPy arguments. + * An `__rkwargs__`, in turn, may or may not be a SymPy object. + + The user may then use this registry to register callbacks to be used when + one such Reconstructable object is encountered. """ - numbers, others = split(args, lambda i: i.is_Number) - if len(numbers) > 1: - args[:] = [expr.func(*numbers)] + others + + def register(self, cls, rkwargs_callback_mapper=None): + self.append(cls) + _uxreplace_handle.register(cls, _uxreplace_handle_reconstructable) + + for kls, callback in (rkwargs_callback_mapper or {}).items(): + _uxreplace_dispatch.register(kls, callback) + + def dispatchable(self, obj): + return isinstance(obj, tuple(self)) + + +_uxreplace_registry = UxreplaceRegistry() +_uxreplace_registry.register(Eq) +_uxreplace_registry.register(DefFunction) +_uxreplace_registry.register(ComponentAccess) class Uxmapper(dict): @@ -203,6 +262,15 @@ def xreplace_indices(exprs, mapper, key=None): return replaced if isinstance(exprs, Iterable) else replaced[0] +def _eval_numbers(expr, args): + """ + Helper function for in-place reduction of the expr arguments. + """ + numbers, others = split(args, lambda i: i.is_Number) + if len(numbers) > 1: + args[:] = [expr.func(*numbers)] + others + + def pow_to_mul(expr): if q_leaf(expr) or isinstance(expr, Basic): return expr diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 89ee0e43dd..616c916059 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -2,12 +2,15 @@ import numpy as np import scipy.sparse +from conftest import assert_structure from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, MatrixSparseTimeFunction, SparseTimeFunction, SubDimension, - SubDomain, SubDomainSet, TimeFunction, Operator, configuration) + SubDomain, SubDomainSet, TimeFunction, Operator, configuration, + switchconfig) from devito.arch import get_gpu_info from devito.exceptions import InvalidArgument -from devito.ir import Expression, Section, FindNodes, FindSymbols, retrieve_iteration_tree +from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols, + retrieve_iteration_tree) from devito.passes.iet.languages.openmp import OmpIteration from devito.types import DeviceID, DeviceRM, Lock, NPThreads, PThreadArray @@ -1094,6 +1097,36 @@ def test_streaming_split_noleak(self): assert np.all(u.data[0] == u1.data[0]) assert np.all(u.data[1] == u1.data[1]) + @pytest.mark.skip(reason="Unsupported MPI + .dx when streaming backwards") + @pytest.mark.parallel(mode=4) + @switchconfig(safe_math=True) # Or NVC will crash + def test_streaming_w_mpi(self): + nt = 5 + grid = Grid(shape=(16, 16)) + + u = TimeFunction(name='u', grid=grid) + usave = TimeFunction(name='usave', grid=grid, save=nt, space_order=4) + vsave = TimeFunction(name='vsave', grid=grid, save=nt, space_order=4) + vsave1 = TimeFunction(name='vsave', grid=grid, save=nt, space_order=4) + + eqns = [Eq(u.backward, u + 1.), + Eq(vsave, usave.dx2)] + + key = lambda f: f is not usave + + op0 = Operator(eqns, opt='noop') + op1 = Operator(eqns, opt=('buffering', 'streaming', 'orchestrate', + {'dist-drop-unwritten': key, + 'gpu-fit': [vsave]})) + + for i in range(nt): + usave.data[i] = i + + op0.apply() + op1.apply(vsave=vsave1) + + assert np.all(vsave.data, vsave1.data, rtol=1.e-5) + @pytest.mark.parametrize('opt,opt_options,gpu_fit', [ (('buffering', 'streaming', 'orchestrate'), {}, False), (('buffering', 'streaming', 'orchestrate'), {'linearize': True}, False) @@ -1208,6 +1241,40 @@ def test_place_transfers(self): assert 'update from(u' not in str(op) assert 'map(release: u' not in str(op) + def test_fuse_compatible_guards(self): + nt = 10 + grid = Grid(shape=(8, 8)) + time_dim = grid.time_dim + + factor = Constant(name='factor', value=2, dtype=np.int32) + time_sub = ConditionalDimension(name="time_sub", parent=time_dim, factor=factor) + + f = TimeFunction(name='f', grid=grid) + fsave = TimeFunction(name='fsave', grid=grid, + save=int(nt//factor.data), time_dim=time_sub) + gsave = TimeFunction(name='gsave', grid=grid, + save=int(nt//factor.data), time_dim=time_sub) + + eqns = [Eq(f.forward, f + 1.), + Eq(fsave, f.forward), + Eq(gsave, f.forward)] + + op = Operator(eqns, opt=('buffering', 'tasking', 'orchestrate', + {'gpu-fit': [gsave]})) + + op.apply(time_M=nt-1) + + assert all(np.all(fsave.data[i] == 2*i + 1) for i in range(fsave.save)) + assert all(np.all(gsave.data[i] == 2*i + 1) for i in range(gsave.save)) + + # Check generated code + assert_structure(op, ['t,x,y', 't', 't,x,y', 't,x,y'], + 't,x,y,x,y,x,y') + nodes = FindNodes(Conditional).visit(op) + assert len(nodes) == 2 + assert len(nodes[1].then_body) == 3 + assert len(retrieve_iteration_tree(nodes[1])) == 2 + class TestAPI(object): diff --git a/tests/test_iet.py b/tests/test_iet.py index 1d44e4b3ff..eb6ddcc7d7 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -9,6 +9,8 @@ from devito.ir.iet import (Call, Callable, Conditional, DummyExpr, Iteration, List, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations, make_efunc, retrieve_iteration_tree) +from devito.ir import SymbolRegistry +from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import Byref, FieldFromComposite, InlineIf, Macro from devito.tools import as_tuple @@ -326,3 +328,25 @@ def test_templates(): { u(x, y) = 1; }""" + + +def test_codegen_quality0(): + grid = Grid(shape=(4, 4, 4)) + _, y, z = grid.dimensions + + a = Array(name='a', dimensions=grid.dimensions) + + expr = DummyExpr(a.indexed, 1) + foo = Callable('foo', expr, 'void', + parameters=[a, y.symbolic_size, z.symbolic_size]) + + # Emulate what the compiler would do + graph = Graph(foo) + + CDataManager(sregistry=SymbolRegistry()).process(graph) + + foo1 = graph.root + + assert len(foo.parameters) == 3 + assert len(foo1.parameters) == 1 + assert foo1.parameters[0] is a diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 8919df3b44..45347bf605 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -298,9 +298,9 @@ def test_strides_forwarding0(): foo = graph.root bar = graph.efuncs['bar'] - assert foo.body.body[0].write.name == 'y_fsz0' - assert foo.body.body[2].write.name == 'y_stride0' - assert len(foo.body.body[4].arguments) == 2 + assert foo.body.strides[0].write.name == 'y_fsz0' + assert foo.body.strides[2].write.name == 'y_stride0' + assert len(foo.body.body[0].arguments) == 2 assert len(bar.parameters) == 2 assert bar.parameters[1].name == 'y_stride0' @@ -333,9 +333,10 @@ def test_strides_forwarding1(): assert len(foo.body.body) == 1 assert foo.body.body[0].is_Call - assert len(bar.body.body) == 5 - assert bar.body.body[0].write.name == 'y_fsz0' - assert bar.body.body[2].write.name == 'y_stride0' + assert len(bar.body.body) == 1 + assert len(bar.body.strides) == 3 + assert bar.body.strides[0].write.name == 'y_fsz0' + assert bar.body.strides[2].write.name == 'y_stride0' def test_strides_forwarding2(): @@ -380,9 +381,9 @@ def test_strides_forwarding2(): assert all(i.is_Call for i in root.body.body) for foo in [foo0, foo1]: - assert foo.body.body[0].write.name == 'y_fsz0' - assert foo.body.body[2].write.name == 'y_stride0' - assert len(foo.body.body[4].arguments) == 2 + assert foo.body.strides[0].write.name == 'y_fsz0' + assert foo.body.strides[2].write.name == 'y_stride0' + assert len(foo.body.body[0].arguments) == 2 for bar in [bar0, bar1]: assert len(bar.parameters) == 2 @@ -414,10 +415,10 @@ def test_strides_forwarding3(): root = graph.root bar = graph.efuncs['bar'] - assert root.body.body[0].write.name == 'y_fsz0' - assert root.body.body[0].write.dtype is np.int64 - assert root.body.body[2].write.name == 'y_stride0' - assert root.body.body[2].write.dtype is np.int64 + assert root.body.strides[0].write.name == 'y_fsz0' + assert root.body.strides[0].write.dtype is np.int64 + assert root.body.strides[2].write.name == 'y_stride0' + assert root.body.strides[2].write.dtype is np.int64 assert bar.parameters[1].name == 'y_stride0' @@ -445,9 +446,9 @@ def test_strides_forwarding4(): root = graph.root bar = graph.efuncs['bar'] - assert root.body.body[0].write.name == 'y_fsz0' - assert root.body.body[2].write.name == 'y_stride0' - assert root.body.body[4].arguments[1].name == 'y_stride0' + assert root.body.strides[0].write.name == 'y_fsz0' + assert root.body.strides[2].write.name == 'y_stride0' + assert root.body.body[0].arguments[1].name == 'y_stride0' assert bar.parameters[1].name == 'y_stride0' @@ -518,8 +519,8 @@ def test_call_retval_indexed(): foo = graph.root - assert foo.body.body[0].write.name == 'y_fsz0' - assert foo.body.body[2].write.name == 'y_stride0' + assert foo.body.strides[0].write.name == 'y_fsz0' + assert foo.body.strides[2].write.name == 'y_stride0' assert str(foo.body.body[-1]) == 'vL0(x, y) = bar(f);' @@ -549,8 +550,8 @@ def test_bundle(): assert f not in bar.parameters assert g not in bar.parameters - assert foo.body.body[0].write.name == 'y_fsz0' - y_stride0 = foo.body.body[2].write + assert foo.body.strides[0].write.name == 'y_fsz0' + y_stride0 = foo.body.strides[2].write assert y_stride0.name == 'y_stride0' assert y_stride0 in bar.parameters diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 965ca64055..985ab09592 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1347,6 +1347,26 @@ def test_many_functions(self): assert len(calls) == 2 assert calls[0].ncomps == 7 + @pytest.mark.parallel(mode=1) + def test_enforce_haloupdate_if_unwritten_function(self): + grid = Grid(shape=(16, 16)) + + u = TimeFunction(name='u', grid=grid) + v = TimeFunction(name='v', grid=grid) + w = TimeFunction(name='w', grid=grid) + usave = TimeFunction(name='usave', grid=grid, save=10, space_order=4) + + eqns = [Eq(w.forward, v.forward.dx + w + 1., subdomain=grid.interior), + Eq(u.forward, u + 1.), + Eq(v.forward, u.forward + usave.dx4, subdomain=grid.interior)] + + key = lambda f: f is not usave + + op = Operator(eqns, opt=('advanced', {'dist-drop-unwritten': key})) + + calls = FindNodes(Call).visit(op) + assert len(calls) == 2 # One for `v` and one for `usave` + class TestOperatorAdvanced(object): diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 8015bb403a..807ed35216 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -9,8 +9,9 @@ Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max) from devito.ir import Expression, FindNodes from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa - CallFromPointer, Cast, FieldFromPointer, INT, - FieldFromComposite, IntDiv, ccode, uxreplace) + CallFromPointer, Cast, DefFunction, FieldFromPointer, + INT, FieldFromComposite, IntDiv, ccode, uxreplace) +from devito.tools import as_tuple from devito.types import Array, Bundle, LocalObject, Object, Symbol as dSymbol @@ -368,6 +369,36 @@ def test_uxreplace(expr, subs, expected): assert uxreplace(eval(expr), eval(subs)) == eval(expected) +def test_uxreplace_custom_reconstructable(): + + class MyDefFunction(DefFunction): + __rargs__ = ('name', 'arguments') + __rkwargs__ = ('p0', 'p1', 'p2') + + def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None): + obj = super().__new__(cls, name=name, arguments=arguments) + obj.p0 = p0 + obj.p1 = as_tuple(p1) + obj.p2 = p2 + return obj + + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + func = MyDefFunction(name='foo', arguments=f.indexify(), + p0=f, p1=f, p2='bar') + + mapper = {f: g, f.indexify(): g.indexify()} + func1 = uxreplace(func, mapper) + + assert func1.arguments == (g.indexify(),) + assert func1.p0 is g + assert func1.p1 == (g,) + assert func1.p2 == 'bar' + + def test_minmax(): grid = Grid(shape=(5, 5)) x, y = grid.dimensions