Skip to content

Commit

Permalink
api: switch to recommended scipy i0 when available
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Apr 3, 2024
1 parent a83fac7 commit 353af4f
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 15 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/pytest-core-mpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@ jobs:
- name: gcc
arch: gcc
os: ubuntu-latest
mpiflag: ""
- name: icx
arch: icx
os: ubuntu-latest
# Need safe math for icx due to inaccuracy with mpi+sinc interpolation
mpiflag: "-e DEVITO_SAFE_MATH=1"

steps:
- name: Checkout devito
Expand All @@ -85,3 +88,8 @@ jobs:
- name: Test with pytest
run: |
docker run --rm -e CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }} -e OMP_NUM_THREADS=1 --name testrun devito_img pytest tests/test_mpi.py
- name: Test examples with MPI
run: |
docker run --rm ${{ matrix.mpiflag }} -e DEVITO_MPI=1 -e OMP_NUM_THREADS=1 --name examplerun devito_img mpiexec -n 2 pytest examples/seismic/acoustic
docker run --rm -e DEVITO_MPI=1 -e OMP_NUM_THREADS=1 --name examplerun devito_img mpiexec -n 2 pytest examples/seismic/tti
28 changes: 20 additions & 8 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
import numpy as np
from cached_property import cached_property

try:
from scipy.special import i0
except ImportError:
from numpy import i0

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.logger import warning
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten, filter_ordered
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
Expand Down Expand Up @@ -216,7 +222,7 @@ def _positions(self, implicit_dims):
return [Eq(v, INT(floor(k)), implicit_dims=implicit_dims)
for k, v in self.sfunction._position_map.items()]

def _interp_idx(self, variables, implicit_dims=None, pos_only=None):
def _interp_idx(self, variables, implicit_dims=None, pos_only=()):
"""
Generate interpolation indices for the DiscreteFunctions in ``variables``.
"""
Expand All @@ -238,10 +244,8 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=None):
# Position only replacement, not radius dependent.
# E.g src.inject(vp(x)*src) needs to use vp[posx] at all points
# not vp[posx + rx]
if pos_only is not None:
idx_subs.update({v: v.subs({k: p
for (k, p) in zip(mapper, pos)})
for v in pos_only})
idx_subs.update({v: v.subs({k: p for (k, p) in zip(mapper, pos)})
for v in pos_only})

return idx_subs, temps

Expand Down Expand Up @@ -368,7 +372,7 @@ def _inject(self, field, expr, implicit_dims=None):
implicit_dims = implicit_dims + tuple(r.parent for r in self._rdim)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(list(fields), implicit_dims=implicit_dims,
idx_subs, temps = self._interp_idx(fields, implicit_dims=implicit_dims,
pos_only=variables)

# Substitute coordinate base symbols into the interpolation coefficients
Expand Down Expand Up @@ -465,6 +469,14 @@ class SincInterpolator(PrecomputedInterpolator):
4: 4.14, 5: 5.26, 6: 6.40,
7: 7.51, 8: 8.56, 9: 9.56, 10: 10.64}

def __init__(self, sfunction):
if i0 is np.i0:
warning("""
Using `numpy.i0`. We (and numpy) recommend to install scipy to improve the performance
of the SincInterpolator that uses i0 (Bessel function).
""")
super().__init__(sfunction)

@cached_property
def interpolation_coeffs(self):
coeffs = {}
Expand All @@ -487,7 +499,7 @@ def _weights(self):
def _arg_defaults(self, coords=None, sfunc=None):
args = {}
b = self._b_table[self.r]
b0 = np.i0(b)
b0 = i0(b)
if coords is None or sfunc is None:
raise ValueError("No coordinates or sparse function provided")
# Coords to indices
Expand All @@ -499,7 +511,7 @@ def _arg_defaults(self, coords=None, sfunc=None):
data = np.zeros((coords.shape[0], 2*self.r), dtype=sfunc.dtype)
for ri in range(2*self.r):
rpos = ri - self.r + 1 - coords[:, j]
num = np.i0(b*np.sqrt(1 - (rpos/self.r)**2))
num = i0(b*np.sqrt(1 - (rpos/self.r)**2))
data[:, ri] = num / b0 * np.sinc(rpos)
args[self.interpolation_coeffs[r].name] = data

Expand Down
12 changes: 6 additions & 6 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,16 +822,16 @@ def __init_finalize__(self, *args, **kwargs):
self._coordinates = self.__subfunc_setup__(coordinates, 'coords')
self._dist_origin = {self._coordinates: self.grid.origin_offset}

def __interp_setup__(self, interp='linear', r=None, **kwargs):
self.interpolation = interp
self.interpolator = _interpolators[interp](self)
self._radius = r or _default_radius[interp]
if interp == 'sinc':
def __interp_setup__(self, interpolation='linear', r=None, **kwargs):
self.interpolation = interpolation
self.interpolator = _interpolators[interpolation](self)
self._radius = r or _default_radius[interpolation]
if interpolation == 'sinc':
if self._radius < 2:
raise ValueError("'sinc' interpolator requires a radius of at least 2")
elif self._radius > 10:
raise ValueError("'sinc' interpolator requires a radius of at most 10")
elif interp == 'linear' and self._radius != 1:
elif interpolation == 'linear' and self._radius != 1:
self._radius = 1

@cached_property
Expand Down
2 changes: 1 addition & 1 deletion examples/seismic/acoustic/acoustic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ def test_isoacoustic(fs, normrec, dtype, interp):
run(shape=shape, spacing=spacing, nbl=args.nbl, tn=tn, fs=args.fs,
space_order=args.space_order, preset=preset, kernel=args.kernel,
autotune=args.autotune, opt=args.opt, full_run=args.full,
checkpointing=args.checkpointing, dtype=args.dtype)
checkpointing=args.checkpointing, dtype=args.dtype, interpolation=args.interp)
2 changes: 2 additions & 0 deletions examples/seismic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,6 @@ def __call__(self, parser, args, values, option_string=None):
type=float, help="Simulation time in millisecond")
parser.add_argument("-dtype", action=_dtype_store, dest="dtype", default=np.float32,
choices=['float32', 'float64'])
parser.add_argument("-interp", dest="interp", default="linear",
choices=['linear', 'sinc'])
return parser

0 comments on commit 353af4f

Please sign in to comment.