Skip to content

Commit

Permalink
compiler: make lower_expr apply the mapper to indices in case functio…
Browse files Browse the repository at this point in the history
…n only appear as index
  • Loading branch information
mloubout committed Nov 1, 2021
1 parent 2d76a50 commit bb9b307
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ def lower_exprs(expressions, **kwargs):
# Apply substitutions, if necessary
if dimension_map:
indices = [j.xreplace(dimension_map) for j in indices]

mapper[i] = f.indexed[indices]

# Add dimensions map to the mapper in case dimensions are used
# as an expression, i.e. Eq(u, x, subdomain=xleft)
mapper.update(dimension_map)
# Add the user-supplied substitutions
mapper.update(subs)
# Apply mapper to expression
processed.append(uxreplace(expr, mapper))

if isinstance(expressions, Iterable):
Expand Down
7 changes: 5 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,13 +1076,16 @@ def indexify(self, indices=None, lshift=False, subs=None):

# Substitution for each index (spacing only used in own dimension)
subs = subs or {}
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]
subsi = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Add halo shift
shift = self._size_nodomain.left if lshift else tuple([0]*len(self.dimensions))
# Indexify indices in case it's indexed by functions
indices = [i.indexify(lshift=lshift, subs=subs) if i.is_Function else i
for i in self.args]
# Indices after substitutions
indices = [sympy.sympify((a - o + f).xreplace(s)) for a, o, f, s in
zip(self.args, self.origin, shift, subs)]
zip(indices, self.origin, shift, subsi)]
indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]
return self.indexed[indices]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,26 @@ def test_nested_lowering(self):
assert np.all(u0.data[:2, :2] == 1) and np.all(u0.data[1:3, 1:3] == 1)
assert np.all(u0.data[2:3, 3] == 2) and np.all(u0.data[3, 2:3] == 2)

def test_nested_lowering_indexify(self):
"""
Tests that nested function are lowered if only used as index.
"""
grid = Grid(shape=(4, 4), dtype=np.int32)
x, y = grid.dimensions

u0 = Function(name="u0", grid=grid)
u1 = Function(name="u1", grid=grid, dtype=np.int32)

u0.data[:, :] = 2
u1.data[:, :] = 1

eq0 = Eq(u0._subs(x, u1), 2*u0)

op0 = Operator(eq0)
op0.apply()
assert np.all(np.all(u0.data[i, :] == 2) for i in [0, 2, 3])
assert np.all(u0.data[1, :] == 4)


class TestArithmetic(object):

Expand Down

0 comments on commit bb9b307

Please sign in to comment.