Skip to content

Commit

Permalink
compiler: Patch GuardBoundNext pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 14, 2023
1 parent 56d1b03 commit 136a9b4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
4 changes: 3 additions & 1 deletion devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class GuardBoundGt(BaseGuardBound, Gt):
# *** GuardBoundNext


class BaseGuardBoundNext(Guard):
class BaseGuardBoundNext(Guard, Pickable):

"""
A guard to avoid out-of-bounds iteration.
Expand All @@ -118,6 +118,8 @@ class BaseGuardBoundNext(Guard):
given `direction`.
"""

__rargs__ = ('d', 'direction')

def __new__(cls, d, direction, **kwargs):
assert isinstance(d, Dimension)
assert isinstance(direction, IterationDirection)
Expand Down
24 changes: 23 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Dimension, SubDimension, ConditionalDimension, IncrDimension,
TimeDimension, SteppingDimension, Operator, MPI, Min, solve,
PrecomputedSparseTimeFunction)
from devito.ir import GuardFactor
from devito.ir import Backward, GuardFactor, GuardBound, GuardBoundNext
from devito.data import LEFT, OWNED
from devito.mpi.halo_scheme import Halo
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
Expand Down Expand Up @@ -388,6 +388,28 @@ def test_guard_factor(self, pickle):

assert str(gf) == str(new_gf)

def test_guard_bound(self, pickle):
d = Dimension(name='d')

gb = GuardBound(d, 3)

pkl_gb = pickle.dumps(gb)
new_gb = pickle.loads(pkl_gb)

assert str(gb) == str(new_gb)

def test_guard_bound_next(self, pickle):
d = Dimension(name='d')
cd = ConditionalDimension(name='cd', parent=d, factor=4)

for i in [d, cd]:
gbn = GuardBoundNext(i, Backward)

pkl_gbn = pickle.dumps(gbn)
new_gbn = pickle.loads(pkl_gbn)

assert str(gbn) == str(new_gbn)

def test_temp_function(self, pickle):
grid = Grid(shape=(3, 3))
d = Dimension(name='d')
Expand Down

0 comments on commit 136a9b4

Please sign in to comment.