Skip to content

Commit

Permalink
deps: make pyrevolve optional throughout
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 30, 2023
1 parent ed9fb9a commit c682d38
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .[tests]
pip install matplotlib
pip install matplotlib pyrevolve
- name: Tests in examples
run: py.test --cov --cov-config=.coveragerc --cov-report=xml examples/
Expand Down
6 changes: 5 additions & 1 deletion devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from devito.data.allocators import * # noqa
from devito.logger import error, warning, info, set_log_level # noqa
from devito.mpi import MPI # noqa
from devito.checkpointing import DevitoCheckpoint, CheckpointOperator # noqa
from devito.checkpointing import pyrevolve # noqa
try:
from devito.checkpointing import DevitoCheckpoint, CheckpointOperator # noqa
except ImportError:
pass

# Imports required to initialize Devito
from devito.arch import compiler_registry, platform_registry
Expand Down
6 changes: 5 additions & 1 deletion devito/checkpointing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .checkpoint import * # noqa
try:
import pyrevolve as pyrevolve
from .checkpoint import * # noqa
except ImportError:
pyrevolve = None
21 changes: 11 additions & 10 deletions examples/seismic/acoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator
import devito
from devito import Function, TimeFunction, pyrevolve
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
ForwardOperator, AdjointOperator, GradientOperator, BornOperator
)
from pyrevolve import Revolver


class AcousticWaveSolver(object):
Expand Down Expand Up @@ -194,19 +194,20 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))

if checkpointing:
if checkpointing and pyrevolve is not None:
u = TimeFunction(name='u', grid=self.model.grid,
time_order=2, space_order=self.space_order)
cp = DevitoCheckpoint([u])
cp = devito.DevitoCheckpoint([u])
n_checkpoints = None
wrap_fw = CheckpointOperator(self.op_fwd(save=False),
src=src or self.geometry.src,
u=u, dt=dt, **kwargs)
wrap_rev = CheckpointOperator(self.op_grad(save=False), u=u, v=v,
rec=rec, dt=dt, grad=grad, **kwargs)
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False),
src=src or self.geometry.src,
u=u, dt=dt, **kwargs)
wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), u=u, v=v,
rec=rec, dt=dt, grad=grad, **kwargs)

# Run forward
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0]-2)
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
20 changes: 11 additions & 9 deletions examples/seismic/tti/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# coding: utf-8
from devito import Function, TimeFunction, warning, DevitoCheckpoint, CheckpointOperator
import devito
from devito import Function, TimeFunction, warning, pyrevolve
from devito.tools import memoized_meth
from examples.seismic.tti.operators import ForwardOperator, AdjointOperator
from examples.seismic.tti.operators import JacobianOperator, JacobianAdjOperator
from examples.seismic.tti.operators import particle_velocity_fields
from pyrevolve import Revolver


class AnisotropicWaveSolver(object):
Expand Down Expand Up @@ -350,20 +350,22 @@ def jacobian_adjoint(self, rec, u0, v0, du=None, dv=None, dm=None, model=None,
if self.model.dim < 3:
kwargs.pop('phi', None)

if checkpointing:
if checkpointing and pyrevolve is not None:
u0 = TimeFunction(name='u0', grid=self.model.grid,
time_order=2, space_order=self.space_order)
v0 = TimeFunction(name='v0', grid=self.model.grid,
time_order=2, space_order=self.space_order)
cp = DevitoCheckpoint([u0, v0])
cp = devito.DevitoCheckpoint([u0, v0])
n_checkpoints = None
wrap_fw = CheckpointOperator(self.op_fwd(save=False), src=self.geometry.src,
u=u0, v=v0, dt=dt, **kwargs)
wrap_rev = CheckpointOperator(self.op_jacadj(save=False), u0=u0, v0=v0,
du=du, dv=dv, rec=rec, dm=dm, dt=dt, **kwargs)
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), u=u0, v=v0,
dt=dt, src=self.geometry.src, **kwargs)
wrap_rev = devito. CheckpointOperator(self.op_jacadj(save=False), u0=u0,
v0=v0, du=du, dv=dv, rec=rec, dm=dm,
dt=dt, **kwargs)

# Run forward
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0]-2)
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
22 changes: 11 additions & 11 deletions examples/seismic/viscoacoustic/wavesolver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from devito import (VectorTimeFunction, TimeFunction, Function, NODE,
DevitoCheckpoint, CheckpointOperator)
import devito
from devito import VectorTimeFunction, TimeFunction, Function, NODE, pyrevolve
from devito.tools import memoized_meth
from examples.seismic import PointSource
from examples.seismic.viscoacoustic.operators import (
ForwardOperator, AdjointOperator, GradientOperator, BornOperator
)
from pyrevolve import Revolver


class ViscoacousticWaveSolver(object):
Expand Down Expand Up @@ -263,7 +262,7 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
# Pick vp and physical parameters from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))

if checkpointing:
if checkpointing and pyrevolve is not None:
if self.time_order == 1:
v = VectorTimeFunction(name="v", grid=self.model.grid,
time_order=self.time_order,
Expand All @@ -278,10 +277,10 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
space_order=self.space_order, staggered=NODE)

l = [p, r] + v.values() if self.time_order == 1 else [p, r]
cp = DevitoCheckpoint(l)
cp = devito.DevitoCheckpoint(l)
n_checkpoints = None
wrap_fw = CheckpointOperator(self.op_fwd(save=False),
src=self.geometry.src, p=p, r=r, dt=dt, **kwargs)
wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), p=p, r=r, dt=dt,
src=self.geometry.src, **kwargs)

