Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Misc compiler tweaks and improvements #2136

Merged
merged 18 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _normalize_kwargs(cls, **kwargs):
o['par-dynamic-work'] = oo.pop('par-dynamic-work', cls.PAR_DYNAMIC_WORK)
o['par-nested'] = oo.pop('par-nested', cls.PAR_NESTED)

# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
Expand Down
3 changes: 3 additions & 0 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _normalize_kwargs(cls, **kwargs):
o['gpu-fit'] = as_tuple(oo.pop('gpu-fit', cls._normalize_gpu_fit(**kwargs)))
o['gpu-create'] = as_tuple(oo.pop('gpu-create', ()))

# Distributed parallelism
o['dist-drop-unwritten'] = oo.pop('dist-drop-unwritten', cls.DIST_DROP_UNWRITTEN)

# Misc
o['expand'] = oo.pop('expand', cls.EXPAND)
o['optcomms'] = oo.pop('optcomms', True)
Expand Down
16 changes: 11 additions & 5 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class BasicOperator(Operator):
The supported MPI modes.
"""

DIST_DROP_UNWRITTEN = True
"""
Drop halo exchanges for read-only Function, even in presence of
stencil-like data accesses.
"""

INDEX_MODE = "int64"
"""
The type of the expression used to compute array indices. Either `int64`
Expand Down Expand Up @@ -281,7 +287,7 @@ def _specialize_iet(cls, graph, **kwargs):
# from HaloSpot optimization)
# Note that if MPI is disabled then this pass will act as a no-op
if 'mpi' not in passes:
passes_mapper['mpi'](graph)
passes_mapper['mpi'](graph, **kwargs)

# Run passes
applied = []
Expand All @@ -300,17 +306,17 @@ def _specialize_iet(cls, graph, **kwargs):
if 'init' not in passes:
passes_mapper['init'](graph)

# Enforce pthreads if CPU-GPU orchestration requested
if 'orchestrate' in passes and 'pthreadify' not in passes:
passes_mapper['pthreadify'](graph, sregistry=sregistry)

# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

# Linearize n-dimensional Indexeds
if 'linearize' not in passes and options['linearize']:
passes_mapper['linearize'](graph)

# Enforce pthreads if CPU-GPU orchestration requested
if 'orchestrate' in passes and 'pthreadify' not in passes:
passes_mapper['pthreadify'](graph, sregistry=sregistry)

return graph


Expand Down
22 changes: 17 additions & 5 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataSpace, Guards, Properties, Scope, detect_accesses,
detect_io, normalize_properties, normalize_syncs,
sdims_min, sdims_max)
from devito.mpi.halo_scheme import HaloTouch
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, frozendict, infer_dtype

Expand All @@ -26,7 +26,7 @@ class Cluster(object):
exprs : expr-like or list of expr-like
An ordered sequence of expressions computing a tensor.
ispace : IterationSpace, optional
The cluster iteration space.
The Cluster iteration space.
guards : dict, optional
Mapper from Dimensions to expr-like, representing the conditions under
which the Cluster should be computed.
Expand All @@ -37,9 +37,12 @@ class Cluster(object):
Mapper from Dimensions to lists of SyncOps, that is ordered sequences of
synchronization operations that must be performed in order to compute the
Cluster asynchronously.
halo_scheme : HaloScheme, optional
The halo exchanges required by the Cluster.
"""

def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None):
def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None,
halo_scheme=None):
ispace = ispace or IterationSpace([])

self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
Expand All @@ -57,6 +60,8 @@ def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None)
properties = properties.drop(d)
self._properties = properties

self._halo_scheme = halo_scheme

def __repr__(self):
return "Cluster([%s])" % ('\n' + ' '*9).join('%s' % i for i in self.exprs)

Expand Down Expand Up @@ -91,7 +96,9 @@ def from_clusters(cls, *clusters):
raise ValueError("Cannot build a Cluster from Clusters with "
"non-compatible synchronization operations")

return Cluster(exprs, ispace, guards, properties, syncs)
halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters])

return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme)

def rebuild(self, *args, **kwargs):
"""
Expand All @@ -110,7 +117,8 @@ def rebuild(self, *args, **kwargs):
ispace=kwargs.get('ispace', self.ispace),
guards=kwargs.get('guards', self.guards),
properties=kwargs.get('properties', self.properties),
syncs=kwargs.get('syncs', self.syncs))
syncs=kwargs.get('syncs', self.syncs),
halo_scheme=kwargs.get('halo_scheme', self.halo_scheme))

@property
def exprs(self):
Expand Down Expand Up @@ -144,6 +152,10 @@ def properties(self):
def syncs(self):
return self._syncs

@property
def halo_scheme(self):
return self._halo_scheme

@cached_property
def free_symbols(self):
return set().union(*[e.free_symbols for e in self.exprs])
Expand Down
26 changes: 21 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,13 @@ def functions(self):
retval.append(s.function)
except AttributeError:
continue

if self.base is not None:
retval.append(self.base.function)

if self.retobj is not None:
retval.append(self.retobj.function)

return tuple(filter_ordered(retval))

@cached_property
Expand All @@ -309,10 +312,15 @@ def expr_symbols(self):
retval.extend(i.free_symbols)
except AttributeError:
pass

if self.base is not None:
retval.append(self.base)
if self.retobj is not None:

if isinstance(self.retobj, Indexed):
retval.extend(self.retobj.free_symbols)
elif self.retobj is not None:
retval.append(self.retobj)

return tuple(filter_ordered(retval))

@property
Expand Down Expand Up @@ -744,6 +752,8 @@ class CallableBody(Node):
maps : Transfer or list of Transfer, optional
Data maps for `body` (a data map may e.g. trigger a data transfer from
host to device).
strides : list of Nodes, optional
Statements defining symbols used to access linearized arrays.
objs : list of Definitions, optional
Object definitions for `body`.
unmaps : Transfer or list of Transfer, optional
Expand All @@ -756,19 +766,21 @@ class CallableBody(Node):

is_CallableBody = True

_traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps', 'objs',
'body', 'unmaps', 'unbundles', 'frees']
_traversable = ['unpacks', 'init', 'allocs', 'casts', 'bundles', 'maps',
'strides', 'objs', 'body', 'unmaps', 'unbundles', 'frees']

def __init__(self, body, init=(), unpacks=(), allocs=(), casts=(),
def __init__(self, body, init=(), unpacks=(), strides=(), allocs=(), casts=(),
bundles=(), objs=(), maps=(), unmaps=(), unbundles=(), frees=()):
# Sanity check
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"

self.body = as_tuple(body)
self.init = as_tuple(init)

self.unpacks = as_tuple(unpacks)
self.init = as_tuple(init)
self.allocs = as_tuple(allocs)
self.casts = as_tuple(casts)
self.strides = as_tuple(strides)
self.bundles = as_tuple(bundles)
self.maps = as_tuple(maps)
self.objs = as_tuple(objs)
Expand Down Expand Up @@ -1201,6 +1213,10 @@ def __getattr__(self, name):
def functions(self):
return as_tuple(self.nthreads)

@property
def expr_symbols(self):
return as_tuple(self.nthreads)

@property
def root(self):
return self.body[0]
Expand Down
Loading