Skip to content

Commit

Permalink
Merge pull request #1736 from devitocodes/fuse-withlocks
Browse files Browse the repository at this point in the history
compiler: Add optimization option to fuse WithLocks tasks
  • Loading branch information
FabioLuporini authored Sep 7, 2021
2 parents 0b84fe1 + f114164 commit 0483968
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 43 deletions.
7 changes: 5 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', False)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -298,13 +301,13 @@ def callback(f):
'blocking': lambda i: blocking(i, options),
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
9 changes: 6 additions & 3 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', True)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -148,7 +151,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
sregistry = kwargs['sregistry']

# Toposort+Fusion (the former to expose more fusion opportunities)
clusters = fuse(clusters, toposort=True)
clusters = fuse(clusters, toposort=True, options=options)

# Fission to increase parallelism
clusters = fission(clusters)
Expand Down Expand Up @@ -245,13 +248,13 @@ def callback(f):
'streaming': Streaming(reads_if_on_host).process,
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
'MetaCall', 'PointerCast', 'ForeignExpression', 'HaloSpot', 'IterationTree',
'ExpressionBundle', 'AugmentedExpression', 'Increment', 'Return', 'While',
'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot',
'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait',
'CallableBody']
'Pragma', 'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree',
'BusyWait', 'CallableBody']

# First-class IET nodes

Expand Down
31 changes: 22 additions & 9 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.ir.clusters import Cluster, ClusterGroup, Queue
Expand Down Expand Up @@ -91,9 +91,13 @@ class Fusion(Queue):
Fuse Clusters with compatible IterationSpace.
"""

def __init__(self, toposort):
super(Fusion, self).__init__()
def __init__(self, toposort, options=None):
options = options or {}

self.toposort = toposort
self.fusetasks = options.get('fuse-tasks', False)

super().__init__()

def _make_key_hook(self, cgroup, level):
assert level > 0
Expand Down Expand Up @@ -137,15 +141,24 @@ def _key(self, c):

key = (frozenset(c.itintervals), c.guards)

# We allow fusing Clusters/ClusterGroups with WaitLocks over different Locks,
# while the WithLocks are to be kept separated (i.e. the remain separate tasks)
# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
if isinstance(c, Cluster):
sync_locks = (c.sync_locks,)
else:
sync_locks = c.sync_locks
for i in sync_locks:
key += (frozendict({k: frozenset(type(i) if i.is_WaitLock else i for i in v)
for k, v in i.items()}),)
mapper = defaultdict(set)
for k, v in i.items():
for s in v:
if s.is_WaitLock or \
(self.fusetasks and s.is_WithLock):
mapper[k].add(type(s))
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
mapper = frozendict(mapper)
key += (mapper,)

return key

Expand Down Expand Up @@ -243,14 +256,14 @@ def _build_dag(self, cgroups, prefix):


@timed_pass()
def fuse(clusters, toposort=False):
def fuse(clusters, toposort=False, options=None):
"""
Clusters fusion.
If ``toposort=True``, then the Clusters are reordered to maximize the likelihood
of fusion; the new ordering is computed such that all data dependencies are honored.
"""
return Fusion(toposort=toposort).process(clusters)
return Fusion(toposort, options).process(clusters)


@cluster_pass(mode='all')
Expand Down
15 changes: 11 additions & 4 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def _map_present(cls, f, imask=None):
"""
raise NotImplementedError

@classmethod
def _map_wait(cls, queueid=None):
"""
Explicitly wait on event.
"""
raise NotImplementedError

