Skip to content

Commit

Permalink
types: added new types to tag metadata to dimensions generated by CIRE
Browse files Browse the repository at this point in the history
  • Loading branch information
Edward Caunt committed Nov 29, 2023
1 parent 21a30b5 commit 42da86c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 16 deletions.
13 changes: 7 additions & 6 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from devito.tools import (Stamp, as_mapper, as_tuple, flatten, frozendict,
is_integer, generator, split, timed_pass)
from devito.types import (Eq, Symbol, Temp, TempArray, TempFunction,
ModuloDimension, CustomDimension, IncrDimension,
CireModuloDimension, CireCustomDimension, IncrDimension,
StencilDimension, Indexed, Hyperplane)
from devito.types.grid import MultiSubDimension

Expand Down Expand Up @@ -762,6 +762,7 @@ def optimize_schedule_rotations(schedule, sregistry):

try:
candidate = k[ridx]
print(candidate, type(candidate))
except IndexError:
# Degenerate alias (a scalar)
processed.extend(g)
Expand All @@ -781,10 +782,10 @@ def optimize_schedule_rotations(schedule, sregistry):
iis = candidate.lower
iib = candidate.upper

ii = ModuloDimension('%sii' % d, ds, iis, incr=iib)
cd = CustomDimension(name='%s%s' % (d, d), symbolic_min=ii, symbolic_max=iib,
symbolic_size=n)
dsi = ModuloDimension('%si' % ds, cd, cd + ds - iis, n)
ii = CireModuloDimension('%sii' % d, ds, iis, incr=iib)
cd = CireCustomDimension(name='%s%s' % (d, d), symbolic_min=ii, symbolic_max=iib,
symbolic_size=n)
dsi = CireModuloDimension('%si' % ds, cd, cd + ds - iis, n)

mapper = OrderedDict()
for i in g:
Expand All @@ -796,7 +797,7 @@ def optimize_schedule_rotations(schedule, sregistry):
md = mapper[v]
except KeyError:
name = sregistry.make_name(prefix='%sr' % d.name)
md = mapper.setdefault(v, ModuloDimension(name, ds, v, n))
md = mapper.setdefault(v, CireModuloDimension(name, ds, v, n))
mds.append(md)
indicess = [indices[:ridx] + [md] + indices[ridx + 1:]
for md, indices in zip(mds, i.indicess)]
Expand Down
22 changes: 15 additions & 7 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ def remove_redundant_moddims(iet):


def abridge_dim_names(iet):
def replace_parent(dim, mapper):
# Get args for rebuild
args = tuple(mapper[dim.parent] if a == 'parent'
else getattr(dim, a) for a in dim.__rargs__)
kwargs = {a: getattr(dim, a) for a in dim.__rkwargs__}

return dim._rebuild(*args, **kwargs)

subs = {}
mapper = {}
umapper = {}
Expand All @@ -232,16 +240,16 @@ def abridge_dim_names(iet):

for i in uindices:
if i.parent in mapper:
# Get args for rebuild
args = tuple(mapper[i.parent] if a == 'parent'
else getattr(i, a) for a in i.__rargs__)
kwargs = {a: getattr(i, a) for a in i.__rkwargs__}

umapper[i] = i._rebuild(*args, **kwargs)
umapper[i] = replace_parent(i, mapper)

for it in tree:
cire_dims = [d for d in FindSymbols('dimensions').visit(it) if d.is_Cire]
cire_dims = [d for d in cire_dims if d.parent in {**mapper, **umapper}]
cmapper = {d: replace_parent(d, {**mapper, **umapper}) for d in cire_dims
if d.parent in {**mapper, **umapper}}

for i in it:
subs[i] = Uxreplace({**mapper, **umapper}).visit(i)
subs[i] = Uxreplace({**mapper, **umapper, **cmapper}).visit(i)

if subs:
iet = Transformer(subs, nested=True).visit(iet)
Expand Down
46 changes: 45 additions & 1 deletion devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
'CustomDimension', 'SteppingDimension', 'SubDimension', 'ConditionalDimension',
'ModuloDimension', 'IncrDimension', 'BlockDimension', 'StencilDimension',
'Spacing', 'dimensions']
'Spacing', 'dimensions', 'CireCustomDimension', 'CireModuloDimension']


Thickness = namedtuple('Thickness', 'left right')
Expand Down Expand Up @@ -108,6 +108,7 @@ class Dimension(ArgProvider):
is_Modulo = False
is_Incr = False
is_Block = False
is_Cire = False

# Prioritize self's __add__ and __sub__ to construct AffineIndexAccessFunction
_op_priority = sympy.Expr._op_priority + 1.
Expand Down Expand Up @@ -1074,6 +1075,27 @@ def __sub__(self, other):
return super().__sub__(other)


class CireModuloDimension(ModuloDimension):
"""
A ModuloDimension generated by CIRE. Contains metadata to aid shortening of names
in the generated code.
"""

__rkwargs__ = ModuloDimension.__rkwargs__ + ('short_name',)

is_Cire = True

def __init_finalize__(self, name, parent, offset=None, modulo=None, incr=None,
origin=None, short_name=None, **kwargs):
super().__init_finalize__(name, parent, offset=offset, modulo=modulo, incr=incr,
origin=origin, **kwargs)
self._short_name = short_name

@property
def short_name(self):
return self._short_name if self._short_name is None else self.name


class AbstractIncrDimension(DerivedDimension):

"""
Expand Down Expand Up @@ -1380,6 +1402,28 @@ def _arg_check(self, *args):
return


class CireCustomDimension(CustomDimension):
"""
A CustomDimension generated by CIRE. Contains metadata to aid shortening of names
in the generated code.
"""

__rkwargs__ = CustomDimension.__rkwargs__ + ('short_name',)

is_Cire = True

def __init_finalize__(self, name, symbolic_min=None, symbolic_max=None,
symbolic_size=None, parent=None, short_name=None, **kwargs):
super().__init_finalize__(name, symbolic_min=symbolic_min,
symbolic_max=symbolic_max, symbolic_size=symbolic_size,
parent=parent, **kwargs)
self._short_name = short_name

@property
def short_name(self):
return self._short_name if self._short_name is None else self.name


class DynamicDimensionMixin(object):

"""
Expand Down
2 changes: 0 additions & 2 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,6 @@ def test_full_shape_w_subdims(self, rotate):
op0(time_M=1)
op1(time_M=1, u=u1)
assert np.all(u.data == u1.data)
if rotate:
raise ValueError

@pytest.mark.parametrize('rotate', [False, True])
def test_mixed_shapes(self, rotate):
Expand Down

0 comments on commit 42da86c

Please sign in to comment.