Skip to content

Commit

Permalink
compiler: Patch compare_ops for IndexDerivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini authored and mloubout committed Oct 28, 2023
1 parent 5e4e14c commit 0744570
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
5 changes: 5 additions & 0 deletions devito/symbolics/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def compare_ops(e1, e2):
if type(e1) is type(e2) and len(e1.args) == len(e2.args):
if e1.is_Atom:
return True if e1 == e2 else False
elif isinstance(e1, IndexDerivative) and isinstance(e2, IndexDerivative):
if e1.mapper == e2.mapper:
return compare_ops(e1.base, e2.base)
else:
return False
elif e1.is_Indexed and e2.is_Indexed:
return True if e1.base == e2.base else False
else:
Expand Down
29 changes: 29 additions & 0 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from conftest import assert_structure, get_params, get_arrays, check_array
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
Expand Down Expand Up @@ -61,6 +62,34 @@ def test_numeric_coeffs(self):
# Compound expression
Operator(Eq(u, (v*u.dx).dy, coefficients=coeffs), opt=opt).cfunction

@pytest.mark.parametrize('coeffs,expected', [
((7, 7, 7), 1), # We've had a bug triggered by identical coeffs
((5, 7, 9), 3),
])
def test_multiple_cross_derivs(self, coeffs, expected):
grid = Grid(shape=(11, 11, 11), extent=(10., 10., 10.))
x, y, z = grid.dimensions

p = TimeFunction(name='p', grid=grid, space_order=4,
coefficients='symbolic')

c0, c1, c2 = coeffs
coeffs0 = np.full(5, c0)
coeffs1 = np.full(5, c1)
coeffs2 = np.full(5, c2)

subs = Substitutions(Coefficient(1, p, x, coeffs0),
Coefficient(1, p, y, coeffs1),
Coefficient(1, p, z, coeffs2))

eq = Eq(p.forward, p.dy.dz + p.dx.dy, coefficients=subs)

op = Operator(eq, opt=('advanced', {'expand': False}))
op.cfunction

# w0, w1, ...
assert len(op._globals) == expected


class Test1Pass(object):

Expand Down

0 comments on commit 0744570

Please sign in to comment.