From df588c57a8c5ea89c4ba26f6cf9125cdb6075b19 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 14 Feb 2024 15:12:26 +0000 Subject: [PATCH] compiler: Revamp linearize --- devito/passes/iet/linearization.py | 362 +++++++++++++++++++---------- tests/test_linearize.py | 14 +- 2 files changed, 243 insertions(+), 133 deletions(-) diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index 5ec2a41a9d..89bf82f863 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -12,12 +12,10 @@ from devito.passes.iet.engine import iet_pass from devito.passes.iet.parpragma import PragmaIteration from devito.symbolics import DefFunction, MacroArgument, ccode -from devito.tools import Bunch, filter_ordered, prod +from devito.tools import Bunch, filter_ordered, flatten, prod from devito.types import Array, Bundle, Symbol, FIndexed, Indexed, Wildcard -from devito.types.basic import IndexedData from devito.types.dense import DiscreteFunction - __all__ = ['linearize'] @@ -27,47 +25,49 @@ def linearize(graph, **kwargs): access function, such as `a[i, j] -> a[i*n + j]`. The row-major format of the underlying Function objects is honored. """ + sregistry = kwargs.get('sregistry') options = kwargs.get('options', {}) - lmode = kwargs.pop('lmode', options.get('linearize')) - if not lmode: + + mode = options.get('linearize') + maybe_callback = kwargs.pop('callback', mode) + + if not maybe_callback: return + elif callable(maybe_callback): + key = lambda f: maybe_callback(f) and f.ndim > 1 and not f._mem_stack + else: + key = lambda f: f.is_AbstractFunction and f.ndim > 1 and not f._mem_stack + + if options['index-mode'] == 'int32': + dtype = np.int32 + else: + dtype = np.int64 - # All information that need to be propagated across callables - tracker = Bunch( - sizes={}, - strides={}, - undef=defaultdict(set), - subs={}, - headers={} - ) + # NOTE: Even if `mode=False`, `key` may still want to enforce linearization + # of some Functions, so it takes precedence and we then attempt to linearize + if not mode or isinstance(mode, bool): + mode = 'basic' - linearization(graph, lmode=lmode, tracker=tracker, **kwargs) + tracker = Tracker(mode, dtype, sregistry) + + linearization(graph, key=key, tracker=tracker, **kwargs) # Sanity check assert not tracker.undef @iet_pass -def linearization(iet, lmode=None, tracker=None, **kwargs): +def linearization(iet, key=None, tracker=None, **kwargs): """ Carry out the actual work of `linearize`. """ - # `lmode` may be a callback describing what Function types, and under what - # conditions, should linearization be applied - if callable(lmode): - key0 = lambda f: lmode(f) and f.ndim > 1 - else: - # Default - key0 = lambda f: f.is_AbstractFunction and f.ndim > 1 - key = lambda f: key0(f) and not f._mem_stack - - iet = linearize_accesses(iet, key, tracker, **kwargs) + iet = linearize_accesses(iet, key, tracker) iet = linearize_pointers(iet, key) iet = linearize_transfers(iet, **kwargs) iet = linearize_clauses(iet, **kwargs) # Postprocess headers - headers = [(ccode(define), ccode(expr)) for define, expr in tracker.headers.values()] + headers = sorted((ccode(i), ccode(e)) for i, e in tracker.headers.values()) return iet, {'headers': headers} @@ -88,60 +88,153 @@ def key1(f, d): return False -def linearize_accesses(iet, key0, tracker=None, sregistry=None, options=None, - **kwargs): +class Size(Symbol): + """ - Turn Indexeds into FIndexeds and create the necessary access Macros. + Symbol representing the size of a Function dimension. """ - if options['index-mode'] == 'int32': - dtype = np.int32 - else: - dtype = np.int64 - # 1) What `iet` *needs* - indexeds = FindSymbols('indexeds').visit(iet) - needs = filter_ordered(i.function for i in indexeds if key0(i.function)) - needs = sorted(needs, key=lambda f: len(f.dimensions), reverse=True) + pass + - # Update unique sizes table - # E.g. `{(x, 8, grid) -> x_fsz0}` - for f in needs: - # NOTE: the outermost dimension is unnecessary +class Stride(Symbol): + + """ + Symbol representing the stride of a Function dimension. + + The stride is the the number of elements to skip over in the array's linear + memory layout to move one step along that dimension. For example, consider + a 2D array with rows and columns. The stride for the row dimension tells + you how many elements you need to move forward in memory to get from one + row to the next. Similarly, the stride for the column dimension tells you + the number of elements you need to skip to move from one column to the next + in memory. The stride, therefore, is the product of the size of the + dimensions that come after the dimension in question. + """ + + pass + + +class Dist(Symbol): + + """ + Symbol representing the distance between a reference point and another point. + """ + + pass + + +class Tracker(object): + + def __init__(self, mode, dtype, sregistry): + self.mode = mode + self.dtype = dtype + self.sregistry = sregistry + + self.sizes = {} + self.strides = {} + self.strides_dynamic = {} # Strides varying regularly across iterations + self.dists = {} + self.undef = defaultdict(set) + self.headers = {} + + def add(self, f): + # Update unique sizes table for d in f.dimensions[1:]: k = key1(f, d) - if k and k not in tracker.sizes: - name = sregistry.make_name(prefix='%s_fsz' % d.name) - fsz = Symbol(name=name, dtype=dtype, is_const=True) - tracker.sizes[k] = fsz - - # Update unique strides table - # E.g., `{x_stride0 -> (y_fsz0, z_fsz0)}` - mapper = defaultdict(dict) - for f in needs: + if not k or k in self.sizes: + continue + name = self.sregistry.make_name(prefix='%s_fsz' % d.name) + self.sizes[k] = Size(name=name, dtype=self.dtype, is_const=True) + + # Update unique strides table for n, d in enumerate(f.dimensions[1:], 1): try: - k = tuple(tracker.sizes[key1(f, d1)] for d1 in f.dimensions[n:]) + k = tuple(self.get_size(f, d1) for d1 in f.dimensions[n:]) except KeyError: continue - if k not in tracker.strides: - name = sregistry.make_name(prefix='%s_stride' % d.name) - stride = Symbol(name=name, dtype=dtype, is_const=True) - tracker.strides[k] = stride - mapper[f][d] = tracker.strides[k] + if k in self.strides: + continue + name = self.sregistry.make_name(prefix='%s_stride' % d.name) + self.strides[k] = Stride(name=name, dtype=self.dtype, is_const=True) + + def update(self, functions): + for f in functions: + self.add(f) + + def add_strides_dynamic(self, f, strides): + assert not f.is_regular + assert len(strides) == f.ndim - 1 + self.strides_dynamic[f] = tuple(strides) + + def add_header(self, b, define, expr): + # NOTE: the input must be an IndexedBase and not a Function because + # we might require either the `.base` or the `.dmap`, or perhaps both + v = Bunch(define=define, expr=expr) + self.headers[b] = v + return v + + def needs_header(self, b): + return b.function.ndim > 1 and b not in self.headers + + def get_header(self, b): + return self.headers.get(b) + + def get_size(self, f, d): + return self.sizes[key1(f, d)] + + def get_sizes(self, f): + return tuple(self.get_size(f, d) for d in f.dimensions[1:]) + + def map_strides(self, f): + dims = list(f.dimensions[1:]) + if f.is_regular: + sizes = self.get_sizes(f) + return {d: self.strides[sizes[n:]] for n, d in enumerate(dims)} + elif f in self.strides_dynamic: + return {d: i for d, i in zip(dims, self.strides_dynamic[f])} + else: + return {} - # Update unique access macros - # E.g. `{f(x, y) -> foo}`, `foo(Indexed) -> f[(x)*y_stride0 + (y)]` - for f in needs: - # All the Indexeds demanding an access macro. For a given `f`, more than - # one access macros may be required -- in particular, it will be more - # than one if `f` was optimized by the `split_pointers` pass - v = [i for i in indexeds if i.function is f] + def make_dist(self, v): + try: + dist = self.dists[v] + except KeyError: + name = self.sregistry.make_name(prefix='dst') + dist = self.dists[v] = Dist(name=name, dtype=self.dtype, is_const=True) + return dist + + +def linearize_accesses(iet, key0, tracker=None): + """ + Turn Indexeds into FIndexeds and create the necessary access Macros. + """ + # 1) What `iet` *needs* + indexeds = FindSymbols('indexeds').visit(iet) + needs = filter_ordered(i.function for i in indexeds if key0(i.function)) + needs = sorted(needs, key=lambda f: len(f.dimensions), reverse=True) + + # Update unique sizes and strides + tracker.update(needs) - _generate_macro(f, v, tracker, mapper[f], sregistry) + # Pick the chosen linearization mode + generate_linearization = linearization_registry[tracker.mode] - # Turn Indexeds into FIndexeds + # Generate unique access macros + # E.g. `{f(x, y) -> foo}`, `foo(Indexed) -> f[(x)*y_stride0 + (y)]` + # And linearize Indexeds # E.g. `u[t2, x+8, y+9, z+7] -> uL(t2, x+8, y+9, z+7)` - iet = Uxreplace(tracker.subs).visit(iet) + subs = {} + for i in indexeds: + f = i.function + if f not in needs: + continue + + v = generate_linearization(f, i, tracker) + if v is not None: + subs[i] = v + + iet = Uxreplace(subs).visit(iet) # 2) What `iet` *offers* # E.g. `{x_fsz0 -> u_vec->size[1]}` @@ -150,47 +243,53 @@ def linearize_accesses(iet, key0, tracker=None, sregistry=None, options=None, instances = {} for f in offers: for d in f.dimensions[1:]: - k = key1(f, d) try: - fsz = tracker.sizes[k] + fsz = tracker.get_size(f, d) + except KeyError: + continue + try: val = _generate_fsz(f, d) if val.free_symbols.issubset(f.bound_symbols): - v = DummyExpr(fsz, val, init=True) - else: - v = None - except (AttributeError, KeyError): - v = None - if v is not None: - instances[fsz] = v - - # 3) What we need to construct are `iet`'s needs plus the callee delegations - candidates = set().union(*[v.values() for v in mapper.values()]) + instances[fsz] = DummyExpr(fsz, val, init=True) + except AttributeError: + pass + + # 3) Which symbols (Strides, Dists) does `iet` need? + symbols = flatten(i.free_symbols for i in subs.values()) + candidates = {i for i in symbols if isinstance(i, (Stride, Dist))} for n in FindNodes(Call).visit(iet): + # Also consider the strides needed by the callee candidates.update(tracker.undef.pop(n.name, [])) - # 4) Construct what needs to *and* can be constructed - mapper = dict(zip(tracker.strides.values(), tracker.strides.keys())) - stmts0, stmts1 = [], [] + # 4) What `strides` can indeed be constructed? + mapper = {} + for sizes, stride in tracker.strides.items(): + if stride in candidates: + if set(sizes).issubset(instances): + mapper[stride] = sizes + else: + tracker.undef[iet.name].add(stride) + + # 5) Construct what needs to *and* can be constructed + stmts, stmts1 = [], [] for stride, sizes in mapper.items(): - if stride not in candidates: - continue - try: - stmts0.extend([instances[size] for size in sizes]) - except KeyError: - # `stride` would be needed by `iet`, but we don't have the means - # to define it, hence we delegate to the caller - tracker.undef[iet.name].add(stride) - continue + stmts.extend([instances[size] for size in sizes]) stmts1.append(DummyExpr(stride, prod(sizes), init=True)) + stmts = filter_ordered(stmts) - # 5) Attach `stmts0` and `stmts1` to `iet` - if stmts0: - assert len(stmts1) > 0 - stmts = filter_ordered(stmts0) + [BlankLine] + stmts1 + stmts2 = [] + for v, dist in tracker.dists.items(): + if dist in candidates: + stmts2.append(DummyExpr(dist, v, init=True)) + + for i in [stmts1, stmts2]: + if i and stmts: + stmts.append(BlankLine) + stmts.extend(i) + + if stmts: body = iet.body._rebuild(strides=stmts) iet = iet._rebuild(body=body) - else: - assert len(stmts1) == 0 return iet @@ -219,50 +318,61 @@ def _(f, d): @singledispatch -def _generate_macro(f, indexeds, tracker, strides, sregistry): - pass +def _generate_header_basic(f, tracker): + return + +@_generate_header_basic.register(DiscreteFunction) +@_generate_header_basic.register(Array) +@_generate_header_basic.register(Bundle) +def _(f, i, tracker): + b = i.base + + if not tracker.needs_header(b): + return tracker.get_header(b) -@_generate_macro.register(DiscreteFunction) -@_generate_macro.register(Array) -@_generate_macro.register(Bundle) -def _(f, indexeds, tracker, strides, sregistry): # Generate e.g. `usave[(time)*xi_slc0 + (xi)*yi_slc0 + (yi)]` - assert len(strides) == len(f.dimensions) - 1 + + strides = tracker.map_strides(f) + macroargnames = [d.name for d in f.dimensions] macroargs = [MacroArgument(i) for i in macroargnames] items = [m*strides[d] for m, d in zip(macroargs, f.dimensions[1:])] items.append(MacroArgument(f.dimensions[-1].name)) - headers = tracker.headers - subs = tracker.subs - for i in indexeds: - if i in subs: - continue + pname = tracker.sregistry.make_name(prefix='%sL' % f.name) + define = DefFunction(pname, macroargnames) + expr = Indexed(b, Add(*items, evaluate=False)) - n = len(i.indices) - if n == 1: - continue + return tracker.add_header(b, define, expr) - if i.base not in headers: - pname = sregistry.make_name(prefix='%sL' % f.name) - value = Add(*items[-n:], evaluate=False) - expr = Indexed(IndexedData(i.base, None, f), value) - define = DefFunction(pname, macroargnames[-n:]) +@singledispatch +def _generate_linearization_basic(f, i, tracker): + assert False - headers[i.base] = (define, expr) - else: - define, _ = headers[i.base] - pname = str(define.name) - if len(i.indices) == i.function.ndim: - v = tuple(strides.values())[-n:] - subs[i] = FIndexed.from_indexed(i, pname, strides=v) - else: - # Honour custom indexing - subs[i] = i.base[sum(i.indices)] +@_generate_linearization_basic.register(DiscreteFunction) +@_generate_linearization_basic.register(Array) +@_generate_linearization_basic.register(Bundle) +def _(f, i, tracker): + header = _generate_header_basic(f, i, tracker) + + n = len(i.indices) + + if header and n == i.function.ndim: + pname = header.define.name.value + strides = tuple(tracker.map_strides(f).values())[-n:] + return FIndexed.from_indexed(i, pname, strides=strides) + else: + # Honour custom indexing + return i.base[sum(i.indices)] + + +linearization_registry = { + 'basic': _generate_linearization_basic, +} def linearize_pointers(iet, key): diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 045cb1955a..7f87eecedd 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -290,7 +290,7 @@ def test_strides_forwarding0(): graph = Graph(foo) graph.efuncs['bar'] = bar - linearize(graph, lmode=True, options={'index-mode': 'int32'}, + linearize(graph, callback=True, options={'index-mode': 'int32'}, sregistry=SymbolRegistry()) # Since `f` is passed via `f.indexed`, we expect the stride exprs to be @@ -320,7 +320,7 @@ def test_strides_forwarding1(): graph = Graph(foo) graph.efuncs['bar'] = bar - linearize(graph, lmode=True, options={'index-mode': 'int32'}, + linearize(graph, callback=True, options={'index-mode': 'int32'}, sregistry=SymbolRegistry()) # Despite `a` is passed via `a.indexed`, and since it's an Array (which @@ -368,7 +368,7 @@ def test_strides_forwarding2(): graph.efuncs['foo0'] = foo0 graph.efuncs['foo1'] = foo1 - linearize(graph, lmode=True, options={'index-mode': 'int32'}, + linearize(graph, callback=True, options={'index-mode': 'int32'}, sregistry=SymbolRegistry()) # Both foo's are expected to define `a`! @@ -408,7 +408,7 @@ def test_strides_forwarding3(): graph = Graph(root) graph.efuncs['bar'] = bar - linearize(graph, lmode=True, options={'index-mode': 'int64'}, + linearize(graph, callback=True, options={'index-mode': 'int64'}, sregistry=SymbolRegistry()) # Both foo's are expected to define `a`! @@ -440,7 +440,7 @@ def test_strides_forwarding4(): graph = Graph(root) graph.efuncs['bar'] = bar - linearize(graph, lmode=True, options={'index-mode': 'int64'}, + linearize(graph, callback=True, options={'index-mode': 'int64'}, sregistry=SymbolRegistry()) root = graph.root @@ -514,7 +514,7 @@ def test_call_retval_indexed(): # Emulate what the compiler would do graph = Graph(foo) - linearize(graph, lmode=True, options={'index-mode': 'int64'}, + linearize(graph, callback=True, options={'index-mode': 'int64'}, sregistry=SymbolRegistry()) foo = graph.root @@ -540,7 +540,7 @@ def test_bundle(): graph = Graph(foo) graph.efuncs['bar'] = bar - linearize(graph, lmode=True, options={'index-mode': 'int64'}, + linearize(graph, callback=True, options={'index-mode': 'int64'}, sregistry=SymbolRegistry()) foo = graph.root