Skip to content

Commit

Permalink
Merge pull request #2225 from devitocodes/fix-timederiv-ctf
Browse files Browse the repository at this point in the history
compiler: Rework multi-level buffering
  • Loading branch information
FabioLuporini authored Oct 6, 2023
2 parents 2d355c2 + 82e47f9 commit e6cd0b0
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 84 deletions.
29 changes: 23 additions & 6 deletions devito/ir/support/syncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,16 @@ def __repr__(self):

@property
def imask(self):
ret = [self.handle.indices[d] if d.root in self.lock.locked_dimensions else FULL
for d in self.target.dimensions]
return IMask(*ret, getters=self.target.dimensions, function=self.function,
ret = []
for d in self.target.dimensions:
if d.root in self.lock.locked_dimensions:
ret.append(self.handle.indices[d])
else:
ret.append(FULL)

return IMask(*ret,
getters=self.target.dimensions,
function=self.function,
findex=self.findex)


Expand All @@ -81,9 +88,19 @@ def __repr__(self):

@property
def imask(self):
ret = [(self.tindex, self.size) if d.root is self.dim.root else FULL
for d in self.target.dimensions]
return IMask(*ret, getters=self.target.dimensions, function=self.function,
ret = []
for d in self.target.dimensions:
if d.root is self.dim.root:
if self.target.is_regular:
ret.append((self.tindex, self.size))
else:
ret.append((0, 1))
else:
ret.append(FULL)

return IMask(*ret,
getters=self.target.dimensions,
function=self.function,
findex=self.findex)


Expand Down
23 changes: 8 additions & 15 deletions devito/passes/clusters/asynchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from devito.ir import (Forward, GuardBoundNext, Queue, Vector, WaitLock, WithLock,
FetchUpdate, PrefetchUpdate, ReleaseLock, normalize_syncs)
from devito.symbolics import uxreplace
from devito.passes.clusters.utils import is_memcpy
from devito.symbolics import IntDiv, uxreplace
from devito.tools import OrderedSet, is_integer, timed_pass
from devito.types import CustomDimension, Lock

Expand Down Expand Up @@ -125,7 +126,7 @@ def callback(self, clusters, prefix):
assert lock.size == 1
indices = [0]

if is_memcpy(c0):
if wraps_memcpy(c0):
e = c0.exprs[0]
function = e.lhs.function
findex = e.lhs.indices[d]
Expand Down Expand Up @@ -177,7 +178,7 @@ def callback(self, clusters, prefix):
for c in clusters:
dims = self.key(c)
if d._defines & dims:
if is_memcpy(c):
if wraps_memcpy(c):
# Case 1A (special case, leading to more efficient streaming)
self._actions_from_init(c, prefix, actions)
else:
Expand All @@ -186,7 +187,7 @@ def callback(self, clusters, prefix):

# Case 2
else:
mapper = OrderedDict([(c, is_memcpy(c)) for c in clusters
mapper = OrderedDict([(c, wraps_memcpy(c)) for c in clusters
if d in self.key(c)])

# Case 2A (special case, leading to more efficient streaming)
Expand Down Expand Up @@ -257,7 +258,7 @@ def _actions_from_update_memcpy(self, cluster, clusters, prefix, actions):

# If fetching into e.g. `ub[sb1]` we'll need to prefetch into e.g. `ub[sb0]`
tindex0 = e.lhs.indices[d]
if is_integer(tindex0):
if is_integer(tindex0) or isinstance(tindex0, IntDiv):
tindex = tindex0
else:
assert tindex0.is_Modulo
Expand Down Expand Up @@ -321,16 +322,8 @@ def __init__(self, drop=False, syncs=None, insert=None):
self.insert = insert or []


def is_memcpy(cluster):
"""
True if `cluster` emulates a memcpy involving an Array, False otherwise.
"""
def wraps_memcpy(cluster):
if len(cluster.exprs) != 1:
return False

a, b = cluster.exprs[0].args

if not (a.is_Indexed and b.is_Indexed):
return False

return a.function.is_Array or b.function.is_Array
return is_memcpy(cluster.exprs[0])
170 changes: 110 additions & 60 deletions devito/passes/clusters/buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
lower_exprs, vmax, vmin)
from devito.exceptions import InvalidOperator
from devito.logger import warning
from devito.passes.clusters.utils import is_memcpy
from devito.symbolics import IntDiv, retrieve_function_carriers, uxreplace
from devito.tools import (Bunch, DefaultOrderedDict, Stamp, as_tuple,
filter_ordered, flatten, is_integer, timed_pass)
Expand Down Expand Up @@ -175,6 +176,12 @@ def callback(self, clusters, prefix, cache=None):
if b.size == 1 and not init_onread(b.function):
continue

# Special case: avoid initialization in the case of double
# (or multiple levels of) buffering because it will have been
# already performed
if b.size > 1 and b.multi_buffering:
continue

dims = b.function.dimensions
lhs = b.indexed[dims]._subs(dim, b.firstidx.b)
rhs = b.function[dims]._subs(dim, b.firstidx.f)
Expand Down Expand Up @@ -347,64 +354,14 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry):
self.sub_iterators = defaultdict(list)
self.subdims_mapper = DefaultOrderedDict(set)

# Create the necessary ModuloDimensions for indexing into the buffer
# E.g., `u[time,x] + u[time+1,x] -> `ub[sb0,x] + ub[sb1,x]`, where `sb0`
# and `sb1` are ModuloDimensions starting at `time` and `time+1` respectively
dims = list(function.dimensions)
assert dim in function.dimensions

# Determine the buffer size, and therefore the span of the ModuloDimension,
# along the contracting Dimension `d`
indices = filter_ordered(i.indices[dim] for i in accessv.accesses)
slots = [i.subs({dim: 0, dim.spacing: 1}) for i in indices]
try:
size = max(slots) - min(slots) + 1
except TypeError:
# E.g., special case `slots=[-1 + time/factor, 2 + time/factor]`
# Resort to the fast vector-based comparison machinery (rather than
# the slower sympy.simplify)
slots = [Vector(i) for i in slots]
size = int((vmax(*slots) - vmin(*slots) + 1)[0])

if async_degree is not None:
if async_degree < size:
warning("Ignoring provided asynchronous degree as it'd be "
"too small for the required buffer (provided %d, "
"but need at least %d for `%s`)"
% (async_degree, size, function.name))
else:
size = async_degree

# Create `xd` -- a contraction Dimension for `dim`
try:
xd = sregistry.get('xds', (dim, size))
except KeyError:
name = sregistry.make_name(prefix='db')
v = CustomDimension(name, 0, size-1, size, dim)
xd = sregistry.setdefault('xds', (dim, size), v)
self.xd = dims[dims.index(dim)] = xd

# Finally create the ModuloDimensions as children of `xd`
if size > 1:
# Note: indices are sorted so that the semantic order (sb0, sb1, sb2)
# follows SymPy's index ordering (time, time-1, time+1) after modulo
# replacement, so that associativity errors are consistent. This very
# same strategy is also applied in clusters/algorithms/Stepper
p, _ = offset_from_centre(d, indices)
indices = sorted(indices,
key=lambda i: -np.inf if i - p == 0 else (i - p))
for i in indices:
try:
md = sregistry.get('mds', (xd, i))
except KeyError:
name = sregistry.make_name(prefix='sb')
v = ModuloDimension(name, xd, i, size)
md = sregistry.setdefault('mds', (xd, i), v)
self.index_mapper[i] = md
self.sub_iterators[d.root].append(md)
# Initialize the buffer metadata. Depending on whether it's multi-level
# buffering (e.g., double buffering) or first-level, we need to perform
# different actions. Multi-level is trivial, because it essentially
# inherits metadata from the previous buffering level
if self.multi_buffering:
self.__init_multi_buffering__()
else:
assert len(indices) == 1
self.index_mapper[indices[0]] = 0
self.__init_firstlevel_buffering__(async_degree, sregistry)

# Track the SubDimensions used to index into `function`
for e in accessv.mapper:
Expand Down Expand Up @@ -444,6 +401,11 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry):
for i in d0._defines:
self.itintervals_mapper.setdefault(i, (interval.relaxed, (), Forward))

# The buffer dimensions
dims = list(function.dimensions)
assert dim in function.dimensions
dims[dims.index(dim)] = self.xd

# Finally create the actual buffer
if function in cache:
self.buffer = cache[function]
Expand All @@ -462,6 +424,85 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry):
except TypeError:
self.buffer = cache[function] = Array(**kwargs)

