From a5611fce3af4d7dd7f9ce50744d8a231dbf0b228 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 3 May 2023 09:50:46 -0400 Subject: [PATCH] checkpointing: switch to noop operators if unavailable --- devito/__init__.py | 6 +++-- devito/checkpointing/__init__.py | 23 ++++++++++++++++++-- examples/seismic/acoustic/wavesolver.py | 20 ++++++++--------- examples/seismic/tti/wavesolver.py | 20 ++++++++--------- examples/seismic/viscoacoustic/wavesolver.py | 21 +++++++++--------- tests/conftest.py | 6 ++--- tests/test_checkpointing.py | 21 +++++++++--------- 7 files changed, 66 insertions(+), 51 deletions(-) diff --git a/devito/__init__.py b/devito/__init__.py index 25a4d52d9c9..cd61626b28f 100644 --- a/devito/__init__.py +++ b/devito/__init__.py @@ -23,11 +23,13 @@ from devito.data.allocators import * # noqa from devito.logger import error, warning, info, set_log_level # noqa from devito.mpi import MPI # noqa -from devito.checkpointing import pyrevolve # noqa try: from devito.checkpointing import DevitoCheckpoint, CheckpointOperator # noqa + from pyrevolve import Revolver except ImportError: - pass + from devito.checkpointing import NoopCheckpoint as DevitoCheckpoint # noqa + from devito.checkpointing import NoopCheckpointOperator as CheckpointOperator # noqa + from devito.checkpointing import NoopRevolver as Revolver # noqa # Imports required to initialize Devito from devito.arch import compiler_registry, platform_registry diff --git a/devito/checkpointing/__init__.py b/devito/checkpointing/__init__.py index 5eee68e3753..22bca408b78 100644 --- a/devito/checkpointing/__init__.py +++ b/devito/checkpointing/__init__.py @@ -1,5 +1,24 @@ try: - import pyrevolve as pyrevolve + import pyrevolve as pyrevolve # noqa from .checkpoint import * # noqa except ImportError: - pyrevolve = None + pass + + +class Noop(object): + """ Dummy replacement in case pyrevolve isn't available. """ + + def __init__(self, *args, **kwargs): + raise ImportError("Missing required `pyrevolve`; cannot use checkpointing") + + +class NoopCheckpointOperator(Noop): + pass + + +class NoopCheckpoint(Noop): + pass + + +class NoopRevolver(Noop): + pass diff --git a/examples/seismic/acoustic/wavesolver.py b/examples/seismic/acoustic/wavesolver.py index 73550fbc6cd..3636a86616d 100644 --- a/examples/seismic/acoustic/wavesolver.py +++ b/examples/seismic/acoustic/wavesolver.py @@ -1,5 +1,4 @@ -import devito -from devito import Function, TimeFunction, pyrevolve +from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver from devito.tools import memoized_meth from examples.seismic.acoustic.operators import ( ForwardOperator, AdjointOperator, GradientOperator, BornOperator @@ -194,20 +193,19 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None, # Pick vp from model unless explicitly provided kwargs.update(model.physical_params(**kwargs)) - if checkpointing and pyrevolve is not None: + if checkpointing: u = TimeFunction(name='u', grid=self.model.grid, time_order=2, space_order=self.space_order) - cp = devito.DevitoCheckpoint([u]) + cp = DevitoCheckpoint([u]) n_checkpoints = None - wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), - src=src or self.geometry.src, - u=u, dt=dt, **kwargs) - wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), u=u, v=v, - rec=rec, dt=dt, grad=grad, **kwargs) + wrap_fw = CheckpointOperator(self.op_fwd(save=False), + src=src or self.geometry.src, + u=u, dt=dt, **kwargs) + wrap_rev = CheckpointOperator(self.op_grad(save=False), u=u, v=v, + rec=rec, dt=dt, grad=grad, **kwargs) # Run forward - wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, - rec.data.shape[0]-2) + wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2) wrp.apply_forward() summary = wrp.apply_reverse() else: diff --git a/examples/seismic/tti/wavesolver.py b/examples/seismic/tti/wavesolver.py index 9f7ea93b7a5..5aef4274019 100644 --- a/examples/seismic/tti/wavesolver.py +++ b/examples/seismic/tti/wavesolver.py @@ -1,6 +1,6 @@ # coding: utf-8 -import devito -from devito import Function, TimeFunction, warning, pyrevolve +from devito import (Function, TimeFunction, warning, + DevitoCheckpoint, CheckpointOperator, Revolver) from devito.tools import memoized_meth from examples.seismic.tti.operators import ForwardOperator, AdjointOperator from examples.seismic.tti.operators import JacobianOperator, JacobianAdjOperator @@ -350,22 +350,20 @@ def jacobian_adjoint(self, rec, u0, v0, du=None, dv=None, dm=None, model=None, if self.model.dim < 3: kwargs.pop('phi', None) - if checkpointing and pyrevolve is not None: + if checkpointing: u0 = TimeFunction(name='u0', grid=self.model.grid, time_order=2, space_order=self.space_order) v0 = TimeFunction(name='v0', grid=self.model.grid, time_order=2, space_order=self.space_order) - cp = devito.DevitoCheckpoint([u0, v0]) + cp = DevitoCheckpoint([u0, v0]) n_checkpoints = None - wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), u=u0, v=v0, - dt=dt, src=self.geometry.src, **kwargs) - wrap_rev = devito. CheckpointOperator(self.op_jacadj(save=False), u0=u0, - v0=v0, du=du, dv=dv, rec=rec, dm=dm, - dt=dt, **kwargs) + wrap_fw = CheckpointOperator(self.op_fwd(save=False), src=self.geometry.src, + u=u0, v=v0, dt=dt, **kwargs) + wrap_rev = CheckpointOperator(self.op_jacadj(save=False), u0=u0, v0=v0, + du=du, dv=dv, rec=rec, dm=dm, dt=dt, **kwargs) # Run forward - wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, - rec.data.shape[0]-2) + wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, rec.data.shape[0]-2) wrp.apply_forward() summary = wrp.apply_reverse() else: diff --git a/examples/seismic/viscoacoustic/wavesolver.py b/examples/seismic/viscoacoustic/wavesolver.py index f9a52c4f6c2..05125807b89 100755 --- a/examples/seismic/viscoacoustic/wavesolver.py +++ b/examples/seismic/viscoacoustic/wavesolver.py @@ -1,5 +1,5 @@ -import devito -from devito import VectorTimeFunction, TimeFunction, Function, NODE, pyrevolve +from devito import (VectorTimeFunction, TimeFunction, Function, NODE, + DevitoCheckpoint, CheckpointOperator, Revolver) from devito.tools import memoized_meth from examples.seismic import PointSource from examples.seismic.viscoacoustic.operators import ( @@ -262,7 +262,7 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No # Pick vp and physical parameters from model unless explicitly provided kwargs.update(model.physical_params(**kwargs)) - if checkpointing and pyrevolve is not None: + if checkpointing: if self.time_order == 1: v = VectorTimeFunction(name="v", grid=self.model.grid, time_order=self.time_order, @@ -277,10 +277,10 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No space_order=self.space_order, staggered=NODE) l = [p, r] + v.values() if self.time_order == 1 else [p, r] - cp = devito.DevitoCheckpoint(l) + cp = DevitoCheckpoint(l) n_checkpoints = None - wrap_fw = devito.CheckpointOperator(self.op_fwd(save=False), p=p, r=r, dt=dt, - src=self.geometry.src, **kwargs) + wrap_fw = CheckpointOperator(self.op_fwd(save=False), + src=self.geometry.src, p=p, r=r, dt=dt, **kwargs) ra = TimeFunction(name="ra", grid=self.model.grid, time_order=self.time_order, space_order=self.space_order, staggered=NODE) @@ -294,13 +294,12 @@ def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, va=None, model=No kwargs.update({k.name: k for k in va}) kwargs['time_m'] = 0 - wrap_rev = devito.CheckpointOperator(self.op_grad(save=False), p=p, pa=pa, - r=ra, rec=rec, dt=dt, grad=grad, - **kwargs) + wrap_rev = CheckpointOperator(self.op_grad(save=False), p=p, pa=pa, r=ra, + rec=rec, dt=dt, grad=grad, **kwargs) # Run forward - ntchk = rec.data.shape[0] - (1 if self.time_order == 1 else 2) - wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, ntchk) + wrp = Revolver(cp, wrap_fw, wrap_rev, n_checkpoints, + rec.data.shape[0] - (1 if self.time_order == 1 else 2)) wrp.apply_forward() summary = wrp.apply_reverse() else: diff --git a/tests/conftest.py b/tests/conftest.py index 734bfb3cb6c..55afa7cede5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,8 @@ import pytest import sys -from devito import Eq, configuration # noqa -from devito.checkpointing import pyrevolve +from devito import Eq, configuration, Revolver # noqa +from devito.checkpointing import NoopRevolver from devito.finite_differences.differentiable import EvalDerivative from devito.arch import Cpu64, Device, sniff_mpi_distro, Arm from devito.arch.compiler import compiler_registry, IntelCompiler, NvidiaCompiler @@ -73,7 +73,7 @@ def skipif(items, whole_module=False): skipit = "Arm doesn't support x86-specific instructions" break # Skip if pyrevolve not installed - if i == 'chkpnt' and pyrevolve is None: + if i == 'chkpnt' and Revolver is NoopRevolver: skipit = "pyrevolve not installed" break diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py index ab295fdffcd..75cca861cc3 100644 --- a/tests/test_checkpointing.py +++ b/tests/test_checkpointing.py @@ -4,9 +4,8 @@ import numpy as np from conftest import skipif -import devito from devito import (Grid, TimeFunction, Operator, Function, Eq, switchconfig, Constant, - pyrevolve) + Revolver, CheckpointOperator, DevitoCheckpoint) from examples.seismic.acoustic.acoustic_example import acoustic_setup @@ -131,12 +130,12 @@ def test_forward_with_breaks(shape, kernel, space_order): dt = solver.model.critical_dt u = TimeFunction(name='u', grid=grid, time_order=2, space_order=space_order) - cp = devito.DevitoCheckpoint([u]) - wrap_fw = devito.CheckpointOperator(solver.op_fwd(save=False), rec=rec, - src=solver.geometry.src, u=u, dt=dt) - wrap_rev = devito.CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec) + cp = DevitoCheckpoint([u]) + wrap_fw = CheckpointOperator(solver.op_fwd(save=False), rec=rec, + src=solver.geometry.src, u=u, dt=dt) + wrap_rev = CheckpointOperator(solver.op_grad(save=False), u=u, dt=dt, rec=rec) - wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order) + wrp = Revolver(cp, wrap_fw, wrap_rev, None, rec._time_range.num-time_order) rec1, u1, summary = solver.forward() wrp.apply_forward() @@ -229,13 +228,13 @@ def test_index_alignment(): # change equations to use new symbols fwd_eqn_2 = Eq(u_nosave.forward, u_nosave + 1.*const) fwd_op_2 = Operator(fwd_eqn_2) - cp = devito.DevitoCheckpoint([u_nosave]) - wrap_fw = devito.CheckpointOperator(fwd_op_2, constant=1) + cp = DevitoCheckpoint([u_nosave]) + wrap_fw = CheckpointOperator(fwd_op_2, constant=1) prod_eqn_2 = Eq(prod, prod + u_nosave * v) comb_op_2 = Operator([adj_eqn, prod_eqn_2]) - wrap_rev = devito.CheckpointOperator(comb_op_2, constant=1) - wrp = pyrevolve.Revolver(cp, wrap_fw, wrap_rev, None, nt) + wrap_rev = CheckpointOperator(comb_op_2, constant=1) + wrp = Revolver(cp, wrap_fw, wrap_rev, None, nt) # Invocation 4 wrp.apply_forward()