Skip to content

Commit

Permalink
Merge pull request #2376 from devitocodes/revamp-async-codegen-final
Browse files Browse the repository at this point in the history
compiler: Revamp code generation for asynchronous operations
  • Loading branch information
FabioLuporini authored Jun 18, 2024
2 parents 210f899 + 0decbb2 commit be7c403
Show file tree
Hide file tree
Showing 54 changed files with 2,094 additions and 1,587 deletions.
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def pytest_runtest_call(item):

elif item.get_closest_marker("parallel"):
# Spawn parallel processes to run test

outcome = parallel(item, item.funcargs['mode'])
if outcome:
pytest.skip(f"{item} success in parallel")
Expand Down
22 changes: 20 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.operator.operator import rcompile
from devito.passes import stream_dimensions
from devito.passes.equations import collect_derivatives
from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_pows,
Expand Down Expand Up @@ -92,6 +94,22 @@ def _normalize_kwargs(cls, **kwargs):

return kwargs

@classmethod
def _rcompile_wrapper(cls, **kwargs0):
options0 = kwargs0.pop('options')

def wrapper(expressions, options=None, **kwargs1):
options = {**options0, **(options or {})}
kwargs = {**kwargs0, **kwargs1}

# User-provided openmp flag has precedence over defaults
if not options['openmp']:
kwargs['language'] = 'C'

return rcompile(expressions, kwargs, options)

return wrapper


# Mode level

Expand Down Expand Up @@ -240,9 +258,9 @@ def _make_clusters_passes_mapper(cls, **kwargs):

# Callback used by `buffering`; it mimics `is_on_device`, which is used
# on device backends
def callback(f):
def callback(f, *args):
if f.is_TimeFunction and f.save is not None:
return f.time_dim
return stream_dimensions(f)
else:
return None

Expand Down
79 changes: 32 additions & 47 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.operator.operator import rcompile
from devito.passes import is_on_device
from devito.passes import is_on_device, stream_dimensions
from devito.passes.equations import collect_derivatives
from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, buffering,
cire, cse, factorize, fission, fuse,
from devito.passes.clusters import (Lift, tasking, memcpy_prefetch, blocking,
buffering, cire, cse, factorize, fission, fuse,
optimize_pows)
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, mpiize,
hoist_prodders, linearize, pthreadify,
Expand Down Expand Up @@ -116,11 +116,15 @@ def _normalize_gpu_fit(cls, oo, **kwargs):
return as_tuple(cls.GPU_FIT)

@classmethod
def _rcompile_wrapper(cls, **kwargs):
def wrapper(expressions, mode='default', **options):
def _rcompile_wrapper(cls, **kwargs0):
options0 = kwargs0.pop('options')

def wrapper(expressions, mode='default', options=None, **kwargs1):
options = {**options0, **(options or {})}
kwargs = {**kwargs0, **kwargs1}

