-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Revamp CIRE exploiting EvalDerivative #1688
Changes from all commits
41f5eda
6027a0e
ba51d19
3f1eb42
05ff539
0d38bfc
0fad12f
b765021
f99fb6b
60fa2a1
850aa6f
5bcc4da
e11f526
1761daa
076bc8f
6bccd04
24c9d63
b004148
9b2a32c
f87d46f
b5b8630
605d183
9478fd6
437d4ff
428a218
d94ea89
4f25a16
30276f8
6dfb790
39d3643
a01f44d
81f1135
29100de
c68cc44
1b8931d
86d450a
4f3cd32
cadb4e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,10 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
|
||
from devito.core.operator import CoreOperator, CustomOperator | ||
from devito.exceptions import InvalidOperator | ||
from devito.passes.equations import buffering, collect_derivatives | ||
from devito.passes.clusters import (Lift, blocking, cire, cse, eliminate_arrays, | ||
extract_increments, factorize, fuse, optimize_pows) | ||
from devito.passes.clusters import (Lift, blocking, cire, cse, extract_increments, | ||
factorize, fuse, optimize_pows) | ||
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, mpiize, | ||
optimize_halospots, hoist_prodders, relax_incr_dimensions) | ||
from devito.tools import timed_pass | ||
|
@@ -24,17 +22,17 @@ class Cpu64OperatorMixin(object): | |
3 => "blocks", "sub-blocks", and "sub-sub-blocks", ... | ||
""" | ||
|
||
CIRE_MINCOST_INV = 50 | ||
CIRE_MINGAIN = 10 | ||
""" | ||
Minimum operation count of a Dimension-invariant aliasing expression to be | ||
optimized away. Dimension-invariant aliases are lifted outside of one or more | ||
invariant loop(s), so they require tensor temporaries that can be potentially | ||
very large (e.g., the whole domain in the case of time-invariant aliases). | ||
Minimum operation count reduction for a redundant expression to be optimized | ||
away. Higher (lower) values make a redundant expression less (more) likely to | ||
be optimized away. | ||
""" | ||
|
||
CIRE_MINCOST_SOPS = 10 | ||
CIRE_SCHEDULE = 'automatic' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Minimum operation count of a sum-of-product aliasing expression to be optimized away. | ||
Strategy used to schedule derivatives across loops. This impacts the operational | ||
intensity of the generated kernel. | ||
""" | ||
|
||
PAR_COLLAPSE_NCORES = 4 | ||
|
@@ -88,15 +86,9 @@ def _normalize_kwargs(cls, **kwargs): | |
o['min-storage'] = oo.pop('min-storage', False) | ||
o['cire-rotate'] = oo.pop('cire-rotate', False) | ||
o['cire-maxpar'] = oo.pop('cire-maxpar', False) | ||
o['cire-maxalias'] = oo.pop('cire-maxalias', False) | ||
o['cire-ftemps'] = oo.pop('cire-ftemps', False) | ||
o['cire-mincost'] = { | ||
'invariants': { | ||
'scalar': np.inf, | ||
'tensor': oo.pop('cire-mincost-inv', cls.CIRE_MINCOST_INV), | ||
}, | ||
'sops': oo.pop('cire-mincost-sops', cls.CIRE_MINCOST_SOPS) | ||
} | ||
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN) | ||
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE) | ||
|
||
# Shared-memory parallelism | ||
o['par-collapse-ncores'] = oo.pop('par-collapse-ncores', cls.PAR_COLLAPSE_NCORES) | ||
|
@@ -173,18 +165,16 @@ def _specialize_clusters(cls, clusters, **kwargs): | |
# Blocking to improve data locality | ||
clusters = blocking(clusters, options) | ||
|
||
# Reduce flops (potential arithmetic alterations) | ||
# Reduce flops | ||
clusters = extract_increments(clusters, sregistry) | ||
clusters = cire(clusters, 'sops', sregistry, options, platform) | ||
clusters = factorize(clusters) | ||
clusters = optimize_pows(clusters) | ||
|
||
# The previous passes may have created fusion opportunities, which in | ||
# turn may enable further optimizations | ||
# The previous passes may have created fusion opportunities | ||
clusters = fuse(clusters) | ||
clusters = eliminate_arrays(clusters) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for reviewers: yep, we can drop a whole pass, as thanks to the many generalisations it has become unnecessary (while being much weaker then what we have now) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By |
||
|
||
# Reduce flops (no arithmetic alterations) | ||
# Reduce flops | ||
clusters = cse(clusters, sregistry) | ||
|
||
return clusters | ||
|
@@ -260,10 +250,8 @@ def _specialize_clusters(cls, clusters, **kwargs): | |
clusters = factorize(clusters) | ||
clusters = optimize_pows(clusters) | ||
|
||
# The previous passes may have created fusion opportunities, which in | ||
# turn may enable further optimizations | ||
# The previous passes may have created fusion opportunities | ||
clusters = fuse(clusters) | ||
clusters = eliminate_arrays(clusters) | ||
|
||
# Reduce flops (no arithmetic alterations) | ||
clusters = cse(clusters, sregistry) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,7 @@ | |
from devito.exceptions import InvalidOperator | ||
from devito.passes.equations import collect_derivatives, buffering | ||
from devito.passes.clusters import (Lift, Streaming, Tasker, blocking, cire, cse, | ||
eliminate_arrays, extract_increments, factorize, | ||
fuse, optimize_pows) | ||
extract_increments, factorize, fuse, optimize_pows) | ||
from devito.passes.iet import (DeviceOmpTarget, DeviceAccTarget, optimize_halospots, | ||
mpiize, hoist_prodders, is_on_device) | ||
from devito.tools import as_tuple, timed_pass | ||
|
@@ -26,17 +25,17 @@ class DeviceOperatorMixin(object): | |
3 => "blocks", "sub-blocks", and "sub-sub-blocks", ... | ||
""" | ||
|
||
CIRE_MINCOST_INV = 50 | ||
CIRE_MINGAIN = 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as cpu |
||
""" | ||
Minimum operation count of a Dimension-invariant aliasing expression to be | ||
optimized away. Dimension-invariant aliases are lifted outside of one or more | ||
invariant loop(s), so they require tensor temporaries that can be potentially | ||
very large (e.g., the whole domain in the case of time-invariant aliases). | ||
Minimum operation count reduction for a redundant expression to be optimized | ||
away. Higher (lower) values make a redundant expression less (more) likely to | ||
be optimized away. | ||
""" | ||
|
||
CIRE_MINCOST_SOPS = 10 | ||
CIRE_SCHEDULE = 'automatic' | ||
""" | ||
Minimum operation count of a sum-of-product aliasing expression to be optimized away. | ||
Strategy used to schedule derivatives across loops. This impacts the operational | ||
intensity of the generated kernel. | ||
""" | ||
|
||
PAR_CHUNK_NONAFFINE = 3 | ||
|
@@ -69,15 +68,9 @@ def _normalize_kwargs(cls, **kwargs): | |
o['min-storage'] = False | ||
o['cire-rotate'] = False | ||
o['cire-maxpar'] = oo.pop('cire-maxpar', True) | ||
o['cire-maxalias'] = oo.pop('cire-maxalias', False) | ||
o['cire-ftemps'] = oo.pop('cire-ftemps', False) | ||
o['cire-mincost'] = { | ||
'invariants': { | ||
'scalar': 1, | ||
'tensor': oo.pop('cire-mincost-inv', cls.CIRE_MINCOST_INV), | ||
}, | ||
'sops': oo.pop('cire-mincost-sops', cls.CIRE_MINCOST_SOPS) | ||
} | ||
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN) | ||
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be easier to have a an abstract operator class to avoid this code duplication in cpu and gpu? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about it (obviously) but in reality there are quite a few differences between the two, so I'd rather keep it separated |
||
|
||
# GPU parallelism | ||
o['par-collapse-ncores'] = 1 # Always use a collapse clause | ||
|
@@ -156,19 +149,17 @@ def _specialize_clusters(cls, clusters, **kwargs): | |
clusters = cire(clusters, 'invariants', sregistry, options, platform) | ||
clusters = Lift().process(clusters) | ||
|
||
# Reduce flops (potential arithmetic alterations) | ||
# Reduce flops | ||
clusters = extract_increments(clusters, sregistry) | ||
clusters = cire(clusters, 'sops', sregistry, options, platform) | ||
clusters = factorize(clusters) | ||
clusters = optimize_pows(clusters) | ||
|
||
# Reduce flops (no arithmetic alterations) | ||
clusters = cse(clusters, sregistry) | ||
|
||
# Lifting may create fusion opportunities, which in turn may enable | ||
# further optimizations | ||
# The previous passes may have created fusion opportunities | ||
clusters = fuse(clusters) | ||
clusters = eliminate_arrays(clusters) | ||
|
||
# Reduce flops | ||
clusters = cse(clusters, sregistry) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for reviewers: note that now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was it worng? |
||
|
||
return clusters | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
from devito.finite_differences.finite_difference import (generic_derivative, | ||
first_derivative, | ||
cross_derivative) | ||
from devito.finite_differences.differentiable import Differentiable, EvalDerivative | ||
from devito.finite_differences.differentiable import Differentiable | ||
from devito.finite_differences.tools import direct, transpose | ||
from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict | ||
from devito.types.utils import DimensionTuple | ||
|
@@ -333,9 +333,6 @@ def _eval_fd(self, expr): | |
- 3: Evaluate remaining terms (as `g` may need to be evaluated | ||
at a different point). | ||
- 4: Apply substitutions. | ||
- 5: Cast to an object of type `EvalDerivative` so that we know | ||
the argument stems from a `Derivative. This may be useful for | ||
later compilation passes. | ||
""" | ||
# Step 1: Evaluate derivatives within expression | ||
expr = getattr(expr, '_eval_deriv', expr) | ||
|
@@ -359,8 +356,4 @@ def _eval_fd(self, expr): | |
for e in self._ppsubs: | ||
res = res.xreplace(e) | ||
|
||
# Step 5: Cast to EvaluatedDerivative | ||
assert res.is_Add | ||
res = EvalDerivative(*res.args, evaluate=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for reviewers: now unnecessary as done directly inside finite_difference.py |
||
|
||
return res |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,17 +2,19 @@ | |
from functools import singledispatch | ||
|
||
import sympy | ||
from sympy.core.add import _addsort | ||
from sympy.core.mul import _mulsort | ||
from sympy.core.decorators import call_highest_priority | ||
from sympy.core.evalf import evalf_table | ||
|
||
from cached_property import cached_property | ||
from devito.finite_differences.tools import make_shift_x0 | ||
from devito.logger import warning | ||
from devito.tools import filter_ordered, flatten | ||
from devito.tools import filter_ordered, flatten, split | ||
from devito.types.lazy import Evaluable | ||
from devito.types.utils import DimensionTuple | ||
|
||
__all__ = ['Differentiable'] | ||
__all__ = ['Differentiable', 'EvalDerivative'] | ||
|
||
|
||
class Differentiable(sympy.Expr, Evaluable): | ||
|
@@ -300,6 +302,11 @@ class DifferentiableOp(Differentiable): | |
__sympy_class__ = None | ||
|
||
def __new__(cls, *args, **kwargs): | ||
# Do not re-evaluate if any of the args is an EvalDerivative, | ||
# since the integrity of these objects must be preserved | ||
if any(isinstance(i, EvalDerivative) for i in args): | ||
kwargs['evaluate'] = False | ||
|
||
obj = cls.__base__.__new__(cls, *args, **kwargs) | ||
|
||
# Unfortunately SymPy may build new sympy.core objects (e.g., sympy.Add), | ||
|
@@ -363,12 +370,54 @@ def _eval_at(self, func): | |
|
||
class Add(DifferentiableOp, sympy.Add): | ||
__sympy_class__ = sympy.Add | ||
__new__ = DifferentiableOp.__new__ | ||
|
||
def __new__(cls, *args, **kwargs): | ||
# Here, often we get `evaluate=False` to prevent SymPy evaluation (e.g., | ||
# when `cls==EvalDerivative`), but in all cases we at least apply a small | ||
# set of basic simplifications | ||
|
||
# (a+b)+c -> a+b+c (flattening) | ||
nested, others = split(args, lambda e: isinstance(e, Add)) | ||
args = flatten(e.args for e in nested) + list(others) | ||
|
||
# a+0 -> a | ||
args = [i for i in args if i != 0] | ||
|
||
# Reorder for homogeneity with pure SymPy types | ||
_addsort(args) | ||
|
||
return super().__new__(cls, *args, **kwargs) | ||
|
||
|
||
class Mul(DifferentiableOp, sympy.Mul): | ||
__sympy_class__ = sympy.Mul | ||
__new__ = DifferentiableOp.__new__ | ||
|
||
def __new__(cls, *args, **kwargs): | ||
# A Mul, being a DifferentiableOp, may not trigger evaluation upon | ||
# construction (e.g., when an EvalDerivative is present among its | ||
# arguments), so here we apply a small set of basic simplifications | ||
# to avoid generating functional, but also ugly, code | ||
|
||
# (a*b)*c -> a*b*c (flattening) | ||
nested, others = split(args, lambda e: isinstance(e, Mul)) | ||
args = flatten(e.args for e in nested) + list(others) | ||
|
||
# a*0 -> 0 | ||
mloubout marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if any(i == 0 for i in args): | ||
return sympy.S.Zero | ||
|
||
# a*1 -> a | ||
args = [i for i in args if i != 1] | ||
|
||
# a*-1*-1 -> a | ||
nminus = len([i for i in args if i == sympy.S.NegativeOne]) | ||
if nminus % 2 == 0: | ||
args = [i for i in args if i != sympy.S.NegativeOne] | ||
|
||
# Reorder for homogeneity with pure SymPy types | ||
_mulsort(args) | ||
|
||
return super().__new__(cls, *args, **kwargs) | ||
|
||
@property | ||
def _gather_for_diff(self): | ||
|
@@ -411,17 +460,46 @@ def _gather_for_diff(self): | |
class Pow(DifferentiableOp, sympy.Pow): | ||
_fd_priority = 0 | ||
__sympy_class__ = sympy.Pow | ||
__new__ = DifferentiableOp.__new__ | ||
|
||
|
||
class Mod(DifferentiableOp, sympy.Mod): | ||
__sympy_class__ = sympy.Mod | ||
__new__ = DifferentiableOp.__new__ | ||
|
||
|
||
class EvalDerivative(DifferentiableOp, sympy.Add): | ||
__sympy_class__ = sympy.Add | ||
__new__ = DifferentiableOp.__new__ | ||
|
||
is_commutative = True | ||
|
||
def __new__(cls, *args, base=None, **kwargs): | ||
kwargs['evaluate'] = False | ||
|
||
# a+0 -> a | ||
args = [i for i in args if i != 0] | ||
|
||
# Reorder for homogeneity with pure SymPy types | ||
_addsort(args) | ||
|
||
obj = super().__new__(cls, *args, **kwargs) | ||
|
||
try: | ||
obj.base = base | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the need for that one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CIRE won't optimize EvalDerivatives if the base is a pure Function, such as |
||
except AttributeError: | ||
# This might happen if e.g. one attempts a (re)construction with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually there may be even weirder case. Like consider first order derivative with coefficient There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried, but I wasn't able to produce them, not sure why... any chance you can try writing a reproducer? |
||
# one sole argument. The (re)constructed EvalDerivative degenerates | ||
# to an object of different type, in classic SymPy style. That's fine | ||
assert len(args) <= 1 | ||
assert not obj.is_Add | ||
return obj | ||
|
||
return obj | ||
|
||
@property | ||
def func(self): | ||
return lambda *a, **kw: EvalDerivative(*a, base=self.base, **kw) | ||
|
||
def _new_rawargs(self, *args, **kwargs): | ||
kwargs.pop('is_commutative', None) | ||
return self.func(*args, **kwargs) | ||
|
||
|
||
class diffify(object): | ||
|
@@ -502,6 +580,9 @@ def _diff2sympy(obj): | |
except AttributeError: | ||
# Not of type DifferentiableOp | ||
pass | ||
except TypeError: | ||
# Won't lower (e.g., EvalDerivative) | ||
pass | ||
if flag: | ||
return obj.func(*args, evaluate=False), True | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want backward compatibility?
And still use it if set (or throw a warning).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I'm going to deprecate stuff and raise warnings accordingly