Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Patch pickling of GuardFactor and reconstruction #2126

Merged
merged 2 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from devito.ir.support.space import Forward, IterationDirection
from devito.symbolics import CondEq, CondNe
from devito.tools import as_tuple, frozendict
from devito.tools import Pickable, as_tuple, frozendict
from devito.types import Dimension

__all__ = ['GuardFactor', 'GuardBound', 'GuardBoundNext', 'BaseGuardBound',
Expand All @@ -33,7 +33,7 @@ def negated(self):
# *** GuardFactor


class GuardFactor(Guard, CondEq):
class GuardFactor(Guard, CondEq, Pickable):

"""
A guard for factor-based ConditionalDimensions.
Expand All @@ -42,6 +42,8 @@ class GuardFactor(Guard, CondEq):
symbolic relational `d.parent % k == 0`.
"""

__rargs__ = ('d',)

def __new__(cls, d, **kwargs):
assert d.is_Conditional

Expand Down
6 changes: 5 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros)
generate_macros, unevaluate)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -368,6 +368,10 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs):
# Lower all remaining high order symbolic objects
clusters = lower_index_derivatives(clusters, **kwargs)

# Make sure no reconstructions can unpick any of the symbolic
# optimizations performed so far
clusters = unevaluate(clusters)

return ClusterGroup(clusters)

# Compilation -- ScheduleTree level
Expand Down
1 change: 1 addition & 0 deletions devito/passes/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .implicit import * # noqa
from .misc import * # noqa
from .derivatives import * # noqa
from .unevaluate import * # noqa
33 changes: 33 additions & 0 deletions devito/passes/clusters/unevaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sympy

from devito.ir import cluster_pass
from devito.symbolics import reuse_if_untouched, q_leaf
from devito.symbolics.unevaluation import Add, Mul, Pow

__all__ = ['unevaluate']


@cluster_pass
def unevaluate(cluster):
exprs = [_unevaluate(e) for e in cluster.exprs]

return cluster.rebuild(exprs=exprs)


mapper = {
sympy.Add: Add,
sympy.Mul: Mul,
sympy.Pow: Pow
}


def _unevaluate(expr):
if q_leaf(expr):
return expr

args = [_unevaluate(a) for a in expr.args]

try:
return mapper[expr.func](*args)
except KeyError:
return reuse_if_untouched(expr, args)
21 changes: 21 additions & 0 deletions devito/symbolics/unevaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sympy

__all__ = ['Add', 'Mul', 'Pow']


