Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Add optimization option to fuse WithLocks tasks #1736

Merged
merged 4 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', False)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -298,13 +301,13 @@ def callback(f):
'blocking': lambda i: blocking(i, options),
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
9 changes: 6 additions & 3 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', True)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -148,7 +151,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
sregistry = kwargs['sregistry']

# Toposort+Fusion (the former to expose more fusion opportunities)
clusters = fuse(clusters, toposort=True)
clusters = fuse(clusters, toposort=True, options=options)

# Fission to increase parallelism
clusters = fission(clusters)
Expand Down Expand Up @@ -245,13 +248,13 @@ def callback(f):
'streaming': Streaming(reads_if_on_host).process,
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
'MetaCall', 'PointerCast', 'ForeignExpression', 'HaloSpot', 'IterationTree',
'ExpressionBundle', 'AugmentedExpression', 'Increment', 'Return', 'While',
'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot',
'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait',
'CallableBody']
'Pragma', 'PragmaTransfer', 'DummyExpr', 'BlankLine', 'ParallelTree',
'BusyWait', 'CallableBody']

# First-class IET nodes

Expand Down
31 changes: 22 additions & 9 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.ir.clusters import Cluster, ClusterGroup, Queue
Expand Down Expand Up @@ -91,9 +91,13 @@ class Fusion(Queue):
Fuse Clusters with compatible IterationSpace.
"""

def __init__(self, toposort):
super(Fusion, self).__init__()
def __init__(self, toposort, options=None):
options = options or {}

self.toposort = toposort
self.fusetasks = options.get('fuse-tasks', False)

super().__init__()

def _make_key_hook(self, cgroup, level):
assert level > 0
Expand Down Expand Up @@ -137,15 +141,24 @@ def _key(self, c):

key = (frozenset(c.itintervals), c.guards)

# We allow fusing Clusters/ClusterGroups with WaitLocks over different Locks,
# while the WithLocks are to be kept separated (i.e. the remain separate tasks)
# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
if isinstance(c, Cluster):
sync_locks = (c.sync_locks,)
else:
sync_locks = c.sync_locks
for i in sync_locks:
key += (frozendict({k: frozenset(type(i) if i.is_WaitLock else i for i in v)
for k, v in i.items()}),)
mapper = defaultdict(set)
for k, v in i.items():
for s in v:
if s.is_WaitLock or \
(self.fusetasks and s.is_WithLock):
mapper[k].add(type(s))
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
mapper = frozendict(mapper)
key += (mapper,)

return key

Expand Down Expand Up @@ -243,14 +256,14 @@ def _build_dag(self, cgroups, prefix):


@timed_pass()
def fuse(clusters, toposort=False):
def fuse(clusters, toposort=False, options=None):
"""
Clusters fusion.

If ``toposort=True``, then the Clusters are reordered to maximize the likelihood
of fusion; the new ordering is computed such that all data dependencies are honored.
"""
return Fusion(toposort=toposort).process(clusters)
return Fusion(toposort, options).process(clusters)


@cluster_pass(mode='all')
Expand Down
15 changes: 11 additions & 4 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def _map_present(cls, f, imask=None):
"""
raise NotImplementedError

@classmethod
def _map_wait(cls, queueid=None):
"""
Explicitly wait on event.
"""
raise NotImplementedError

