Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Improve lowering of IndexDerivatives #2112

Merged
merged 9 commits into from
Jun 5, 2023
18 changes: 5 additions & 13 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def callback(self, clusters, prefix):
mapper[size][si].add(iaf)

# Construct the ModuloDimensions
mds = OrderedDict()
mds = []
for size, v in mapper.items():
for si, iafs in list(v.items()):
# Offsets are sorted so that the semantic order (t0, t1, t2) follows
Expand All @@ -290,15 +290,10 @@ def callback(self, clusters, prefix):
# sorting offsets {-1, 0, 1} as {0, -1, 1} assigning -inf to 0
siafs = sorted(iafs, key=lambda i: -np.inf if i - si == 0 else (i - si))

# Create the ModuloDimensions. Note that if `size < len(iafs)` then
# the same ModuloDimension may be used for multiple offsets
for iaf in siafs[:size]:
for iaf in siafs:
name = '%s%d' % (si.name, len(mds))
offset = uxreplace(iaf, {si: d.root})
md = ModuloDimension(name, si, offset, size, origin=iaf)

key = lambda i: i.subs(si, 0) % size
mds[md] = [i for i in siafs if key(i) == key(iaf)]
mds.append(ModuloDimension(name, si, offset, size, origin=iaf))

# Replacement rule for ModuloDimensions
def rule(size, e):
Expand All @@ -320,11 +315,8 @@ def rule(size, e):
exprs = c.exprs
groups = as_mapper(mds, lambda d: d.modulo)
for size, v in groups.items():
mapper = {}
for md in v:
mapper.update({i: md for i in mds[md]})

func = partial(xreplace_indices, mapper=mapper, key=partial(rule, size))
subs = {md.origin: md for md in v}
func = partial(xreplace_indices, mapper=subs, key=partial(rule, size))
exprs = [e.apply(func) for e in exprs]

# Augment IterationSpace
Expand Down
13 changes: 11 additions & 2 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,12 @@ def distance(self, other):
# Indexed representing an arbitrary access along `x`, within the `t`
# IterationSpace, while the sink lives within the `tx` IterationSpace
if len(self.itintervals[n:]) != len(other.itintervals[n:]):
ret.append(S.Infinity)
return Vector(*ret)
v = Vector(*ret)
if v != 0:
return v
else:
ret.append(S.Infinity)
return Vector(*ret)

# It still could be an imaginary dependence, e.g. `a[3] -> a[4]` or, more
# nasty, `a[i+1, 3] -> a[i, 4]`
Expand Down Expand Up @@ -562,6 +566,11 @@ def is_lex_equal(self):
"""
return self.source.timestamp == self.sink.timestamp

@cached_property
def is_lex_ne(self):
"""True if the source's and sink's timestamps differ, False otherwise."""
return self.source.timestamp != self.sink.timestamp

@cached_property
def is_lex_negative(self):
"""
Expand Down
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, unevaluate)
generate_macros, minimize_symbols, unevaluate)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -458,6 +458,9 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):
# Extract the necessary macros from the symbolic objects
generate_macros(graph)

# Target-independent optimizations
minimize_symbols(graph)

return graph.root, graph

# Read-only properties exposed to the outside world
Expand Down
83 changes: 63 additions & 20 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from devito.finite_differences import IndexDerivative
from devito.ir import Interval, IterationSpace
from devito.ir import Interval, IterationSpace, Queue
from devito.passes.clusters.misc import fuse
from devito.symbolics import (retrieve_dimensions, reuse_if_untouched, q_leaf,
uxreplace)
Expand All @@ -11,43 +11,44 @@

@timed_pass()
def lower_index_derivatives(clusters, mode=None, **kwargs):
clusters, weights = _lower_index_derivatives(clusters, **kwargs)
clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs)

if not weights:
return clusters

if mode != 'noop':
clusters = fuse(clusters, toposort='maximal')

clusters = CDE(mapper).process(clusters)

return clusters


def _lower_index_derivatives(clusters, sregistry=None, **kwargs):
processed = []
weights = {}
processed = []
mapper = {}

def dump(exprs, c):
if exprs:
processed.append(c.rebuild(exprs=exprs))
exprs[:] = []

for c in clusters:

exprs = []
seen = {}
for e in c.exprs:
expr, v = _lower_index_derivatives_core(e, c, weights, seen, sregistry)
expr, v = _core(e, c, weights, mapper, sregistry)
if v:
dump(exprs, c)
processed.extend(v)
exprs.append(expr)
processed.extend(v)

dump(exprs, c)

return processed, weights
return processed, weights, mapper


def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
def _core(expr, c, weights, mapper, sregistry):
"""
Recursively carry out the core of `lower_index_derivatives`.
"""
Expand All @@ -57,7 +58,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
args = []
processed = []
for a in expr.args:
e, clusters = _lower_index_derivatives_core(a, c, weights, seen, sregistry)
e, clusters = _core(a, c, weights, mapper, sregistry)
args.append(e)
processed.extend(clusters)

Expand All @@ -76,12 +77,6 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
w = weights[k] = w0._rebuild(name=name)
expr = uxreplace(expr, {w0.indexed: w.indexed})

# Have I seen this IndexDerivative already?
try:
return seen[expr], []
except KeyError:
pass