@classmethod
def _map_update(cls, f, imask=None):
"""
Expand All @@ -86,9 +93,9 @@ def _map_update_host(cls, f, imask=None, queueid=None):
raise NotImplementedError

@classmethod
def _map_update_wait_host(cls, f, imask=None, queueid=None):
def _map_update_host_async(cls, f, imask=None, queueid=None):
"""
Copy Function from device to host memory and explicitly wait.
Asynchronously copy Function from device to host memory.
"""
raise NotImplementedError

Expand All @@ -100,9 +107,9 @@ def _map_update_device(cls, f, imask=None, queueid=None):
raise NotImplementedError

@classmethod
def _map_update_wait_device(cls, f, imask=None, queueid=None):
def _map_update_device_async(cls, f, imask=None, queueid=None):
"""
Copy Function from host to device memory and explicitly wait.
Asynchronously copy Function from host to device memory and explicitly wait.
"""
raise NotImplementedError

Expand Down
20 changes: 10 additions & 10 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ class AccBB(PragmaLangBB):
c.Pragma('acc enter data create(%s%s)' % (i, j)),
'map-present': lambda i, j:
c.Pragma('acc data present(%s%s)' % (i, j)),
'map-wait': lambda i:
c.Pragma('acc wait(%s)' % i),
'map-update': lambda i, j:
c.Pragma('acc exit data copyout(%s%s)' % (i, j)),
'map-update-host': lambda i, j:
c.Pragma('acc update self(%s%s)' % (i, j)),
'map-update-wait-host': lambda i, j, k:
(c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
c.Pragma('acc wait(%s)' % k)),
'map-update-host-async': lambda i, j, k:
c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
'map-update-device': lambda i, j:
c.Pragma('acc update device(%s%s)' % (i, j)),
'map-update-wait-device': lambda i, j, k:
(c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
c.Pragma('acc wait(%s)' % k)),
'map-update-device-async': lambda i, j, k:
c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
'map-release': lambda i, j, k:
c.Pragma('acc exit data delete(%s%s)%s' % (i, j, k)),
'map-exit-delete': lambda i, j, k:
Expand Down Expand Up @@ -147,14 +147,14 @@ def _map_delete(cls, f, imask=None, devicerm=None):
return cls.mapper['map-exit-delete'](f.name, sections, cond)

@classmethod
def _map_update_wait_host(cls, f, imask=None, queueid=None):
def _map_update_host_async(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-wait-host'](f.name, sections, queueid)
return cls.mapper['map-update-host-async'](f.name, sections, queueid)

@classmethod
def _map_update_wait_device(cls, f, imask=None, queueid=None):
def _map_update_device_async(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-wait-device'](f.name, sections, queueid)
return cls.mapper['map-update-device-async'](f.name, sections, queueid)


class DeviceAccizer(PragmaDeviceAwareTransformer):
Expand Down
29 changes: 18 additions & 11 deletions devito/passes/iet/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from devito.data import FULL
from devito.ir.iet import (Call, Callable, Conditional, List, SyncSpot, FindNodes,
Transformer, BlankLine, BusyWait, PragmaTransfer,
Transformer, BlankLine, BusyWait, Pragma, PragmaTransfer,
DummyExpr, derive_parameters, make_thread_ctx)
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
Expand Down Expand Up @@ -54,17 +54,21 @@ def _make_withlock(self, iet, sync_ops, pieces, root):
# will never be more than 2 threads in flight concurrently
npthreads = min(i.size for i in locks)

preactions = []
postactions = []
preactions = [BlankLine]
for s in sync_ops:
imask = [s.handle.indices[d] if d.root in s.lock.locked_dimensions else FULL
for d in s.target.dimensions]
update = PragmaTransfer(self.lang._map_update_wait_host, s.target,
update = PragmaTransfer(self.lang._map_update_host_async, s.target,
imask=imask, queueid=SharedData._field_id)
preactions.append(List(body=[BlankLine, update, DummyExpr(s.handle, 1)]))
postactions.append(DummyExpr(s.handle, 2))
preactions.append(update)
wait = self.lang._map_wait(SharedData._field_id)
if wait is not None:
preactions.append(Pragma(wait))
preactions.extend([DummyExpr(s.handle, 1) for s in sync_ops])
preactions.append(BlankLine)
postactions.insert(0, BlankLine)

postactions = [BlankLine]
postactions.extend([DummyExpr(s.handle, 2) for s in sync_ops])

# Turn `iet` into a ThreadFunction so that it can be executed
# asynchronously by a pthread in the `npthreads` pool
Expand Down Expand Up @@ -120,7 +124,7 @@ def _make_fetchupdate(self, iet, sync_ops, pieces, *args):
def _make_prefetchupdate(self, iet, sync_ops, pieces, root):
fid = SharedData._field_id

postactions = []
postactions = [BlankLine]
for s in sync_ops:
# `pcond` is not None, but we won't use it here because the condition
# is actually already encoded in `iet` itself (it stems from the
Expand All @@ -129,8 +133,11 @@ def _make_prefetchupdate(self, iet, sync_ops, pieces, root):

imask = [(s.tstore, s.size) if d.root is s.dim.root else FULL
for d in s.dimensions]
postactions.append(PragmaTransfer(self.lang._map_update_wait_device,
postactions.append(PragmaTransfer(self.lang._map_update_device_async,
s.target, imask=imask, queueid=fid))
wait = self.lang._map_wait(fid)
if wait is not None:
postactions.append(Pragma(wait))

# Turn prefetch IET into a ThreadFunction
name = self.sregistry.make_name(prefix='prefetch_host_to_device')
Expand All @@ -156,8 +163,8 @@ def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
ff = SharedData._field_flag

waits = []
for s in sync_ops:
sdata, threads = pieces.objs.get(s)
objs = filter_ordered(pieces.objs.get(s) for s in sync_ops)
for sdata, threads in objs:
wait = BusyWait(CondNe(FieldFromComposite(ff, sdata[threads.index]), 1))
waits.append(wait)

Expand Down
12 changes: 10 additions & 2 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,14 @@ def _map_alloc(cls, f, imask=None):
def _map_present(cls, f, imask=None):
return

@classmethod
def _map_wait(cls, queueid=None):
try:
return cls.mapper['map-wait'](queueid)
except KeyError:
# Not all languages may provide an explicit wait construct
return None

@classmethod
def _map_update(cls, f, imask=None):
sections = cls._make_sections_from_imask(f, imask)
Expand All @@ -546,14 +554,14 @@ def _map_update_host(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-host'](f.name, sections)

_map_update_wait_host = _map_update_host
_map_update_host_async = _map_update_host

@classmethod
def _map_update_device(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-device'](f.name, sections)

_map_update_wait_device = _map_update_device
_map_update_device_async = _map_update_device

@classmethod
def _map_release(cls, f, imask=None, devicerm=None):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,56 @@ def test_tasking_unfused_two_locks(self):
assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

def test_tasking_forcefuse(self):
nt = 10
bundle0 = Bundle()
grid = Grid(shape=(10, 10, 10), subdomains=bundle0)

tmp0 = Function(name='tmp0', grid=grid)
tmp1 = Function(name='tmp1', grid=grid)
u = TimeFunction(name='u', grid=grid, save=nt)
v = TimeFunction(name='v', grid=grid, save=nt)
w = TimeFunction(name='w', grid=grid)

eqns = [Eq(w.forward, w + 1),
Eq(tmp0, w.forward),
Eq(tmp1, w.forward),
Eq(u.forward, tmp0, subdomain=bundle0),
Eq(v.forward, tmp1, subdomain=bundle0)]

op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'fuse-tasks': True}))

# Check generated code
assert len(retrieve_iteration_tree(op)) == 5
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert (str(body.body[1].condition) ==
'Ne(lock0[0], 2) | '
'Ne(lock1[0], 2) | '
'Ne(FieldFromComposite(sdata0[wi0]), 1)') # Wait-thread
assert (str(body.body[1].body[0]) ==
'wi0 = (wi0 + 1)%(npthreads0);')
assert str(body.body[2]) == 'sdata0[wi0].time = time;'
assert str(body.body[3]) == 'lock0[0] = 0;' # Set-lock
assert str(body.body[4]) == 'lock1[0] = 0;' # Set-lock
assert str(body.body[5]) == 'sdata0[wi0].flag = 2;'
assert len(op._func_table) == 2
exprs = FindNodes(Expression).visit(op._func_table['copy_device_to_host0'].root)
assert len(exprs) == 22
assert str(exprs[15]) == 'lock0[0] = 1;'
assert str(exprs[16]) == 'lock1[0] = 1;'
assert exprs[17].write is u
assert exprs[18].write is v

op.apply(time_M=nt-2)

assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

@pytest.mark.parametrize('opt', [
('tasking', 'orchestrate'),
('tasking', 'streaming', 'orchestrate'),
Expand Down

0 comments on commit 0483968

Please sign in to comment.