Skip to content

Commit

Permalink
compiler: make lower_expr apply the mapper to indices in case functio…
Browse files Browse the repository at this point in the history
…n only appear as index
  • Loading branch information
mloubout committed Nov 1, 2021
1 parent 2d76a50 commit 1b47f15
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/pytest-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ jobs:
echo "PATH=$PATH" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Check pip
run: |
type -a pip3
- name: Install dependencies
run: |
pip3 install --upgrade pip
Expand Down
3 changes: 2 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def lower_exprs(expressions, **kwargs):
dimension_map = {}

# Handle Functions (typical case)
mapper = {f: f.indexify(lshift=True, subs=dimension_map)
mapper = {f: lower_exprs(f.indexify(subs=dimension_map), **kwargs)
for f in retrieve_functions(expr)}

# Handle Indexeds (from index notation)
Expand All @@ -160,6 +160,7 @@ def lower_exprs(expressions, **kwargs):
mapper.update(dimension_map)
# Add the user-supplied substitutions
mapper.update(subs)
# Apply mapper to expression
processed.append(uxreplace(expr, mapper))

if isinstance(expressions, Iterable):
Expand Down
8 changes: 3 additions & 5 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def _data_alignment(self):
"""
return default_allocator().guaranteed_alignment

def indexify(self, indices=None, lshift=False, subs=None):
def indexify(self, indices=None, subs=None):
"""Create a types.Indexed from the current object."""
if indices is not None:
return Indexed(self.indexed, *indices)
Expand All @@ -1078,11 +1078,9 @@ def indexify(self, indices=None, lshift=False, subs=None):
subs = subs or {}
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Add halo shift
shift = self._size_nodomain.left if lshift else tuple([0]*len(self.dimensions))
# Indices after substitutions
indices = [sympy.sympify((a - o + f).xreplace(s)) for a, o, f, s in
zip(self.args, self.origin, shift, subs)]
indices = [sympy.sympify((a - o).xreplace(s)) for a, o, s in
zip(self.args, self.origin, subs)]
indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]
return self.indexed[indices]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,34 @@ def test_nested_lowering(self):
assert np.all(u0.data[:2, :2] == 1) and np.all(u0.data[1:3, 1:3] == 1)
assert np.all(u0.data[2:3, 3] == 2) and np.all(u0.data[3, 2:3] == 2)

def test_nested_lowering_indexify(self):
"""
Tests that nested function are lowered if only used as index.
"""
grid = Grid(shape=(4, 4), dtype=np.int32)
x, y = grid.dimensions

u0 = Function(name="u0", grid=grid)
u1 = Function(name="u1", grid=grid)
u2 = Function(name="u2", grid=grid)

u0.data[:, :] = 2
u1.data[:, :] = 1
u2.data[:, :] = 1

# Function as index only
eq0 = Eq(u0._subs(x, u1), 2*u0)
# Function as part of expression as index only
eq1 = Eq(u0._subs(x, u1._subs(y, u2) + 1), 4*u0)

op0 = Operator(eq0)
op0.apply()
op1 = Operator(eq1)
op1.apply()
assert np.all(np.all(u0.data[i, :] == 2) for i in [0, 3])
assert np.all(u0.data[1, :] == 4)
assert np.all(u0.data[2, :] == 8)


class TestArithmetic(object):

Expand Down

0 comments on commit 1b47f15

Please sign in to comment.