dims = retrieve_dimensions(expr, deep=True)
dims = filter_ordered(d for d in dims if isinstance(d, StencilDimension))

Expand All @@ -91,7 +86,7 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
# upper and lower offsets, we honor it
dims = tuple(d for d in dims if d not in c.ispace)

intervals = [Interval(d, 0, 0) for d in dims]
intervals = [Interval(d) for d in dims]
ispace0 = IterationSpace(intervals)

extra = (c.ispace.itdimensions + dims,)
Expand All @@ -103,13 +98,61 @@ def _lower_index_derivatives_core(expr, c, weights, seen, sregistry):
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))

# Track IndexDerivative to avoid intra-Cluster duplicates
seen[expr] = s

# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights: expr.weights.subs(d, d - d._min) for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
processed.append(c.rebuild(exprs=expr1, ispace=ispace))

# Track lowered IndexDerivative for subsequent optimization by the caller
mapper.setdefault(expr1.rhs, []).append(s)

return s, processed


class CDE(Queue):

"""
Common derivative elimination.
"""

def __init__(self, mapper):
super().__init__()

self.mapper = {k: v for k, v in mapper.items() if len(v) > 1}

def process(self, clusters):
return self._process_fdta(clusters, 1, subs0={}, seen=set())

def callback(self, clusters, prefix, subs0=None, seen=None):
subs = {}
processed = []
for c in clusters:
if c in seen:
processed.append(c)
continue

exprs = []
for e in c.exprs:
k, v = e.args

if k in subs0:
continue

try:
subs0[k] = subs[v]
continue
except KeyError:
pass

if v in self.mapper:
subs[v] = k
exprs.append(e)
else:
exprs.append(uxreplace(e, {**subs0, **subs}))

processed.append(c.rebuild(exprs=exprs))

seen.update(processed)

return processed
63 changes: 47 additions & 16 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter, defaultdict
from itertools import groupby, product

from devito.finite_differences import IndexDerivative
from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass
from devito.ir.support import (SEQUENTIAL, SEPARABLE, Scope, ReleaseLock,
WaitLock, WithLock, FetchUpdate, PrefetchUpdate)
Expand Down Expand Up @@ -145,19 +146,34 @@ def callback(self, cgroups, prefix):
else:
return [ClusterGroup(processed, prefix)]

def _key(self, c):
# Two Clusters/ClusterGroups are fusion candidates if their key is identical
class Key(tuple):

key = (frozenset(c.ispace.itintervals),)
"""
A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that
two Clusters (ClusterGroups) are topo-fusible if and only if their Key is
identical.

A Key contains several elements that can logically be split into two
groups -- the `strict` and the `weak` components of the Key.
Two Clusters (ClusterGroups) having same `strict` but different `weak` parts
are, as by definition, not fusible; however, since at least their `strict`
parts match, they can at least be topologically reordered.
"""

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
key += (any(f._mem_shared for f in c.scope.writes),)
# Same story for reads from thread-shared objects
key += (any(f._mem_shared for f in c.scope.reads),)
def __new__(cls, strict, weak):
obj = super().__new__(cls, strict + weak)
obj.strict = tuple(strict)
obj.weak = tuple(weak)

return obj

def _key(self, c):
strict = []

key += (c.guards if any(c.guards) else None,)
strict.extend([
frozenset(c.ispace.itintervals),
c.guards if any(c.guards) else None
])

# We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and
# WithLocks, but not with any other SyncOps
Expand All @@ -180,13 +196,28 @@ def _key(self, c):
mapper[k].add(type(s))
else:
mapper[k].add(s)
mapper[k] = frozenset(mapper[k])
if any(mapper.values()):
mapper = frozendict(mapper)
key += (mapper,)
if k in mapper:
mapper[k] = frozenset(mapper[k])
strict.append(frozendict(mapper))

weak = []

# Clusters representing HaloTouches should get merged, if possible
key += (c.is_halo_touch,)
weak.append(c.is_halo_touch)

# If there are writes to thread-shared object, make it part of the key.
# This will promote fusion of non-adjacent Clusters writing to (some form of)
# shared memory, which in turn will minimize the number of necessary barriers
# Same story for reads from thread-shared objects
weak.extend([
any(f._mem_shared for f in c.scope.writes),
any(f._mem_shared for f in c.scope.reads)
])

# Promoting adjacency of IndexDerivatives will maximize their reuse
weak.append(any(e.find(IndexDerivative) for e in c.exprs))

key = self.Key(strict, weak)

return key

Expand Down Expand Up @@ -236,7 +267,7 @@ def dump():
def _toposort(self, cgroups, prefix):
# Are there any ClusterGroups that could potentially be fused? If
# not, do not waste time computing a new topological ordering
counter = Counter(self._key(cg) for cg in cgroups)
counter = Counter(self._key(cg).strict for cg in cgroups)
if not any(v > 1 for it, v in counter.most_common()):
return ClusterGroup(cgroups, prefix)

Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _(f, indexeds, tracker, strides, sregistry):

if len(i.indices) == i.function.ndim:
v = tuple(strides.values())[-n:]
subs[i] = FIndexed(i, pname, strides=v)
subs[i] = FIndexed.from_indexed(i, pname, strides=v)
else:
# Honour custom indexing
subs[i] = i.base[sum(i.indices)]
Expand Down
Loading