def __init_multi_buffering__(self):
try:
expr, = self.accessv.exprs
except ValueError:
assert False

lhs, rhs = expr.args

self.xd = lhs.function.indices[self.dim]

idx0 = lhs.indices[self.dim]
idx1 = rhs.indices[self.dim]

if self.is_read:
if is_integer(idx0) or isinstance(idx0, ModuloDimension):
# This is just for aesthetics of the generated code
self.index_mapper[idx1] = 0
else:
self.index_mapper[idx1] = idx1
else:
self.index_mapper[idx0] = idx1

def __init_firstlevel_buffering__(self, async_degree, sregistry):
d = self.d
dim = self.dim
function = self.function

indices = filter_ordered(i.indices[dim] for i in self.accessv.accesses)
slots = [i.subs({dim: 0, dim.spacing: 1}) for i in indices]

try:
size = max(slots) - min(slots) + 1
except TypeError:
# E.g., special case `slots=[-1 + time/factor, 2 + time/factor]`
# Resort to the fast vector-based comparison machinery (rather than
# the slower sympy.simplify)
slots = [Vector(i) for i in slots]
size = int((vmax(*slots) - vmin(*slots) + 1)[0])

if async_degree is not None:
if async_degree < size:
warning("Ignoring provided asynchronous degree as it'd be "
"too small for the required buffer (provided %d, "
"but need at least %d for `%s`)"
% (async_degree, size, function.name))
else:
size = async_degree