@classmethod
def _map_update(cls, f, imask=None):
"""
Expand All @@ -86,9 +93,9 @@ def _map_update_host(cls, f, imask=None, queueid=None):
raise NotImplementedError

@classmethod
def _map_update_wait_host(cls, f, imask=None, queueid=None):
def _map_update_host_async(cls, f, imask=None, queueid=None):
"""
Copy Function from device to host memory and explicitly wait.
Asynchronously copy Function from device to host memory.
"""
raise NotImplementedError

Expand All @@ -100,9 +107,9 @@ def _map_update_device(cls, f, imask=None, queueid=None):
raise NotImplementedError

@classmethod
def _map_update_wait_device(cls, f, imask=None, queueid=None):
def _map_update_device_async(cls, f, imask=None, queueid=None):
"""
Copy Function from host to device memory and explicitly wait.
Asynchronously copy Function from host to device memory and explicitly wait.
"""
raise NotImplementedError

Expand Down
20 changes: 10 additions & 10 deletions devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ class AccBB(PragmaLangBB):
c.Pragma('acc enter data create(%s%s)' % (i, j)),
'map-present': lambda i, j:
c.Pragma('acc data present(%s%s)' % (i, j)),
'map-wait': lambda i:
c.Pragma('acc wait(%s)' % i),
'map-update': lambda i, j:
c.Pragma('acc exit data copyout(%s%s)' % (i, j)),
'map-update-host': lambda i, j:
c.Pragma('acc update self(%s%s)' % (i, j)),
'map-update-wait-host': lambda i, j, k:
(c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
c.Pragma('acc wait(%s)' % k)),
'map-update-host-async': lambda i, j, k:
c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)),
'map-update-device': lambda i, j:
c.Pragma('acc update device(%s%s)' % (i, j)),
'map-update-wait-device': lambda i, j, k:
(c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
c.Pragma('acc wait(%s)' % k)),
'map-update-device-async': lambda i, j, k:
c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)),
'map-release': lambda i, j, k:
c.Pragma('acc exit data delete(%s%s)%s' % (i, j, k)),
'map-exit-delete': lambda i, j, k:
Expand Down Expand Up @@ -147,14 +147,14 @@ def _map_delete(cls, f, imask=None, devicerm=None):
return cls.mapper['map-exit-delete'](f.name, sections, cond)

@classmethod
def _map_update_wait_host(cls, f, imask=None, queueid=None):
def _map_update_host_async(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-wait-host'](f.name, sections, queueid)
return cls.mapper['map-update-host-async'](f.name, sections, queueid)

@classmethod
def _map_update_wait_device(cls, f, imask=None, queueid=None):
def _map_update_device_async(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-wait-device'](f.name, sections, queueid)
return cls.mapper['map-update-device-async'](f.name, sections, queueid)


class DeviceAccizer(PragmaDeviceAwareTransformer):
Expand Down
29 changes: 18 additions & 11 deletions devito/passes/iet/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from devito.data import FULL
from devito.ir.iet import (Call, Callable, Conditional, List, SyncSpot, FindNodes,
Transformer, BlankLine, BusyWait, PragmaTransfer,
Transformer, BlankLine, BusyWait, Pragma, PragmaTransfer,
DummyExpr, derive_parameters, make_thread_ctx)
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.langbase import LangBB
Expand Down Expand Up @@ -54,17 +54,21 @@ def _make_withlock(self, iet, sync_ops, pieces, root):
# will never be more than 2 threads in flight concurrently
npthreads = min(i.size for i in locks)

preactions = []
postactions = []
preactions = [BlankLine]
for s in sync_ops:
imask = [s.handle.indices[d] if d.root in s.lock.locked_dimensions else FULL
for d in s.target.dimensions]
update = PragmaTransfer(self.lang._map_update_wait_host, s.target,
update = PragmaTransfer(self.lang._map_update_host_async, s.target,
imask=imask, queueid=SharedData._field_id)
preactions.append(List(body=[BlankLine, update, DummyExpr(s.handle, 1)]))
postactions.append(DummyExpr(s.handle, 2))
preactions.append(update)
wait = self.lang._map_wait(SharedData._field_id)
if wait is not None:
preactions.append(Pragma(wait))
preactions.extend([DummyExpr(s.handle, 1) for s in sync_ops])
preactions.append(BlankLine)
postactions.insert(0, BlankLine)

postactions = [BlankLine]
postactions.extend([DummyExpr(s.handle, 2) for s in sync_ops])

# Turn `iet` into a ThreadFunction so that it can be executed
# asynchronously by a pthread in the `npthreads` pool
Expand Down Expand Up @@ -120,7 +124,7 @@ def _make_fetchupdate(self, iet, sync_ops, pieces, *args):
def _make_prefetchupdate(self, iet, sync_ops, pieces, root):
fid = SharedData._field_id

postactions = []
postactions = [BlankLine]
for s in sync_ops:
# `pcond` is not None, but we won't use it here because the condition
# is actually already encoded in `iet` itself (it stems from the
Expand All @@ -129,8 +133,11 @@ def _make_prefetchupdate(self, iet, sync_ops, pieces, root):

imask = [(s.tstore, s.size) if d.root is s.dim.root else FULL
for d in s.dimensions]
postactions.append(PragmaTransfer(self.lang._map_update_wait_device,
postactions.append(PragmaTransfer(self.lang._map_update_device_async,
s.target, imask=imask, queueid=fid))
wait = self.lang._map_wait(fid)
if wait is not None:
postactions.append(Pragma(wait))

# Turn prefetch IET into a ThreadFunction
name = self.sregistry.make_name(prefix='prefetch_host_to_device')
Expand All @@ -156,8 +163,8 @@ def _make_waitprefetch(self, iet, sync_ops, pieces, *args):
ff = SharedData._field_flag

waits = []
for s in sync_ops:
sdata, threads = pieces.objs.get(s)
objs = filter_ordered(pieces.objs.get(s) for s in sync_ops)
for sdata, threads in objs:
wait = BusyWait(CondNe(FieldFromComposite(ff, sdata[threads.index]), 1))
waits.append(wait)

Expand Down
12 changes: 10 additions & 2 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,14 @@ def _map_alloc(cls, f, imask=None):
def _map_present(cls, f, imask=None):
return

@classmethod
def _map_wait(cls, queueid=None):
try:
return cls.mapper['map-wait'](queueid)
except KeyError:
# Not all languages may provide an explicit wait construct
return None

@classmethod
def _map_update(cls, f, imask=None):
sections = cls._make_sections_from_imask(f, imask)
Expand All @@ -546,14 +554,14 @@ def _map_update_host(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-host'](f.name, sections)

_map_update_wait_host = _map_update_host
_map_update_host_async = _map_update_host

@classmethod
def _map_update_device(cls, f, imask=None, queueid=None):
sections = cls._make_sections_from_imask(f, imask)
return cls.mapper['map-update-device'](f.name, sections)

_map_update_wait_device = _map_update_device
_map_update_device_async = _map_update_device

@classmethod
def _map_release(cls, f, imask=None, devicerm=None):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,56 @@ def test_tasking_unfused_two_locks(self):
assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

def test_tasking_forcefuse(self):
nt = 10
bundle0 = Bundle()
grid = Grid(shape=(10, 10, 10), subdomains=bundle0)

tmp0 = Function(name='tmp0', grid=grid)
tmp1 = Function(name='tmp1', grid=grid)
u = TimeFunction(name='u', grid=grid, save=nt)
v = TimeFunction(name='v', grid=grid, save=nt)
w = TimeFunction(name='w', grid=grid)

eqns = [Eq(w.forward, w + 1),
Eq(tmp0, w.forward),
Eq(tmp1, w.forward),
Eq(u.forward, tmp0, subdomain=bundle0),
Eq(v.forward, tmp1, subdomain=bundle0)]

op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'fuse-tasks': True}))

# Check generated code
assert len(retrieve_iteration_tree(op)) == 5
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert (str(body.body[1].condition) ==
'Ne(lock0[0], 2) | '
'Ne(lock1[0], 2) | '
'Ne(FieldFromComposite(sdata0[wi0]), 1)') # Wait-thread
assert (str(body.body[1].body[0]) ==
'wi0 = (wi0 + 1)%(npthreads0);')
assert str(body.body[2]) == 'sdata0[wi0].time = time;'
assert str(body.body[3]) == 'lock0[0] = 0;' # Set-lock
assert str(body.body[4]) == 'lock1[0] = 0;' # Set-lock
assert str(body.body[5]) == 'sdata0[wi0].flag = 2;'
assert len(op._func_table) == 2
exprs = FindNodes(Expression).visit(op._func_table['copy_device_to_host0'].root)
assert len(exprs) == 22
assert str(exprs[15]) == 'lock0[0] = 1;'
assert str(exprs[16]) == 'lock1[0] = 1;'
assert exprs[17].write is u
assert exprs[18].write is v

op.apply(time_M=nt-2)

assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

@pytest.mark.parametrize('opt', [
('tasking', 'orchestrate'),
('tasking', 'streaming', 'orchestrate'),
Expand Down