if mode == 'host':
par_disabled = kwargs['options']['par-disabled']
par_disabled = options['par-disabled']
target = {
'platform': 'cpu64',
'language': 'C' if par_disabled else 'openmp',
Expand Down Expand Up @@ -266,15 +270,14 @@ def _make_clusters_passes_mapper(cls, **kwargs):
platform = kwargs['platform']
sregistry = kwargs['sregistry']

# Callbacks used by `buffering`, `Tasking` and `Streaming`
callback = lambda f: on_host(f, options)
runs_on_host, reads_if_on_host = make_callbacks(options)
callback = lambda f: not is_on_device(f, options['gpu-fit'])
stream_key = stream_wrap(callback)

return {
'buffering': lambda i: buffering(i, callback, sregistry, options),
'blocking': lambda i: blocking(i, sregistry, options),
'tasking': Tasker(runs_on_host, sregistry).process,
'streaming': Streaming(reads_if_on_host, sregistry).process,
'buffering': lambda i: buffering(i, stream_key, sregistry, options),
'tasking': lambda i: tasking(i, stream_key, sregistry),
'streaming': lambda i: memcpy_prefetch(i, stream_key, sregistry),
'factorize': factorize,
'fission': fission,
'fuse': lambda i: fuse(i, options=options),
Expand All @@ -294,7 +297,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
sregistry = kwargs['sregistry']

parizer = cls._Target.Parizer(sregistry, options, platform, compiler)
orchestrator = cls._Target.Orchestrator(sregistry)
orchestrator = cls._Target.Orchestrator(**kwargs)

return {
'parallel': parizer.make_parallel,
Expand Down Expand Up @@ -419,39 +422,21 @@ def _make_iet_passes_mapper(cls, **kwargs):
assert not (set(_known_passes) & set(DeviceCustomOperator._known_passes_disabled))


# Utils

def on_host(f, options):
# A Dimension in `f` defining an IterationSpace that definitely
# gets executed on the host, regardless of whether it's parallel
# or sequential
if not is_on_device(f, options['gpu-fit']):
return f.time_dim
else:
return None


def make_callbacks(options, key=None):
"""
Options-dependent callbacks used by various compiler passes.
"""

if key is None:
key = lambda f: on_host(f, options)

def runs_on_host(c):
# The only situation in which a Cluster doesn't get offloaded to
# the device is when it writes to a host Function
retval = {key(f) for f in c.scope.writes} - {None}
retval = set().union(*[d._defines for d in retval])
return retval

def reads_if_on_host(c):
if not runs_on_host(c):
retval = {key(f) for f in c.scope.reads} - {None}
retval = set().union(*[d._defines for d in retval])
return retval
# *** Utils

def stream_wrap(callback):
def stream_key(items, *args):
"""
Given one or more Functions `f(d_1, ...d_n)`, return the Dimensions
`(d_i, ..., d_n)` requiring data streaming.
"""
found = [f for f in as_tuple(items) if callback(f)]
retval = {stream_dimensions(f) for f in found}
if len(retval) > 1:
raise ValueError("Cannot determine homogenous stream Dimensions")
elif len(retval) == 1:
return retval.pop()
else:
return set()
return None

return runs_on_host, reads_if_on_host
return stream_key
7 changes: 5 additions & 2 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def clusterize(exprs, **kwargs):
clusters = Schedule().process(clusters)

# Handle SteppingDimensions
clusters = Stepper().process(clusters)
clusters = Stepper(**kwargs).process(clusters)

# Handle ConditionalDimensions
clusters = guard(clusters)
Expand Down Expand Up @@ -273,6 +273,9 @@ class Stepper(Queue):
sub-iterators induced by a SteppingDimension.
"""

def __init__(self, sregistry=None, **kwargs):
self.sregistry = sregistry

def callback(self, clusters, prefix):
if not prefix:
return clusters
Expand Down Expand Up @@ -326,7 +329,7 @@ def callback(self, clusters, prefix):
siafs = sorted(iafs, key=key)

for iaf in siafs:
name = '%s%d' % (si.name, len(mds))
name = self.sregistry.make_name(prefix='t')
offset = uxreplace(iaf, {si: d.root})
mds.append(ModuloDimension(name, si, offset, size, origin=iaf))

Expand Down
22 changes: 15 additions & 7 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from devito.ir.equations import ClusterizedEq
from devito.ir.support import (PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext,
Forward, Interval, IntervalGroup, IterationSpace,
DataSpace, Guards, Properties, Scope, WithLock,
PrefetchUpdate, detect_accesses, detect_io,
DataSpace, Guards, Properties, Scope, WaitLock,
WithLock, PrefetchUpdate, detect_accesses, detect_io,
normalize_properties, normalize_syncs, minimum,
maximum, null_ispace)
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.mpi.reduction_scheme import DistReduce
from devito.symbolics import estimate_cost
from devito.tools import as_tuple, flatten, frozendict, infer_dtype
from devito.tools import as_tuple, flatten, infer_dtype
from devito.types import WeakFence, CriticalRegion

__all__ = ["Cluster", "ClusterGroup"]
Expand Down Expand Up @@ -49,9 +49,8 @@ def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None,
self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs))
self._ispace = ispace
self._guards = Guards(guards or {})
self._syncs = frozendict(syncs or {})
self._syncs = normalize_syncs(syncs or {})

# Normalize properties
properties = Properties(properties or {})
self._properties = tailor_properties(properties, ispace)

Expand Down Expand Up @@ -279,6 +278,15 @@ def is_async(self):
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in flatten(self.syncs.values()))

@property
def is_wait(self):
"""
True if a Cluster waiting on a lock (that is a special synchronization
operation), False otherwise.
"""
return any(isinstance(s, WaitLock)
for s in flatten(self.syncs.values()))

@cached_property
def dtype(self):
"""
Expand Down Expand Up @@ -341,9 +349,9 @@ def dspace(self):
if len(ret) != 1:
continue
if ret.pop().direction is Forward:
intervals = intervals.translate(d, v1=-1)
intervals = intervals.translate(d._defines, v1=-1)
else:
intervals = intervals.translate(d, 1)
intervals = intervals.translate(d._defines, 1)
for d in self.properties:
if self.properties.is_inbound(d):
intervals = intervals.zero(d._defines)
Expand Down
8 changes: 6 additions & 2 deletions devito/ir/iet/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def iet_build(stree):
body = Conditional(i.guard, queues.pop(i))

elif i.is_Iteration:
body = Iteration(queues.pop(i), i.dim, i.limits, direction=i.direction,
properties=i.properties, uindices=i.sub_iterators)
if i.dim.is_Virtual:
body = List(body=queues.pop(i))
else:
body = Iteration(queues.pop(i), i.dim, i.limits,
direction=i.direction, properties=i.properties,
uindices=i.sub_iterators)

elif i.is_Section:
body = Section('section%d' % nsections, body=queues.pop(i))
Expand Down
7 changes: 7 additions & 0 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def functions(self):
launch_args += (self.stream.function,)
return super().functions + launch_args

@cached_property
def expr_symbols(self):
launch_symbols = (self.grid, self.block)
if self.stream is not None:
launch_symbols += (self.stream,)
return super().expr_symbols + launch_symbols


# Other relevant Callable subclasses

Expand Down
17 changes: 16 additions & 1 deletion devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from devito.ir.equations import DummyEq, OpInc, OpMin, OpMax
from devito.ir.support import (INBOUND, SEQUENTIAL, PARALLEL, PARALLEL_IF_ATOMIC,
PARALLEL_IF_PVT, VECTORIZED, AFFINE, Property,
Forward, detect_io)
Forward, WithLock, PrefetchUpdate, detect_io)
from devito.symbolics import ListInitializer, CallFromPointer, ccode
from devito.tools import (Signer, as_tuple, filter_ordered, filter_sorted, flatten,
ctypes_to_cstr)
Expand Down Expand Up @@ -1378,6 +1378,21 @@ def __init__(self, sync_ops, body=None):
def __repr__(self):
return "<SyncSpot (%s)>" % ",".join(str(i) for i in self.sync_ops)

@property
def is_async_op(self):
"""
True if the SyncSpot contains an asynchronous operation, False otherwise.
If False, the SyncSpot may for example represent a wait on a lock.
"""
return any(isinstance(s, (WithLock, PrefetchUpdate))
for s in self.sync_ops)

@property
def functions(self):
ret = [(s.lock, s.function, s.target) for s in self.sync_ops]
ret = tuple(filter_ordered(f for f in flatten(ret) if f is not None))
return ret


class CBlankLine(List):

Expand Down
3 changes: 2 additions & 1 deletion devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def derive_parameters(iet, drop_locals=False, ordering='default'):

# Maybe filter out all other compiler-generated objects
if drop_locals:
parameters = [p for p in parameters if not (p.is_ArrayBasic or p.is_LocalObject)]
parameters = [p for p in parameters
if not (p.is_ArrayBasic or p.is_LocalObject)]

# NOTE: This is requested by the caller when the parameters are used to
# construct Callables whose signature only depends on the object types,
Expand Down
21 changes: 20 additions & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,26 @@ def visit_Pragma(self, o):
def visit_PragmaTransfer(self, o):
function = uxreplace(o.function, self.mapper)
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
return o._rebuild(function=function, arguments=arguments)
if o.imask is None:
return o._rebuild(function=function, arguments=arguments)

# An `imask` may be None, a list of symbols/numbers, or a list of
# 2-tuples representing ranges
imask = []
for v in o.imask:
try:
i, j = v
imask.append((uxreplace(i, self.mapper),
uxreplace(j, self.mapper)))
except TypeError:
imask.append(uxreplace(v, self.mapper))
return o._rebuild(function=function, imask=imask, arguments=arguments)

def visit_ParallelTree(self, o):
prefix = self._visit(o.prefix)
body = self._visit(o.body)
nthreads = self.mapper.get(o.nthreads, o.nthreads)
return o._rebuild(prefix=prefix, body=body, nthreads=nthreads)

def visit_HaloSpot(self, o):
hs = o.halo_scheme
Expand Down
Loading

0 comments on commit be7c403

Please sign in to comment.