Skip to content

Commit

Permalink
api: fix indexification of staggered functions after dimension subs
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 7, 2023
1 parent 9cb54fc commit 9b7282a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
4 changes: 2 additions & 2 deletions devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',
raise NotImplementedError("TimeFunctions are not currently supported.")

if nbl == 0:
for f in functions:
for f, data in zip(functions, datas):
if isinstance(data, dv.Function):
f.data[:] = data.data[:]
else:
Expand All @@ -382,7 +382,7 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant',

assert len(lhss) == len(rhss) == len(optionss)

name = name or 'pad_%s' % '_'.join(f.name for f in functions)
name = name or 'initialize_%s' % '_'.join(f.name for f in functions)
assign(lhss, rhss, options=optionss, name=name, **kwargs)

if pad_halo:
Expand Down
11 changes: 9 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,8 +1245,15 @@ def indexify(self, indices=None, subs=None):
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Indices after substitutions
indices = [sympy.sympify(a.subs(d, d - o).xreplace(s)) for a, d, o, s in
zip(self.args, self.dimensions, self.origin, subs)]
indices = []
for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs):
if d in a.args:
# Shift by origin d -> d - o.
indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s)))
else:
# Dimension has been removed, e.g. u[10], plain shift by origin
indices.append(sympy.sympify(a - o).xreplace(s))

indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def test_batching(self):
a = np.arange(16).reshape((4, 4))

f = Function(name='f', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.int32)
h = Function(name='h', grid=grid, dtype=np.int32)
g = Function(name='g', grid=grid, dtype=np.float32)
h = Function(name='h', grid=grid, dtype=np.float64)

initialize_function([f, g, h], [a, a, a], 4, mode='reflect')

Expand Down
11 changes: 11 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def test_indexed():
assert ub.indexed.free_symbols == {ub.indexed}


def test_indexed_staggered():
grid = Grid(shape=(10, 10))
x, y = grid.dimensions
hx, hy = x.spacing, y.spacing

u = Function(name='u', grid=grid, staggered=(x, y))
u0 = u.subs({x: 1, y: 2})
assert u0.indices == (1 + hx / 2, 2 + hy / 2)
assert u0.indexify().indices == (1, 2)


def test_bundle():
grid = Grid(shape=(4, 4))

Expand Down

0 comments on commit 9b7282a

Please sign in to comment.