Skip to content

Commit

Permalink
tests: add staggered origin in Abs MFE
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 11, 2023
1 parent eb66c63 commit 5d9a2a2
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
14 changes: 6 additions & 8 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,14 @@ def _interpolation_indices(self, variables, offset=0, field_offset=0,
condition = sympy.And(lb, ub, evaluate=False)
mapper[d] = ConditionalDimension(p.name, self.sfunction._sparse_dim,
condition=condition, indirect=True)

# Apply mapper to each variable with origin correction before the
# dimension get replaced.
vmapper = {}
for v in variables:
try:
vmapper[v] = v.subs({k: c - v.origin[k] for k, c in mapper.items()})
except KeyError:
vmapper[v] = v.subs(mapper)
# Dimensions get replaced
subs = {v: v.subs({k: c - v.origin.get(k, 0) for k, c in mapper.items()})
for v in variables}

# Track Indexed substitutions
idx_subs.append(vmapper)
idx_subs.append(subs)

# Temporaries for the position
temps = [Eq(v, k, implicit_dims=implicit_dims)
Expand Down
3 changes: 3 additions & 0 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __getnewargs_ex__(self):
# objects with varying number of attributes
return (tuple(self), dict(self.__dict__))

def get(self, key, val):
return self._getters.get(key, val)


class ReducerMap(MultiDict):

Expand Down
14 changes: 13 additions & 1 deletion tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, Min, Max)
from devito.ir import Expression, FindNodes
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
CallFromPointer, Cast, FieldFromPointer,
CallFromPointer, Cast, FieldFromPointer, INT,
FieldFromComposite, IntDiv, ccode, uxreplace)
from devito.types import Array, Bundle, LocalObject, Object, Symbol as dSymbol

Expand All @@ -31,6 +31,18 @@ def test_float_indices():
assert indices == 1


def test_func_of_indices():
"""
Test that origin is correctly processed with functions
"""
grid = Grid((10,))
x = grid.dimensions[0]
u = Function(name="u", grid=grid, space_order=2, staggered=x)
us = u.subs({u.indices[0]: INT(Abs(u.indices[0]))})
assert us.indices[0] == INT(Abs(x + x.spacing/2))
assert us.indexify().indices[0] == INT(Abs(x))


@pytest.mark.parametrize('dtype,expected', [
(np.float32, "float r0 = 1.0F/h_x;"),
(np.float64, "double r0 = 1.0/h_x;")
Expand Down

0 comments on commit 5d9a2a2

Please sign in to comment.