-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Rework multi-level buffering #2225
Changes from all commits
0feae10
c2c48d0
ae68e5e
77a283e
59b57c8
82e47f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,9 +66,16 @@ def __repr__(self): | |
|
||
@property | ||
def imask(self): | ||
ret = [self.handle.indices[d] if d.root in self.lock.locked_dimensions else FULL | ||
for d in self.target.dimensions] | ||
return IMask(*ret, getters=self.target.dimensions, function=self.function, | ||
ret = [] | ||
for d in self.target.dimensions: | ||
if d.root in self.lock.locked_dimensions: | ||
ret.append(self.handle.indices[d]) | ||
else: | ||
ret.append(FULL) | ||
|
||
return IMask(*ret, | ||
getters=self.target.dimensions, | ||
function=self.function, | ||
findex=self.findex) | ||
|
||
|
||
|
@@ -81,9 +88,19 @@ def __repr__(self): | |
|
||
@property | ||
def imask(self): | ||
ret = [(self.tindex, self.size) if d.root is self.dim.root else FULL | ||
for d in self.target.dimensions] | ||
return IMask(*ret, getters=self.target.dimensions, function=self.function, | ||
ret = [] | ||
for d in self.target.dimensions: | ||
if d.root is self.dim.root: | ||
if self.target.is_regular: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added this; code path excited via PRO |
||
ret.append((self.tindex, self.size)) | ||
else: | ||
ret.append((0, 1)) | ||
else: | ||
ret.append(FULL) | ||
|
||
return IMask(*ret, | ||
getters=self.target.dimensions, | ||
function=self.function, | ||
findex=self.findex) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
|
||
from devito.ir import (Forward, GuardBoundNext, Queue, Vector, WaitLock, WithLock, | ||
FetchUpdate, PrefetchUpdate, ReleaseLock, normalize_syncs) | ||
from devito.symbolics import uxreplace | ||
from devito.passes.clusters.utils import is_memcpy | ||
from devito.symbolics import IntDiv, uxreplace | ||
from devito.tools import OrderedSet, is_integer, timed_pass | ||
from devito.types import CustomDimension, Lock | ||
|
||
|
@@ -125,7 +126,7 @@ def callback(self, clusters, prefix): | |
assert lock.size == 1 | ||
indices = [0] | ||
|
||
if is_memcpy(c0): | ||
if wraps_memcpy(c0): | ||
e = c0.exprs[0] | ||
function = e.lhs.function | ||
findex = e.lhs.indices[d] | ||
|
@@ -177,7 +178,7 @@ def callback(self, clusters, prefix): | |
for c in clusters: | ||
dims = self.key(c) | ||
if d._defines & dims: | ||
if is_memcpy(c): | ||
if wraps_memcpy(c): | ||
# Case 1A (special case, leading to more efficient streaming) | ||
self._actions_from_init(c, prefix, actions) | ||
else: | ||
|
@@ -186,7 +187,7 @@ def callback(self, clusters, prefix): | |
|
||
# Case 2 | ||
else: | ||
mapper = OrderedDict([(c, is_memcpy(c)) for c in clusters | ||
mapper = OrderedDict([(c, wraps_memcpy(c)) for c in clusters | ||
if d in self.key(c)]) | ||
|
||
# Case 2A (special case, leading to more efficient streaming) | ||
|
@@ -257,7 +258,7 @@ def _actions_from_update_memcpy(self, cluster, clusters, prefix, actions): | |
|
||
# If fetching into e.g. `ub[sb1]` we'll need to prefetch into e.g. `ub[sb0]` | ||
tindex0 = e.lhs.indices[d] | ||
if is_integer(tindex0): | ||
if is_integer(tindex0) or isinstance(tindex0, IntDiv): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for that you end up in the what you see here isn't stuff that the user has written, but rather by-product expressions of other passes such as buffering |
||
tindex = tindex0 | ||
else: | ||
assert tindex0.is_Modulo | ||
|
@@ -321,16 +322,8 @@ def __init__(self, drop=False, syncs=None, insert=None): | |
self.insert = insert or [] | ||
|
||
|
||
def is_memcpy(cluster): | ||
""" | ||
True if `cluster` emulates a memcpy involving an Array, False otherwise. | ||
""" | ||
def wraps_memcpy(cluster): | ||
if len(cluster.exprs) != 1: | ||
return False | ||
|
||
a, b = cluster.exprs[0].args | ||
|
||
if not (a.is_Indexed and b.is_Indexed): | ||
return False | ||
|
||
return a.function.is_Array or b.function.is_Array | ||
return is_memcpy(cluster.exprs[0]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
lower_exprs, vmax, vmin) | ||
from devito.exceptions import InvalidOperator | ||
from devito.logger import warning | ||
from devito.passes.clusters.utils import is_memcpy | ||
from devito.symbolics import IntDiv, retrieve_function_carriers, uxreplace | ||
from devito.tools import (Bunch, DefaultOrderedDict, Stamp, as_tuple, | ||
filter_ordered, flatten, is_integer, timed_pass) | ||
|
@@ -175,6 +176,12 @@ def callback(self, clusters, prefix, cache=None): | |
if b.size == 1 and not init_onread(b.function): | ||
continue | ||
|
||
# Special case: avoid initialization in the case of double | ||
# (or multiple levels of) buffering because it will have been | ||
# already performed | ||
if b.size > 1 and b.multi_buffering: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. code path excited via PRO |
||
continue | ||
|
||
dims = b.function.dimensions | ||
lhs = b.indexed[dims]._subs(dim, b.firstidx.b) | ||
rhs = b.function[dims]._subs(dim, b.firstidx.f) | ||
|
@@ -347,64 +354,14 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry): | |
self.sub_iterators = defaultdict(list) | ||
self.subdims_mapper = DefaultOrderedDict(set) | ||
|
||
# Create the necessary ModuloDimensions for indexing into the buffer | ||
# E.g., `u[time,x] + u[time+1,x] -> `ub[sb0,x] + ub[sb1,x]`, where `sb0` | ||
# and `sb1` are ModuloDimensions starting at `time` and `time+1` respectively | ||
dims = list(function.dimensions) | ||
assert dim in function.dimensions | ||
|
||
# Determine the buffer size, and therefore the span of the ModuloDimension, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. None of this is gone; simply moved inside |
||
# along the contracting Dimension `d` | ||
indices = filter_ordered(i.indices[dim] for i in accessv.accesses) | ||
slots = [i.subs({dim: 0, dim.spacing: 1}) for i in indices] | ||
try: | ||
size = max(slots) - min(slots) + 1 | ||
except TypeError: | ||
# E.g., special case `slots=[-1 + time/factor, 2 + time/factor]` | ||
# Resort to the fast vector-based comparison machinery (rather than | ||
# the slower sympy.simplify) | ||
slots = [Vector(i) for i in slots] | ||
size = int((vmax(*slots) - vmin(*slots) + 1)[0]) | ||
|
||
if async_degree is not None: | ||
if async_degree < size: | ||
warning("Ignoring provided asynchronous degree as it'd be " | ||
"too small for the required buffer (provided %d, " | ||
"but need at least %d for `%s`)" | ||
% (async_degree, size, function.name)) | ||
else: | ||
size = async_degree | ||
|
||
# Create `xd` -- a contraction Dimension for `dim` | ||
try: | ||
xd = sregistry.get('xds', (dim, size)) | ||
except KeyError: | ||
name = sregistry.make_name(prefix='db') | ||
v = CustomDimension(name, 0, size-1, size, dim) | ||
xd = sregistry.setdefault('xds', (dim, size), v) | ||
self.xd = dims[dims.index(dim)] = xd | ||
|
||
# Finally create the ModuloDimensions as children of `xd` | ||
if size > 1: | ||
# Note: indices are sorted so that the semantic order (sb0, sb1, sb2) | ||
# follows SymPy's index ordering (time, time-1, time+1) after modulo | ||
# replacement, so that associativity errors are consistent. This very | ||
# same strategy is also applied in clusters/algorithms/Stepper | ||
p, _ = offset_from_centre(d, indices) | ||
indices = sorted(indices, | ||
key=lambda i: -np.inf if i - p == 0 else (i - p)) | ||
for i in indices: | ||
try: | ||
md = sregistry.get('mds', (xd, i)) | ||
except KeyError: | ||
name = sregistry.make_name(prefix='sb') | ||
v = ModuloDimension(name, xd, i, size) | ||
md = sregistry.setdefault('mds', (xd, i), v) | ||
self.index_mapper[i] = md | ||
self.sub_iterators[d.root].append(md) | ||
# Initialize the buffer metadata. Depending on whether it's multi-level | ||
# buffering (e.g., double buffering) or first-level, we need to perform | ||
# different actions. Multi-level is trivial, because it essentially | ||
# inherits metadata from the previous buffering level | ||
if self.multi_buffering: | ||
self.__init_multi_buffering__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. code path excited via PRO |
||
else: | ||
assert len(indices) == 1 | ||
self.index_mapper[indices[0]] = 0 | ||
self.__init_firstlevel_buffering__(async_degree, sregistry) | ||
|
||
# Track the SubDimensions used to index into `function` | ||
for e in accessv.mapper: | ||
|
@@ -444,6 +401,11 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry): | |
for i in d0._defines: | ||
self.itintervals_mapper.setdefault(i, (interval.relaxed, (), Forward)) | ||
|
||
# The buffer dimensions | ||
dims = list(function.dimensions) | ||
assert dim in function.dimensions | ||
dims[dims.index(dim)] = self.xd | ||
|
||
# Finally create the actual buffer | ||
if function in cache: | ||
self.buffer = cache[function] | ||
|
@@ -462,6 +424,85 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry): | |
except TypeError: | ||
self.buffer = cache[function] = Array(**kwargs) | ||
|
||
def __init_multi_buffering__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is basically the only true new part |
||
try: | ||
expr, = self.accessv.exprs | ||
except ValueError: | ||
assert False | ||
|
||
lhs, rhs = expr.args | ||
|
||
self.xd = lhs.function.indices[self.dim] | ||
|
||
idx0 = lhs.indices[self.dim] | ||
idx1 = rhs.indices[self.dim] | ||
|
||
if self.is_read: | ||
if is_integer(idx0) or isinstance(idx0, ModuloDimension): | ||
# This is just for aesthetics of the generated code | ||
self.index_mapper[idx1] = 0 | ||
else: | ||
self.index_mapper[idx1] = idx1 | ||
else: | ||
self.index_mapper[idx0] = idx1 | ||
|
||
def __init_firstlevel_buffering__(self, async_degree, sregistry): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. already existing, simply refactored inside this method |
||
d = self.d | ||
dim = self.dim | ||
function = self.function | ||
|
||
indices = filter_ordered(i.indices[dim] for i in self.accessv.accesses) | ||
slots = [i.subs({dim: 0, dim.spacing: 1}) for i in indices] | ||
|
||
try: | ||
size = max(slots) - min(slots) + 1 | ||
except TypeError: | ||
# E.g., special case `slots=[-1 + time/factor, 2 + time/factor]` | ||
# Resort to the fast vector-based comparison machinery (rather than | ||
# the slower sympy.simplify) | ||
slots = [Vector(i) for i in slots] | ||
size = int((vmax(*slots) - vmin(*slots) + 1)[0]) | ||
|
||
if async_degree is not None: | ||
if async_degree < size: | ||
warning("Ignoring provided asynchronous degree as it'd be " | ||
"too small for the required buffer (provided %d, " | ||
"but need at least %d for `%s`)" | ||
% (async_degree, size, function.name)) | ||
else: | ||
size = async_degree | ||
|
||
# Create `xd` -- a contraction Dimension for `dim` | ||
try: | ||
xd = sregistry.get('xds', (dim, size)) | ||
except KeyError: | ||
name = sregistry.make_name(prefix='db') | ||
v = CustomDimension(name, 0, size-1, size, dim) | ||
xd = sregistry.setdefault('xds', (dim, size), v) | ||
self.xd = xd | ||
|
||
# Create the ModuloDimensions to step through the buffer | ||
if size > 1: | ||
# Note: indices are sorted so that the semantic order (sb0, sb1, sb2) | ||
# follows SymPy's index ordering (time, time-1, time+1) after modulo | ||
# replacement, so that associativity errors are consistent. This very | ||
# same strategy is also applied in clusters/algorithms/Stepper | ||
p, _ = offset_from_centre(d, indices) | ||
indices = sorted(indices, | ||
key=lambda i: -np.inf if i - p == 0 else (i - p)) | ||
for i in indices: | ||
try: | ||
md = sregistry.get('mds', (xd, i)) | ||
except KeyError: | ||
name = sregistry.make_name(prefix='sb') | ||
v = ModuloDimension(name, xd, i, size) | ||
md = sregistry.setdefault('mds', (xd, i), v) | ||
self.index_mapper[i] = md | ||
self.sub_iterators[d.root].append(md) | ||
else: | ||
assert len(indices) == 1 | ||
self.index_mapper[indices[0]] = 0 | ||
|
||
def __repr__(self): | ||
return "Buffer[%s,<%s>]" % (self.buffer.name, self.xd) | ||
|
||
|
@@ -497,6 +538,13 @@ def is_writeonly(self): | |
def has_uniform_subdims(self): | ||
return self.subdims_mapper is not None | ||
|
||
@property | ||
def multi_buffering(self): | ||
""" | ||
True if double-buffering or more, False otherwise. | ||
""" | ||
return all(is_memcpy(e) for e in self.accessv.exprs) | ||
|
||
@cached_property | ||
def indexed(self): | ||
return self.buffer.indexed | ||
|
@@ -517,7 +565,7 @@ def writeto(self): | |
# in principle this could be accessed through a stencil | ||
interval = Interval(i.dim, -h.left, h.right, i.stamp) | ||
except KeyError: | ||
assert d is self.xd | ||
assert d in self.xd._defines | ||
interval, si, direction = Interval(d), (), Forward | ||
intervals.append(interval) | ||
sub_iterators[d] = si | ||
|
@@ -550,6 +598,8 @@ def written(self): | |
sub_iterators[d] = si + as_tuple(self.sub_iterators[d]) | ||
directions[d] = direction | ||
|
||
directions[d.root] = direction | ||
|
||
relations = (tuple(i.dim for i in intervals),) | ||
intervals = IntervalGroup(intervals, relations=relations) | ||
|
||
|
@@ -572,8 +622,8 @@ def readfrom(self): | |
@cached_property | ||
def lastidx(self): | ||
""" | ||
A 2-tuple of indices representing, respectively, the "last" write to the | ||
buffer and the "last" read from the buffered Function. For example, | ||
A 2-tuple of indices representing, respectively, the *last* write to the | ||
buffer and the *last* read from the buffered Function. For example, | ||
`(sb1, time+1)` in the case of a forward-propagating `time` Dimension. | ||
""" | ||
try: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mere refactoring for code homogeneity (see below...)