Skip to content

Commit

Permalink
compiler: make lower_expr apply the mapper to indices in case functio…
Browse files Browse the repository at this point in the history
…n only appear as index
  • Loading branch information
mloubout committed Nov 1, 2021
1 parent 2d76a50 commit 24f4e0f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .github/workflows/pytest-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ jobs:
echo "PATH=$PATH" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Check pip
run: |
type -a pip3
- name: Install dependencies
run: |
pip3 install --upgrade pip
Expand Down
1 change: 1 addition & 0 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def lower_exprs(expressions, **kwargs):
mapper.update(dimension_map)
# Add the user-supplied substitutions
mapper.update(subs)
# Apply mapper to expression
processed.append(uxreplace(expr, mapper))

if isinstance(expressions, Iterable):
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _uxreplace(expr, rule):
if not isinstance(v, dict):
return v, True
args, eargs = split(expr.args, lambda i: i in v)
args = [v[i] for i in args if v[i] is not None]
args = [_uxreplace(v[i], rule)[0] for i in args if v[i] is not None]
changed = True
else:
args, eargs = [], expr.args
Expand Down
26 changes: 26 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,32 @@ def test_nested_lowering(self):
assert np.all(u0.data[:2, :2] == 1) and np.all(u0.data[1:3, 1:3] == 1)
assert np.all(u0.data[2:3, 3] == 2) and np.all(u0.data[3, 2:3] == 2)

def test_nested_lowering_indexify(self):
"""
Tests that nested function are lowered if only used as index.
"""
grid = Grid(shape=(4, 4), dtype=np.int32)
x, y = grid.dimensions

u0 = Function(name="u0", grid=grid)
u1 = Function(name="u1", grid=grid, dtype=np.int32)

u0.data[:, :] = 2
u1.data[:, :] = 1

# Function as index only
eq0 = Eq(u0._subs(x, u1), 2*u0)
# Function as part of expression as index only
eq1 = Eq(u0._subs(x, u1 + 1), 4*u0)

op0 = Operator(eq0)
op0.apply()
op1 = Operator(eq1)
op1.apply()
assert np.all(np.all(u0.data[i, :] == 2) for i in [0, 3])
assert np.all(u0.data[1, :] == 4)
assert np.all(u0.data[2, :] == 8)


class TestArithmetic(object):

Expand Down

0 comments on commit 24f4e0f

Please sign in to comment.