Skip to content

Commit

Permalink
api: apply staggered shift at interpolation replacement before dimens…
Browse files Browse the repository at this point in the history
…ion is lost
  • Loading branch information
mloubout committed May 9, 2023
1 parent 874a688 commit 72b0a3a
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def generate_indices_staggered(expr, dim, order, side=None, x0=None):
ind0 = expr.indices_ref[dim]
except AttributeError:
ind0 = start

if start != ind0:
if order < 2:
indices = [start - diff/2, start + diff/2]
Expand Down
11 changes: 9 additions & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,16 @@ def _interpolation_indices(self, variables, offset=0, field_offset=0,
condition = sympy.And(lb, ub, evaluate=False)
mapper[d] = ConditionalDimension(p.name, self.sfunction._sparse_dim,
condition=condition, indirect=True)

# Apply mapper to each variable with origin correction before the
# dimension get replaced.
vmapper = {}
for v in variables:
try:
vmapper[v] = v.subs({k: c - v.origin[k] for k, c in mapper.items()})
except KeyError:
vmapper[v] = v.subs(mapper)
# Track Indexed substitutions
idx_subs.append(mapper)
idx_subs.append(vmapper)

# Temporaries for the position
temps = [Eq(v, k, implicit_dims=implicit_dims)
Expand Down
4 changes: 3 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ def origin(self):
f(x) : origin = 0
f(x + hx/2) : origin = hx/2
"""
return tuple(r - d for d, r in zip(self.dimensions, self.indices_ref))
return DimensionTuple(*(r-d for d, r in zip(self.dimensions, self.indices_ref)),
getters=self.dimensions)

@property
def dimensions(self):
Expand Down Expand Up @@ -1201,6 +1202,7 @@ def indexify(self, indices=None, subs=None):
zip(self.args, self.dimensions, self.origin, subs)]
indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
for i in indices]
# print(self, self.indexed[indices])
return self.indexed[indices]

def __getitem__(self, index):
Expand Down
2 changes: 1 addition & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def __indices_setup__(cls, **kwargs):
mapper = {d: d for d in dimensions}
for s in as_tuple(staggered):
c, s = s.as_coeff_Mul()
mapper.update({s: s + c * s.spacing/2})
mapper.update({s: s + c * s.spacing / 2})
staggered_indices = mapper.values()

return tuple(dimensions), tuple(staggered_indices)
Expand Down

0 comments on commit 72b0a3a

Please sign in to comment.