Skip to content

Commit

Permalink
compiler: Add fuse-tasks optimization option
Browse files Browse the repository at this point in the history
  • Loading branch information
dummy committed Aug 12, 2021
1 parent b49845d commit f114164
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 14 deletions.
7 changes: 5 additions & 2 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', False)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -298,13 +301,13 @@ def callback(f):
'blocking': lambda i: blocking(i, options),
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
9 changes: 6 additions & 3 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def _normalize_kwargs(cls, **kwargs):
# Buffering
o['buf-async-degree'] = oo.pop('buf-async-degree', None)

# Fusion
o['fuse-tasks'] = oo.pop('fuse-tasks', False)

# Blocking
o['blockinner'] = oo.pop('blockinner', True)
o['blocklevels'] = oo.pop('blocklevels', cls.BLOCK_LEVELS)
Expand Down Expand Up @@ -148,7 +151,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
sregistry = kwargs['sregistry']

# Toposort+Fusion (the former to expose more fusion opportunities)
clusters = fuse(clusters, toposort=True)
clusters = fuse(clusters, toposort=True, options=options)

# Fission to increase parallelism
clusters = fission(clusters)
Expand Down Expand Up @@ -245,13 +248,13 @@ def callback(f):
'streaming': Streaming(reads_if_on_host).process,
'factorize': factorize,
'fission': fission,
'fuse': fuse,
'fuse': lambda i: fuse(i, options=options),
'lift': lambda i: Lift().process(cire(i, 'invariants', sregistry,
options, platform)),
'cire-sops': lambda i: cire(i, 'sops', sregistry, options, platform),
'cse': lambda i: cse(i, sregistry),
'opt-pows': optimize_pows,
'topofuse': lambda i: fuse(i, toposort=True)
'topofuse': lambda i: fuse(i, toposort=True, options=options)
}

@classmethod
Expand Down
31 changes: 22 additions & 9 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.ir.clusters import Cluster, ClusterGroup, Queue
Expand Down Expand Up @@ -91,9 +91,13 @@ class Fusion(Queue):
Fuse Clusters with compatible IterationSpace.
"""

def __init__(self, toposort):
super(Fusion, self).__init__()
def __init__(self, toposort, options=None):
options = options or {}

self.toposort = toposort
self.fusetasks = options.get('fuse-tasks', False)

super().__init__()

def _make_key_hook(self, cgroup, level):
assert level > 0
Expand Down Expand Up @@ -137,15 +141,24 @@ def _key(self, c):

key = (frozenset(c.itintervals), c.guards)

# We allow fusing Clusters/ClusterGroups with WaitLocks over different Locks,
# while the WithLocks are to be kept separated (i.e. the remain separate tasks)
# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
if isinstance(c, Cluster):
sync_locks = (c.sync_locks,)
else:
sync_locks = c.sync_locks
for i in sync_locks:
key += (frozendict({k: frozenset(type(i) if i.is_WaitLock else i for i in v)
for k, v in i.items()}),)
mapper = defaultdict(set)
for k, v in i.items():
for s in v:
if s.is_WaitLock or \
(self.fusetasks and s.is_WithLock):
mapper[k].add(type(s))
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
mapper = frozendict(mapper)
key += (mapper,)

return key

Expand Down Expand Up @@ -243,14 +256,14 @@ def _build_dag(self, cgroups, prefix):


@timed_pass()
def fuse(clusters, toposort=False):
def fuse(clusters, toposort=False, options=None):
"""
Clusters fusion.
If ``toposort=True``, then the Clusters are reordered to maximize the likelihood
of fusion; the new ordering is computed such that all data dependencies are honored.
"""
return Fusion(toposort=toposort).process(clusters)
return Fusion(toposort, options).process(clusters)


@cluster_pass(mode='all')
Expand Down
50 changes: 50 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,56 @@ def test_tasking_unfused_two_locks(self):
assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

def test_tasking_forcefuse(self):
nt = 10
bundle0 = Bundle()
grid = Grid(shape=(10, 10, 10), subdomains=bundle0)

tmp0 = Function(name='tmp0', grid=grid)
tmp1 = Function(name='tmp1', grid=grid)
u = TimeFunction(name='u', grid=grid, save=nt)
v = TimeFunction(name='v', grid=grid, save=nt)
w = TimeFunction(name='w', grid=grid)

eqns = [Eq(w.forward, w + 1),
Eq(tmp0, w.forward),
Eq(tmp1, w.forward),
Eq(u.forward, tmp0, subdomain=bundle0),
Eq(v.forward, tmp1, subdomain=bundle0)]

op = Operator(eqns, opt=('tasking', 'fuse', 'orchestrate', {'fuse-tasks': True}))

# Check generated code
assert len(retrieve_iteration_tree(op)) == 5
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 2
sections = FindNodes(Section).visit(op)
assert len(sections) == 3
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert (str(body.body[1].condition) ==
'Ne(lock0[0], 2) | '
'Ne(lock1[0], 2) | '
'Ne(FieldFromComposite(sdata0[wi0]), 1)') # Wait-thread
assert (str(body.body[1].body[0]) ==
'wi0 = (wi0 + 1)%(npthreads0);')
assert str(body.body[2]) == 'sdata0[wi0].time = time;'
assert str(body.body[3]) == 'lock0[0] = 0;' # Set-lock
assert str(body.body[4]) == 'lock1[0] = 0;' # Set-lock
assert str(body.body[5]) == 'sdata0[wi0].flag = 2;'
assert len(op._func_table) == 2
exprs = FindNodes(Expression).visit(op._func_table['copy_device_to_host0'].root)
assert len(exprs) == 22
assert str(exprs[15]) == 'lock0[0] = 1;'
assert str(exprs[16]) == 'lock1[0] = 1;'
assert exprs[17].write is u
assert exprs[18].write is v

op.apply(time_M=nt-2)

assert np.all(u.data[nt-1] == 9)
assert np.all(v.data[nt-1] == 9)

@pytest.mark.parametrize('opt', [
('tasking', 'orchestrate'),
('tasking', 'streaming', 'orchestrate'),
Expand Down

0 comments on commit f114164

Please sign in to comment.