class UnevaluableMixin(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpickin, maybe FrozenExpr ?


def __new__(cls, *args, evaluate=None, **kwargs):
return cls.__base__.__new__(cls, *args, evaluate=False, **kwargs)


class Add(sympy.Add, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__


class Mul(sympy.Mul, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__


class Pow(sympy.Pow, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__
2 changes: 1 addition & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def test_hoisting_symbolic_divs(self):
op = Operator(eq)

assert op._profiler._sections['section0'].sops == 1
assert op.body.body[-1].body[0].body[0].expr.rhs == s0**-s1
assert str(op.body.body[-1].body[0].body[0].expr.rhs) == str(s0**-s1)

@pytest.mark.parametrize('rotate', [False, True])
def test_drop_redundants_after_fusion(self, rotate):
Expand Down
14 changes: 8 additions & 6 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,8 +1612,10 @@ def test_expressions_imperfect_loops(self):
assert outer[0] == middle[0] == inner[0]
assert middle[1] == inner[1]
assert outer[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq0.rhs))
assert middle[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq1.rhs))
assert inner[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq2.rhs))
assert (str(middle[-1].nodes[0].exprs[0].expr.rhs) ==
str(diff2sympy(indexify(eq1.rhs))))
assert (str(inner[-1].nodes[0].exprs[0].expr.rhs) ==
str(diff2sympy(indexify(eq2.rhs))))

def test_equations_emulate_bc(self):
"""
Expand Down Expand Up @@ -1647,8 +1649,8 @@ def test_different_section_nests(self):
op = Operator([eq1, eq2], opt='noop')
trees = retrieve_iteration_tree(op)
assert len(trees) == 2
assert trees[0][-1].nodes[0].exprs[0].expr.rhs == eq1.rhs
assert trees[1][-1].nodes[0].exprs[0].expr.rhs == eq2.rhs
assert str(trees[0][-1].nodes[0].exprs[0].expr.rhs) == str(eq1.rhs)
assert str(trees[1][-1].nodes[0].exprs[0].expr.rhs) == str(eq2.rhs)

@pytest.mark.parametrize('exprs', [
['Eq(ti0[x,y,z], ti0[x,y,z] + t0*2.)', 'Eq(ti0[0,0,z], 0.)'],
Expand All @@ -1673,8 +1675,8 @@ def test_directly_indexed_expression(self, exprs):

trees = retrieve_iteration_tree(op)
assert len(trees) == 2
assert trees[0][-1].nodes[0].exprs[0].expr.rhs == eqs[0].rhs
assert trees[1][-1].nodes[0].exprs[0].expr.rhs == eqs[1].rhs
assert str(trees[0][-1].nodes[0].exprs[0].expr.rhs) == str(eqs[0].rhs)
assert str(trees[1][-1].nodes[0].exprs[0].expr.rhs) == str(eqs[1].rhs)

@pytest.mark.parametrize('shape', [(11, 11), (11, 11, 11)])
def test_equations_mixed_functions(self, shape):
Expand Down
44 changes: 43 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from conftest import skipif
from devito import (Constant, Eq, Function, TimeFunction, SparseFunction, Grid,
Dimension, SubDimension, ConditionalDimension, IncrDimension,
TimeDimension, SteppingDimension, Operator, MPI, Min)
TimeDimension, SteppingDimension, Operator, MPI, Min, solve)
from devito.ir import GuardFactor
from devito.data import LEFT, OWNED
from devito.mpi.halo_scheme import Halo
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
Expand Down Expand Up @@ -251,6 +252,18 @@ def test_shared_data():
assert indexed.name == new_indexed.name


def test_guard_factor():
d = Dimension(name='d')
cd = ConditionalDimension(name='cd', parent=d, factor=4)

gf = GuardFactor(cd)

pkl_gf = pickle.dumps(gf)
new_gf = pickle.loads(pkl_gf)

assert gf == new_gf


def test_receiver():
grid = Grid(shape=(3,))
time_range = TimeAxis(start=0., stop=1000., step=0.1)
Expand Down Expand Up @@ -682,6 +695,35 @@ def test_full_model():
# FIXME: fails randomly when using data.flatten() AND numpy is using MKL


def test_usave_sampled():
grid = Grid(shape=(10, 10, 10))
u = TimeFunction(name="u", grid=grid, time_order=2, space_order=8)

time_range = TimeAxis(start=0, stop=1000, step=1)

factor = Constant(name="factor", value=10, dtype=np.int32)
time_sub = ConditionalDimension(name="time_sub", parent=grid.time_dim,
factor=factor)

u0_save = TimeFunction(name="u0_save", grid=grid, time_dim=time_sub)
src = RickerSource(name="src", grid=grid, time_range=time_range, f0=10)

pde = u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward))

src_term = src.inject(field=u.forward, expr=src)

eqn = [stencil] + src_term
eqn += [Eq(u0_save, u)]
op_fwd = Operator(eqn)

tmp_pickle_op_fn = "tmp_operator.pickle"
pickle.dump(op_fwd, open(tmp_pickle_op_fn, "wb"))
op_new = pickle.load(open(tmp_pickle_op_fn, "rb"))

assert str(op_fwd) == str(op_new)


def test_elemental():
"""
Tests that elemental function doesn't get reconstructed differently
Expand Down