ra = TimeFunction(name="ra", grid=self.model.grid, time_order=self.time_order,
space_order=self.space_order, staggered=NODE)
Expand All @@ -295,12 +294,13 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No
kwargs.update({k.name: k for k in va})
kwargs['time_m'] = 0

wrap_rev = CheckpointOperator(self.op_grad(save=False), p=p, pa=pa, r=ra,
rec=rec, dt=dt, grad=grad, **kwargs)
wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), p=p, pa=pa,
r=ra, rec=rec, dt=dt, grad=grad,
**kwargs)

# Run forward
wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints,
rec.data.shape[0] - (1 if self.time_order == 1 else 2))
ntchk = rec.data.shape[0] - (1 if self.time_order == 1 else 2)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, ntchk)
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
Expand Down
15 changes: 9 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys

from devito import Eq, configuration # noqa
from devito.checkpointing import pyrevolve
from devito.finite_differences.differentiable import EvalDerivative
from devito.arch import Cpu64, Device, sniff_mpi_distro, Arm
from devito.arch.compiler import compiler_registry, IntelCompiler, NvidiaCompiler
Expand All @@ -23,19 +24,17 @@ def skipif(items, whole_module=False):
# Sanity check
accepted = set()
accepted.update({'device', 'device-C', 'device-openmp', 'device-openacc',
'device-aomp', 'cpu64-icc', 'cpu64-nvc', 'cpu64-arm'})
'device-aomp', 'cpu64-icc', 'cpu64-nvc', 'cpu64-arm', 'chkpnt'})
accepted.update({'nompi', 'nodevice'})
unknown = sorted(set(items) - accepted)
if unknown:
raise ValueError("Illegal skipif argument(s) `%s`" % unknown)
skipit = False
for i in items:
# Skip if no MPI
if i == 'nompi':
if MPI is None:
skipit = "mpi4py/MPI not installed"
break
continue
if i == 'nompi' and MPI is None:
skipit = "mpi4py/MPI not installed"
break
# Skip if won't run on GPUs
if i == 'device' and isinstance(configuration['platform'], Device):
skipit = "device `%s` unsupported" % configuration['platform'].name
Expand Down Expand Up @@ -73,6 +72,10 @@ def skipif(items, whole_module=False):
if i == 'cpu64-arm' and isinstance(configuration['platform'], Arm):
skipit = "Arm doesn't support x86-specific instructions"
break
# Skip if pyrevolve not installed
if i == 'chkpnt' and pyrevolve is None:
skipit = "pyrevolve not installed"
break

if skipit is False:
return pytest.mark.skipif(False, reason='')
Expand Down
25 changes: 14 additions & 11 deletions tests/test_checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import reduce

import pytest
from pyrevolve import Revolver
import numpy as np

from conftest import skipif
import devito
from devito import (Grid, TimeFunction, Operator, Function, Eq, switchconfig, Constant,
DevitoCheckpoint, CheckpointOperator)
pyrevolve)
from examples.seismic.acoustic.acoustic_example import acoustic_setup


Expand Down Expand Up @@ -107,6 +108,7 @@ def test_segmented_averaging():
assert (f_ref.data_with_halo[1, -1] == 1.).all()


@skipif('chkpnt')
@switchconfig(log_level='WARNING')
@pytest.mark.parametrize('space_order', [4])
@pytest.mark.parametrize('kernel', ['OT2'])
Expand All @@ -129,12 +131,12 @@ def test_forward_with_breaks(shape, kernel, space_order):
dt = solver.model.critical_dt

u = TimeFunction(name='u', grid=grid, time_order=2, space_order=space_order)
cp = DevitoCheckpoint([u])
wrap_fw = CheckpointOperator(solver.op_fwd(save=False), rec=rec,
src=solver.geometry.src, u=u, dt=dt)
wrap_rev = CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec)
cp = devito.DevitoCheckpoint([u])
wrap_fw = devito.CheckpointOperator(solver.op_fwd(save=False), rec=rec,
src=solver.geometry.src, u=u, dt=dt)
wrap_rev = devito.CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec)

wrp = Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order)
rec1, u1, summary = solver.forward()

wrp.apply_forward()
Expand All @@ -160,6 +162,7 @@ def test_acoustic_save_and_nosave(shape=(50, 50), spacing=(15.0, 15.0), tn=500.,
assert(np.allclose(rec.data, rec_bk))


@skipif('chkpnt')
def test_index_alignment():
""" A much simpler test meant to ensure that the forward and reverse indices are
correctly aligned (i.e. u * v , where u is the forward field and v the reverse field
Expand Down Expand Up @@ -226,13 +229,13 @@ def test_index_alignment():
# change equations to use new symbols
fwd_eqn_2 = Eq(u_nosave.forward, u_nosave + 1.*const)
fwd_op_2 = Operator(fwd_eqn_2)
cp = DevitoCheckpoint([u_nosave])
wrap_fw = CheckpointOperator(fwd_op_2, constant=1)
cp = devito.DevitoCheckpoint([u_nosave])
wrap_fw = devito.CheckpointOperator(fwd_op_2, constant=1)

prod_eqn_2 = Eq(prod, prod + u_nosave * v)
comb_op_2 = Operator([adj_eqn, prod_eqn_2])
wrap_rev = CheckpointOperator(comb_op_2, constant=1)
wrp = Revolver(cp, wrap_fw, wrap_rev, None, nt)
wrap_rev = devito.CheckpointOperator(comb_op_2, constant=1)
wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, nt)

# Invocation 4
wrp.apply_forward()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class TestGradient(object):

@skipif('cpu64-icc')
@skipif(['chkpnt', 'cpu64-icc'])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
@pytest.mark.parametrize('opt', [('advanced', {'openmp': True}),
('noop', {'openmp': True})])
Expand Down

0 comments on commit c682d38

Please sign in to comment.