# Create `xd` -- a contraction Dimension for `dim`
try:
xd = sregistry.get('xds', (dim, size))
except KeyError:
name = sregistry.make_name(prefix='db')
v = CustomDimension(name, 0, size-1, size, dim)
xd = sregistry.setdefault('xds', (dim, size), v)
self.xd = xd

# Create the ModuloDimensions to step through the buffer
if size > 1:
# Note: indices are sorted so that the semantic order (sb0, sb1, sb2)
# follows SymPy's index ordering (time, time-1, time+1) after modulo
# replacement, so that associativity errors are consistent. This very
# same strategy is also applied in clusters/algorithms/Stepper
p, _ = offset_from_centre(d, indices)
indices = sorted(indices,
key=lambda i: -np.inf if i - p == 0 else (i - p))
for i in indices:
try:
md = sregistry.get('mds', (xd, i))
except KeyError:
name = sregistry.make_name(prefix='sb')
v = ModuloDimension(name, xd, i, size)
md = sregistry.setdefault('mds', (xd, i), v)
self.index_mapper[i] = md
self.sub_iterators[d.root].append(md)
else:
assert len(indices) == 1
self.index_mapper[indices[0]] = 0

def __repr__(self):
return "Buffer[%s,<%s>]" % (self.buffer.name, self.xd)

Expand Down Expand Up @@ -497,6 +538,13 @@ def is_writeonly(self):
def has_uniform_subdims(self):
return self.subdims_mapper is not None

@property
def multi_buffering(self):
"""
True if double-buffering or more, False otherwise.
"""
return all(is_memcpy(e) for e in self.accessv.exprs)

@cached_property
def indexed(self):
return self.buffer.indexed
Expand All @@ -517,7 +565,7 @@ def writeto(self):
# in principle this could be accessed through a stencil
interval = Interval(i.dim, -h.left, h.right, i.stamp)
except KeyError:
assert d is self.xd
assert d in self.xd._defines
interval, si, direction = Interval(d), (), Forward
intervals.append(interval)
sub_iterators[d] = si
Expand Down Expand Up @@ -550,6 +598,8 @@ def written(self):
sub_iterators[d] = si + as_tuple(self.sub_iterators[d])
directions[d] = direction

directions[d.root] = direction

relations = (tuple(i.dim for i in intervals),)
intervals = IntervalGroup(intervals, relations=relations)

Expand All @@ -572,8 +622,8 @@ def readfrom(self):
@cached_property
def lastidx(self):
"""
A 2-tuple of indices representing, respectively, the "last" write to the
buffer and the "last" read from the buffered Function. For example,
A 2-tuple of indices representing, respectively, the *last* write to the
buffer and the *last* read from the buffered Function. For example,
`(sb1, time+1)` in the case of a forward-propagating `time` Dimension.
"""
try:
Expand Down
8 changes: 7 additions & 1 deletion devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def callback(self, clusters, prefix):
# No iteration space to be lifted from
return clusters

hope_invariant = prefix[-1].dim._defines
dim = prefix[-1].dim
hope_invariant = dim._defines
outer = set().union(*[i.dim._defines for i in prefix[:-1]])

lifted = []
Expand All @@ -43,6 +44,11 @@ def callback(self, clusters, prefix):
processed.append(c)
continue

# Synchronization operations prevent lifting
if c.syncs.get(dim):
processed.append(c)
continue

# Is `c` a real candidate -- is there at least one invariant Dimension?
if any(d._defines & hope_invariant for d in c.used_dimensions):
processed.append(c)
Expand Down
14 changes: 13 additions & 1 deletion devito/passes/clusters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from devito.symbolics import uxreplace
from devito.types import Symbol, Wildcard

__all__ = ['makeit_ssa']
__all__ = ['makeit_ssa', 'is_memcpy']


def makeit_ssa(exprs):
Expand Down Expand Up @@ -36,3 +36,15 @@ def makeit_ssa(exprs):
else:
processed.append(e.func(e.lhs, rhs))
return processed


def is_memcpy(expr):
"""
True if `expr` implements a memcpy involving an Array, False otherwise.
"""
a, b = expr.args

if not (a.is_Indexed and b.is_Indexed):
return False

return a.function.is_Array or b.function.is_Array
4 changes: 4 additions & 0 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,10 @@ def __init_finalize__(self, name, symbolic_min=None, symbolic_max=None,
def is_Derived(self):
return self._parent is not None

@property
def is_NonlinearDerived(self):
return self.is_Derived and self.parent.is_NonlinearDerived

@property
def parent(self):
return self._parent
Expand Down
Loading

0 comments on commit e6cd0b0

Please sign in to comment.