diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index 0f1e9ea5c26..f338e194e1e 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -366,7 +366,7 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant', raise NotImplementedError("TimeFunctions are not currently supported.") if nbl == 0: - for f in functions: + for f, data in zip(functions, datas): if isinstance(data, dv.Function): f.data[:] = data.data[:] else: @@ -382,7 +382,7 @@ def initialize_function(function, data, nbl, mapper=None, mode='constant', assert len(lhss) == len(rhss) == len(optionss) - name = name or 'pad_%s' % '_'.join(f.name for f in functions) + name = name or 'initialize_%s' % '_'.join(f.name for f in functions) assign(lhss, rhss, options=optionss, name=name, **kwargs) if pad_halo: diff --git a/devito/types/basic.py b/devito/types/basic.py index 55fe1d07a13..385aa399e25 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1245,8 +1245,15 @@ def indexify(self, indices=None, subs=None): subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions] # Indices after substitutions - indices = [sympy.sympify(a.subs(d, d - o).xreplace(s)) for a, d, o, s in - zip(self.args, self.dimensions, self.origin, subs)] + indices = [] + for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs): + if d in a.args: + # Shift by origin d -> d - o. + indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s))) + else: + # Dimension has been removed, e.g. u[10], plain shift by origin + indices.append(sympy.sympify(a - o).xreplace(s)) + indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)}) for i in indices] diff --git a/tests/test_builtins.py b/tests/test_builtins.py index e62c8d58db5..4ffe02b552f 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -333,8 +333,8 @@ def test_batching(self): a = np.arange(16).reshape((4, 4)) f = Function(name='f', grid=grid, dtype=np.int32) - g = Function(name='g', grid=grid, dtype=np.int32) - h = Function(name='h', grid=grid, dtype=np.int32) + g = Function(name='g', grid=grid, dtype=np.float32) + h = Function(name='h', grid=grid, dtype=np.float64) initialize_function([f, g, h], [a, a, a], 4, mode='reflect') diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 27d40c93c2a..ce5bb3cdf26 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -132,6 +132,17 @@ def test_indexed(): assert ub.indexed.free_symbols == {ub.indexed} +def test_indexed_staggered(): + grid = Grid(shape=(10, 10)) + x, y = grid.dimensions + hx, hy = x.spacing, y.spacing + + u = Function(name='u', grid=grid, staggered=(x, y)) + u0 = u.subs({x: 1, y: 2}) + assert u0.indices == (1 + hx / 2, 2 + hy / 2) + assert u0.indexify().indices == (1, 2) + + def test_bundle(): grid = Grid(shape=(4, 4))