Skip to content

Commit

Permalink
compiler: Prevent reconstructions (eg unpickling) from unpicking opti…
Browse files Browse the repository at this point in the history
…mizations
  • Loading branch information
FabioLuporini committed May 15, 2023
1 parent 168df51 commit 80dd5d4
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 9 deletions.
6 changes: 5 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.mpi import MPI
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros)
generate_macros, unevaluate)
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
Expand Down Expand Up @@ -368,6 +368,10 @@ def _lower_clusters(cls, expressions, profiler=None, **kwargs):
# Lower all remaining high order symbolic objects
clusters = lower_index_derivatives(clusters, **kwargs)

# Make sure no reconstructions can unpick any of the symbolic
# optimizations performed so far
clusters = unevaluate(clusters)

return ClusterGroup(clusters)

# Compilation -- ScheduleTree level
Expand Down
1 change: 1 addition & 0 deletions devito/passes/clusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .implicit import * # noqa
from .misc import * # noqa
from .derivatives import * # noqa
from .unevaluate import * # noqa
33 changes: 33 additions & 0 deletions devito/passes/clusters/unevaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sympy

from devito.ir import cluster_pass
from devito.symbolics import reuse_if_untouched, q_leaf
from devito.symbolics.unevaluation import Add, Mul, Pow

__all__ = ['unevaluate']


@cluster_pass
def unevaluate(cluster):
exprs = [_unevaluate(e) for e in cluster.exprs]

return cluster.rebuild(exprs=exprs)


mapper = {
sympy.Add: Add,
sympy.Mul: Mul,
sympy.Pow: Pow
}


def _unevaluate(expr):
if q_leaf(expr):
return expr

args = [_unevaluate(a) for a in expr.args]

try:
return mapper[expr.func](*args)
except KeyError:
return reuse_if_untouched(expr, args)
21 changes: 21 additions & 0 deletions devito/symbolics/unevaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sympy

__all__ = ['Add', 'Mul', 'Pow']


class UnevaluableMixin(object):

def __new__(cls, *args, evaluate=None, **kwargs):
return cls.__base__.__new__(cls, *args, evaluate=False, **kwargs)


class Add(sympy.Add, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__


class Mul(sympy.Mul, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__


class Pow(sympy.Pow, UnevaluableMixin):
__new__ = UnevaluableMixin.__new__
2 changes: 1 addition & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def test_hoisting_symbolic_divs(self):
op = Operator(eq)

assert op._profiler._sections['section0'].sops == 1
assert op.body.body[-1].body[0].body[0].expr.rhs == s0**-s1
assert str(op.body.body[-1].body[0].body[0].expr.rhs) == str(s0**-s1)

@pytest.mark.parametrize('rotate', [False, True])
def test_drop_redundants_after_fusion(self, rotate):
Expand Down
14 changes: 8 additions & 6 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,8 +1612,10 @@ def test_expressions_imperfect_loops(self):
assert outer[0] == middle[0] == inner[0]
assert middle[1] == inner[1]
assert outer[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq0.rhs))
assert middle[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq1.rhs))
assert inner[-1].nodes[0].exprs[0].expr.rhs == diff2sympy(indexify(eq2.rhs))
assert (str(middle[-1].nodes[0].exprs[0].expr.rhs) ==
str(diff2sympy(indexify(eq1.rhs))))
assert (str(inner[-1].nodes[0].exprs[0].expr.rhs) ==
str(diff2sympy(indexify(eq2.rhs))))

def test_equations_emulate_bc(self):
"""
Expand Down Expand Up @@ -1647,8 +1649,8 @@ def test_different_section_nests(self):
op = Operator([eq1, eq2], opt='noop')
trees = retrieve_iteration_tree(op)
assert len(trees) == 2
assert trees[0][-1].nodes[0].exprs[0].expr.rhs == eq1.rhs
assert trees[1][-1].nodes[0].exprs[0].expr.rhs == eq2.rhs
assert str(trees[0][-1].nodes[0].exprs[0].expr.rhs) == str(eq1.rhs)
assert str(trees[1][-1].nodes[0].exprs[0].expr.rhs) == str(eq2.rhs)

@pytest.mark.parametrize('exprs', [
['Eq(ti0[x,y,z], ti0[x,y,z] + t0*2.)', 'Eq(ti0[0,0,z], 0.)'],
Expand All @@ -1673,8 +1675,8 @@ def test_directly_indexed_expression(self, exprs):

trees = retrieve_iteration_tree(op)
assert len(trees) == 2
assert trees[0][-1].nodes[0].exprs[0].expr.rhs == eqs[0].rhs
assert trees[1][-1].nodes[0].exprs[0].expr.rhs == eqs[1].rhs
assert str(trees[0][-1].nodes[0].exprs[0].expr.rhs) == str(eqs[0].rhs)
assert str(trees[1][-1].nodes[0].exprs[0].expr.rhs) == str(eqs[1].rhs)

@pytest.mark.parametrize('shape', [(11, 11), (11, 11, 11)])
def test_equations_mixed_functions(self, shape):
Expand Down
31 changes: 30 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from conftest import skipif
from devito import (Constant, Eq, Function, TimeFunction, SparseFunction, Grid,
Dimension, SubDimension, ConditionalDimension, IncrDimension,
TimeDimension, SteppingDimension, Operator, MPI, Min)
TimeDimension, SteppingDimension, Operator, MPI, Min, solve)
from devito.ir import GuardFactor
from devito.data import LEFT, OWNED
from devito.mpi.halo_scheme import Halo
Expand Down Expand Up @@ -695,6 +695,35 @@ def test_full_model():
# FIXME: fails randomly when using data.flatten() AND numpy is using MKL


def test_usave_sampled():
grid = Grid(shape=(10, 10, 10))
u = TimeFunction(name="u", grid=grid, time_order=2, space_order=8)

time_range = TimeAxis(start=0, stop=1000, step=1)

factor = Constant(name="factor", value=10, dtype=np.int32)
time_sub = ConditionalDimension(name="time_sub", parent=grid.time_dim,
factor=factor)

u0_save = TimeFunction(name="u0_save", grid=grid, time_dim=time_sub)
src = RickerSource(name="src", grid=grid, time_range=time_range, f0=10)

pde = u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward))

src_term = src.inject(field=u.forward, expr=src)

eqn = [stencil] + src_term
eqn += [Eq(u0_save, u)]
op_fwd = Operator(eqn)

tmp_pickle_op_fn = "tmp_operator.pickle"
pickle.dump(op_fwd, open(tmp_pickle_op_fn, "wb"))
op_new = pickle.load(open(tmp_pickle_op_fn, "rb"))

assert str(op_fwd) == str(op_new)


def test_elemental():
"""
Tests that elemental function doesn't get reconstructed differently
Expand Down

0 comments on commit 80dd5d4

Please sign in to comment.