Skip to content

Commit

Permalink
compiler: Optimize the -1*-1*... case when building Mul
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed May 12, 2021
1 parent 4f3cd32 commit cadb4e1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class Add(DifferentiableOp, sympy.Add):
def __new__(cls, *args, **kwargs):
# Here, often we get `evaluate=False` to prevent SymPy evaluation (e.g.,
# when `cls==EvalDerivative`), but in all cases we at least apply a small
# sets of basic optimizations
# set of basic simplifications

# (a+b)+c -> a+b+c (flattening)
nested, others = split(args, lambda e: isinstance(e, Add))
Expand All @@ -393,9 +393,10 @@ class Mul(DifferentiableOp, sympy.Mul):
__sympy_class__ = sympy.Mul

def __new__(cls, *args, **kwargs):
# A DifferentiableOp may not trigger evaluation upon construction
# (e.g., if an EvalDerivative is present among the arguments)
# So we treat some special cases here
# A Mul, being a DifferentiableOp, may not trigger evaluation upon
# construction (e.g., when an EvalDerivative is present among its
# arguments), so here we apply a small set of basic simplifications
# to avoid generating functional, but also ugly, code

# (a*b)*c -> a*b*c (flattening)
nested, others = split(args, lambda e: isinstance(e, Mul))
Expand All @@ -408,6 +409,11 @@ def __new__(cls, *args, **kwargs):
# a*1 -> a
args = [i for i in args if i != 1]

# a*-1*-1 -> a
nminus = len([i for i in args if i == sympy.S.NegativeOne])
if nminus % 2 == 0:
args = [i for i in args if i != sympy.S.NegativeOne]

# Reorder for homogeneity with pure SymPy types
_mulsort(args)

Expand Down

0 comments on commit cadb4e1

Please sign in to comment.