Skip to content

Commit

Permalink
ir: make lower_expr apply the mapper to indices in case function only…
Browse files Browse the repository at this point in the history
… appear as index
  • Loading branch information
mloubout committed Nov 1, 2021
1 parent 2d76a50 commit 8532296
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def lower_exprs(expressions, **kwargs):
mapper.update(dimension_map)
# Add the user-supplied substitutions
mapper.update(subs)
# Apply mapper to itself for nested indexing
mapper = {k: uxreplace(v, mapper) for k, v in mapper.items()}
# Apply mapper to expression
processed.append(uxreplace(expr, mapper))

if isinstance(expressions, Iterable):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,26 @@ def test_nested_lowering(self):
assert np.all(u0.data[:2, :2] == 1) and np.all(u0.data[1:3, 1:3] == 1)
assert np.all(u0.data[2:3, 3] == 2) and np.all(u0.data[3, 2:3] == 2)

def test_nested_lowering_indexify(self):
"""
Tests that nested function are lowered if only used as index.
"""
grid = Grid(shape=(4, 4), dtype=np.int32)
x, y = grid.dimensions

u0 = Function(name="u0", grid=grid)
u1 = Function(name="u1", grid=grid, dtype=np.int32)

u0.data[:, :] = 2
u1.data[:, :] = 1

eq0 = Eq(u0._subs(x, u1), 2*u0)

op0 = Operator(eq0)
op0.apply()
assert np.all(np.all(u0.data[i, :] == 2) for i in [0,2,3])
assert np.all(u0.data[1, :] == 4)


class TestArithmetic(object):

Expand Down

0 comments on commit 8532296

Please sign in to comment.