Skip to content

Commit

Permalink
compiler: make lower_expr apply the mapper to indices in case functio…
Browse files Browse the repository at this point in the history
…n only appear as index
  • Loading branch information
mloubout committed Nov 2, 2021
1 parent 2d76a50 commit 362ec77
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 455 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .
pip install matplotlib blosc
pip install matplotlib blosc cycler==0.10
- name: Seismic notebooks
run: |
Expand Down
3 changes: 2 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def lower_exprs(expressions, **kwargs):
dimension_map = {}

# Handle Functions (typical case)
mapper = {f: f.indexify(lshift=True, subs=dimension_map)
mapper = {f: lower_exprs(f.indexify(subs=dimension_map), **kwargs)
for f in retrieve_functions(expr)}

# Handle Indexeds (from index notation)
Expand All @@ -160,6 +160,7 @@ def lower_exprs(expressions, **kwargs):
mapper.update(dimension_map)
# Add the user-supplied substitutions
mapper.update(subs)
# Apply mapper to expression
processed.append(uxreplace(expr, mapper))

if isinstance(expressions, Iterable):
Expand Down
8 changes: 3 additions & 5 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def _data_alignment(self):
"""
return default_allocator().guaranteed_alignment

def indexify(self, indices=None, lshift=False, subs=None):
def indexify(self, indices=None, subs=None):
"""Create a types.Indexed from the current object."""
if indices is not None:
return Indexed(self.indexed, *indices)
Expand All @@ -1078,11 +1078,9 @@ def indexify(self, indices=None, lshift=False, subs=None):
subs = subs or {}
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Add halo shift
shift = self._size_nodomain.left if lshift else tuple([0]*len(self.dimensions))
# Indices after substitutions
indices = [sympy.sympify((a - o + f).xreplace(s)) for a, o, f, s in
zip(self.args, self.origin, shift, subs)]
indices = [sympy.sympify((a - o).xreplace(s)) for a, o, s in
zip(self.args, self.origin, subs)]
indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]
return self.indexed[indices]
Expand Down
Loading

0 comments on commit 362ec77

Please sign in to comment.