Skip to content

Commit

Permalink
api: fix corner case staggered fd for centered x0
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed May 20, 2024
1 parent 9e220ae commit e636dd5
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest-core-nompi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:

- name: pytest-ubuntu-py312-gcc13-omp
python-version: '3.12'
os: ubuntu-20.04
os: ubuntu-latest
arch: "gcc-13"
language: "openmp"
sympy: "1.11"
Expand Down
8 changes: 6 additions & 2 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,12 @@ def _eval_at(self, func):
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
has to be computed at x=x + h_x/2.
"""
# If an x0 already exists do not overwrite it
x0 = self.x0 or func.indices_ref._getters
# If an x0 already exists or evaluating at the same function (i.e u = u.dx)
# do not overwrite it
if self.x0 or func is self.expr:
return self

x0 = func.indices_ref._getters
if self.expr.is_Add:
# If `expr` has both staggered and non-staggered terms such as
# `(u(x + h_x/2) + v(x)).dx` then we exploit linearity of FD to split
Expand Down
11 changes: 8 additions & 3 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,17 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None):
iexpr = x0 + d * dim.spacing
return IndexSet(dim, expr=iexpr), x0

# Shift for side
side = side or centered

# Evaluation point relative to the expression's grid
mid = (x0 - expr.indices_ref[dim]).subs({dim: 0, dim.spacing: 1})

# Centered scheme for staggered field create artifacts if not shifted
if expr.is_Staggered and mid == 0 and not dim.is_Time and side is not centered:
mid = 0.5 if x0 is dim else -0.5
x0 = x0 + mid * dim.spacing

# Shift for side
side = side or centered

# Indices range
o_min = int(np.ceil(mid - order/2)) + side.val
o_max = int(np.floor(mid + order/2)) + side.val
Expand Down
14 changes: 14 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,20 @@ def test_new_x0_eval_at(self):
v = Function(name="v", grid=grid, space_order=2)
assert u.dx(x0=x - x.spacing/2)._eval_at(v).x0 == {x: x - x.spacing/2}

@pytest.mark.parametrize('stagg', [True, False])
def test_eval_at_centered(self, stagg):
grid = Grid((10,))
x = grid.dimensions[0]
stagg = NODE if stagg else x
x0 = x if stagg else x + .5 * x.spacing

u = Function(name="u", grid=grid, space_order=2, staggered=stagg)
v = Function(name="v", grid=grid, space_order=2, staggered=stagg)

assert u.dx._eval_at(v).evaluate == u.dx(x0=x0).evaluate
assert v.dx._eval_at(u).evaluate == v.dx(x0=x0).evaluate
assert u.dx._eval_at(u).evaluate == u.dx.evaluate

def test_fd_new_lo(self):
grid = Grid((10,))
x = grid.dimensions[0]
Expand Down

0 comments on commit e636dd5

Please sign in to comment.