From 351d79d7cb919900bab85e5e9a44b8ca497dad08 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 5 Mar 2024 10:56:18 -0600 Subject: [PATCH] api: fix EvalDerivative and expand arithmetic --- devito/finite_differences/derivative.py | 5 +- devito/finite_differences/differentiable.py | 4 +- devito/finite_differences/rsfd.py | 93 +++++++++---------- devito/finite_differences/tools.py | 2 + devito/types/dense.py | 6 +- .../06_elastic_varying_parameters.ipynb | 82 +++++++--------- examples/userapi/01_dsl.ipynb | 4 +- tests/test_derivatives.py | 42 +++++++-- 8 files changed, 124 insertions(+), 114 deletions(-) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 20ab964a366..9e134541707 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -7,7 +7,7 @@ from .finite_difference import generic_derivative, first_derivative, cross_derivative from .differentiable import Differentiable from .tools import direct, transpose -from .rsfd import difrot +from .rsfd import d45 from devito.tools import as_mapper, as_tuple, filter_ordered, frozendict from devito.types.utils import DimensionTuple @@ -396,8 +396,7 @@ def _eval_fd(self, expr, **kwargs): if self.method == 'RSFD': assert len(self.dims) == 1 assert self.deriv_order == 1 - fdfunc = difrot[expr.grid.dim]['d%s' % self.dims[0].name] - res = fdfunc(expr, self.x0, expand=expand) + res = d45(expr, self.dims[0], x0=self.x0, expand=expand) elif self.side is not None and self.deriv_order == 1: assert self.method == 'FD' res = first_derivative(expr, self.dims[0], self.fd_order, diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index dd908c5e83e..ab4813c322d 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -672,7 +672,7 @@ def _evaluate(self, **kwargs): expr = self.expr._evaluate(**kwargs) if not kwargs.get('expand', True): - return self.func(expr, self.dimensions) + return self._rebuild(expr) values = product(*[list(d.range) for d in self.dimensions]) terms = [] @@ -834,7 +834,7 @@ def _evaluate(self, **kwargs): mapper = {w.subs(d, i): f.weights[n] for n, i in enumerate(d.range)} expr = expr.xreplace(mapper) - return EvalDerivative(expr, base=self.base) + return EvalDerivative(*expr.args, base=self.base) class DiffDerivative(IndexDerivative, DifferentiableOp): diff --git a/devito/finite_differences/rsfd.py b/devito/finite_differences/rsfd.py index 69e96c23f22..3a98336910d 100644 --- a/devito/finite_differences/rsfd.py +++ b/devito/finite_differences/rsfd.py @@ -4,9 +4,8 @@ from devito.types.dimension import StencilDimension from .differentiable import Weights, DiffDerivative from .tools import generate_indices_staggered, fd_weights_registry -from .elementary import sqrt -__all__ = ['drot', 'dxrot', 'dyrot', 'dzrot'] +__all__ = ['drot', 'd45'] smapper = {1: (1, 1, 1), 2: (1, 1, -1), 3: (1, -1, 1), 4: (1, -1, -1)} @@ -20,12 +19,12 @@ def shift(sign, x0): def drot(expr, dim, dir=1, x0=None): """ - Finite difference approximation of the derivative along d1 + Finite difference approximation of the derivative along dir of a Function `f` at point `x0`. Rotated finite differences based on: https://www.sciencedirect.com/science/article/pii/S0165212599000232 - The rotated axis (the four diagonal of a cube) are: + The rotated axis (the four diagonals of a cube) are: d1 = dx/dr x + dz/dl y + dy/dl z d2 = dx/dl x + dz/dl y - dy/dr z d3 = dx/dr x - dz/dl y + dy/dl z @@ -39,8 +38,8 @@ def drot(expr, dim, dir=1, x0=None): if dir > 2 and ndim == 2: return 0 - # Spacing along diagonal - r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions)) + # RSFD scaling + s = 2**(expr.grid.dim - 1) # Center point and indices start, indices = generate_indices_staggered(expr, dim, expr.space_order, x0=x0) @@ -63,15 +62,17 @@ def drot(expr, dim, dir=1, x0=None): signs = smapper[dir][::(1 if ndim == 3 else 2)] # Direction substitutions - dim_mapper = {d: d + signs[di]*i*d.spacing - shift(signs[di], mid)*d.spacing - for (di, d) in enumerate(expr.grid.dimensions)} + dim_mapper = {} + for (di, d) in enumerate(expr.grid.dimensions): + s0 = 0 if mid == adim_start else shift(signs[di], mid)*d.spacing + dim_mapper[d] = d + signs[di]*i*d.spacing - s0 # Create IndexDerivative ui = expr.subs(dim_mapper) - deriv = DiffDerivative(w0*ui, {d: i for d in expr.grid.dimensions}) + deriv = DiffDerivative(w0*ui/(s*dim.spacing), {d: i for d in expr.grid.dimensions}) - return deriv/r + return deriv grid_node = lambda grid: {d: d for d in grid.dimensions} @@ -97,7 +98,7 @@ def check_staggering(func): - grid.dimension center point and NODE staggering """ @wraps(func) - def wrapper(expr, x0=None, expand=True): + def wrapper(expr, dim, x0=None, expand=True): grid = expr.grid x0 = {k: v for k, v in x0.items() if k.is_Space} if expr.staggered is NODE or expr.staggered is None: @@ -107,52 +108,46 @@ def wrapper(expr, x0=None, expand=True): else: cond = False if cond: - return func(expr, x0=x0, expand=expand) + return func(expr, dim, x0=x0, expand=expand) else: raise ValueError('Invalid staggering or x0 for rotated finite differences') return wrapper @check_staggering -def dxrot(expr, x0=None, expand=True): - x = expr.grid.dimensions[0] - r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions)) - s = 2**(expr.grid.dim - 1) - dxrsfd = (drot(expr, x, x0=x0, dir=1) + drot(expr, x, x0=x0, dir=2) + - drot(expr, x, x0=x0, dir=3) + drot(expr, x, x0=x0, dir=4)) - dx45 = r / (s * x.spacing) * dxrsfd - if expand: - return dx45.evaluate - else: - return dx45 - +def d45(expr, dim, x0=None, expand=True): + """ + RSFD approximation of the derivative of `expr` along `dim` at point `x0`. + + Parameters + ---------- + expr : expr-like + Expression for which the derivative is produced. + dim : Dimension + The Dimension w.r.t. which to differentiate. + x0 : dict, optional + Origin of the finite-difference. Defaults to 0 for all dimensions. + expand : bool, optional + Expand the expression. Defaults to True. + """ + # Make sure the grid supports RSFD + if expr.grid.dim == 1: + raise ValueError('RSFD only supported in 2D and 3D') -@check_staggering -def dyrot(expr, x0=None, expand=True): - y = expr.grid.dimensions[1] - r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions)) - s = 2**(expr.grid.dim - 1) - dyrsfd = (drot(expr, y, x0=x0, dir=1) + drot(expr, y, x0=x0, dir=2) - - drot(expr, y, x0=x0, dir=3) - drot(expr, y, x0=x0, dir=4)) - dy45 = r / (s * y.spacing) * dyrsfd - if expand: - return dy45.evaluate - else: - return dy45 + # Diagonals weights + w = dir_weights[(dim.name, expr.grid.dim)] + # RSFD + rsfd = (w[0] * drot(expr, dim, x0=x0, dir=1) + + w[1] * drot(expr, dim, x0=x0, dir=2) + + w[2] * drot(expr, dim, x0=x0, dir=3) + + w[3] * drot(expr, dim, x0=x0, dir=4)) -@check_staggering -def dzrot(expr, x0=None, expand=True): - z = expr.grid.dimensions[-1] - r = sqrt(sum(d.spacing**2 for d in expr.grid.dimensions)) - s = 2**(expr.grid.dim - 1) - dzrsfd = (drot(expr, z, x0=x0, dir=1) - drot(expr, z, x0=x0, dir=2) + - drot(expr, z, x0=x0, dir=3) - drot(expr, z, x0=x0, dir=4)) - dz45 = r / (s * z.spacing) * dzrsfd - if expand: - return dz45.evaluate - else: - return dz45 + # Evaluate + return rsfd._evaluate(expand=expand) -difrot = {2: {'dx': dxrot, 'dy': dzrot}, 3: {'dx': dxrot, 'dy': dyrot, 'dz': dzrot}} +# How to sum d1, d2, d3, d4 depending on the dimension +dir_weights = {('x', 2): (1, 1, 1, 1), ('x', 3): (1, 1, 1, 1), + ('y', 2): (1, -1, 1, -1), ('y', 3): (1, 1, -1, -1), + ('z', 2): (1, -1, 1, -1), ('z', 3): (1, -1, 1, -1)} diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 82de0f08bc9..0b6f9b9fef3 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -65,6 +65,8 @@ def wrapper(expr, *args, **kwargs): "with symbolic coefficients is not currently " "supported") kwargs['coefficients'] = 'symbolic' + else: + kwargs['coefficients'] = expr.coefficients return func(expr, *args, **kwargs) return wrapper diff --git a/devito/types/dense.py b/devito/types/dense.py index e651ee9a164..37d41aba3d2 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -2,10 +2,10 @@ from ctypes import POINTER, Structure, c_int, c_ulong, c_void_p, cast, byref from functools import wraps, reduce from operator import mul +import warnings import numpy as np import sympy -import warnings from psutil import virtual_memory from cached_property import cached_property @@ -44,6 +44,8 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable): Users should not instantiate this class directly. Use Function or SparseFunction (or their subclasses) instead. """ + + # Default method for the finite difference approximation weights computation. _default_fd = 'taylor' # Required by SymPy, otherwise the presence of __getitem__ will make SymPy @@ -71,7 +73,7 @@ def __init_finalize__(self, *args, function=None, **kwargs): # Symbolic (finite difference) coefficients self._coefficients = kwargs.get('coefficients', self._default_fd) - if self._coefficients not in fd_weights_registry.keys(): + if self._coefficients not in fd_weights_registry: if self._coefficients == 'standard': self._coefficients = 'taylor' warnings.warn("The `standard` mode is deprecated and will be removed in " diff --git a/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb b/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb index 54ca308d6a3..5e77fec46b1 100644 --- a/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb +++ b/examples/seismic/tutorials/06_elastic_varying_parameters.ipynb @@ -392,22 +392,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "Operator `Kernel` ran in 0.24 s\n" + "Operator `Kernel` ran in 0.23 s\n" ] }, { "data": { "text/plain": [ "PerformanceSummary([(PerfKey(name='section0', rank=None),\n", - " PerfEntry(time=0.20362499999999983, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.20520499999999994, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section1', rank=None),\n", - " PerfEntry(time=0.006682000000000003, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.005932999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section2', rank=None),\n", - " PerfEntry(time=0.007740000000000005, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.006593000000000015, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section3', rank=None),\n", - " PerfEntry(time=0.007099999999999985, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.005976000000000029, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section4', rank=None),\n", - " PerfEntry(time=0.007404000000000003, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" + " PerfEntry(time=0.00602000000000003, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" ] }, "execution_count": 14, @@ -511,22 +511,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "Operator `Kernel` ran in 0.28 s\n" + "Operator `Kernel` ran in 0.27 s\n" ] }, { "data": { "text/plain": [ "PerformanceSummary([(PerfKey(name='section0', rank=None),\n", - " PerfEntry(time=0.22185899999999997, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.23217499999999996, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section1', rank=None),\n", - " PerfEntry(time=0.011772000000000072, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.008015000000000006, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section2', rank=None),\n", - " PerfEntry(time=0.013601000000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.012265999999999938, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section3', rank=None),\n", - " PerfEntry(time=0.011722999999999928, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.007797000000000036, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section4', rank=None),\n", - " PerfEntry(time=0.012186999999999974, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" + " PerfEntry(time=0.009437999999999978, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" ] }, "execution_count": 17, @@ -699,20 +699,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "Operator `Kernel` ran in 0.22 s\n" + "Operator `Kernel` ran in 0.24 s\n" ] }, { "data": { "text/plain": [ "PerformanceSummary([(PerfKey(name='section0', rank=None),\n", - " PerfEntry(time=0.19614199999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.21262299999999978, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section1', rank=None),\n", - " PerfEntry(time=0.00614999999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.008176999999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section2', rank=None),\n", - " PerfEntry(time=0.006364000000000024, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.008674000000000015, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section3', rank=None),\n", - " PerfEntry(time=0.006212000000000028, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" + " PerfEntry(time=0.008485000000000003, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" ] }, "execution_count": 24, @@ -796,20 +796,20 @@ "name": "stderr", "output_type": "stream", "text": [ - "Operator `Kernel` ran in 0.26 s\n" + "Operator `Kernel` ran in 0.25 s\n" ] }, { "data": { "text/plain": [ "PerformanceSummary([(PerfKey(name='section0', rank=None),\n", - " PerfEntry(time=0.21812400000000012, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.2165199999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section1', rank=None),\n", - " PerfEntry(time=0.011863999999999993, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.00951100000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section2', rank=None),\n", - " PerfEntry(time=0.013489000000000025, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.01492799999999995, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section3', rank=None),\n", - " PerfEntry(time=0.012101000000000027, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" + " PerfEntry(time=0.00861800000000002, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" ] }, "execution_count": 26, @@ -946,28 +946,16 @@ { "data": { "text/latex": [ - "$\\displaystyle \\frac{\\sqrt{h_{x}^{2} + h_{y}^{2}} \\left(\\frac{- \\frac{f(x - h_x, y - h_y)}{2} + \\frac{f(x + h_x, y + h_y)}{2}}{\\sqrt{h_{x}^{2} + h_{y}^{2}}} + \\frac{- \\frac{f(x - h_x, y + 2*h_y)}{2} + \\frac{f(x + h_x, y)}{2}}{\\sqrt{h_{x}^{2} + h_{y}^{2}}}\\right)}{2 h_{x}}$" + "$\\displaystyle \\left(- \\frac{f(x - h_x, y - h_y)}{4 h_{x}} + \\frac{f(x + h_x, y + h_y)}{4 h_{x}}\\right) + \\left(- \\frac{f(x - h_x, y + h_y)}{4 h_{x}} + \\frac{f(x + h_x, y - h_y)}{4 h_{x}}\\right)$" ], "text/plain": [ - " ⎛ f(x - hₓ, y - h_y) f(x + hₓ, y + h_y) f(x - hₓ, y + 2\n", - " ____________ ⎜- ────────────────── + ────────────────── - ───────────────\n", - " ╱ 2 2 ⎜ 2 2 2 \n", - "╲╱ hₓ + h_y ⋅⎜───────────────────────────────────────── + ─────────────────\n", - " ⎜ ____________ ___\n", - " ⎜ ╱ 2 2 ╱ \n", - " ⎝ ╲╱ hₓ + h_y ╲╱ hₓ\n", - "──────────────────────────────────────────────────────────────────────────────\n", - " 2⋅hₓ \n", + " f(x - hₓ, y - h_y) f(x + hₓ, y + h_y) f(x - hₓ, y + h_y) f(x + hₓ, y\n", + "- ────────────────── + ────────────────── + - ────────────────── + ───────────\n", + " 4⋅hₓ 4⋅hₓ 4⋅hₓ 4⋅hₓ\n", "\n", - "⋅h_y) f(x + hₓ, y)⎞\n", - "───── + ────────────⎟\n", - " 2 ⎟\n", - "────────────────────⎟\n", - "_________ ⎟\n", - "2 2 ⎟\n", - " + h_y ⎠\n", - "─────────────────────\n", - " " + " - h_y)\n", + "───────\n", + " " ] }, "execution_count": 32, @@ -1028,22 +1016,22 @@ "name": "stderr", "output_type": "stream", "text": [ - "Operator `Kernel` ran in 0.61 s\n" + "Operator `Kernel` ran in 0.57 s\n" ] }, { "data": { "text/plain": [ "PerformanceSummary([(PerfKey(name='section0', rank=None),\n", - " PerfEntry(time=0.5591969999999996, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.5155960000000004, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section1', rank=None),\n", - " PerfEntry(time=0.011598000000000063, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.011958000000000087, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section2', rank=None),\n", - " PerfEntry(time=0.013013999999999998, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.013548000000000008, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section3', rank=None),\n", - " PerfEntry(time=0.011741999999999987, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", + " PerfEntry(time=0.012433999999999917, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n", " (PerfKey(name='section4', rank=None),\n", - " PerfEntry(time=0.011816999999999944, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" + " PerfEntry(time=0.012537999999999919, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])" ] }, "execution_count": 34, diff --git a/examples/userapi/01_dsl.ipynb b/examples/userapi/01_dsl.ipynb index beedb0b08c6..dec8e756340 100644 --- a/examples/userapi/01_dsl.ipynb +++ b/examples/userapi/01_dsl.ipynb @@ -539,10 +539,10 @@ { "data": { "text/latex": [ - "$\\displaystyle \\frac{\\sqrt{h_{x}^{2} + h_{y}^{2}} \\left(\\frac{f(x, y) - f(x - h_x, y - h_y)}{\\sqrt{h_{x}^{2} + h_{y}^{2}}} + \\frac{f(x, y + h_y) - f(x - h_x, y + 2*h_y)}{\\sqrt{h_{x}^{2} + h_{y}^{2}}}\\right)}{2 h_{x}}$" + "$\\displaystyle \\left(\\frac{f(x, y)}{2 h_{x}} - \\frac{f(x - h_x, y - h_y)}{2 h_{x}}\\right) + \\left(\\frac{f(x, y)}{2 h_{x}} - \\frac{f(x - h_x, y + h_y)}{2 h_{x}}\\right)$" ], "text/plain": [ - "sqrt(h_x**2 + h_y**2)*((f(x, y) - f(x - h_x, y - h_y))/sqrt(h_x**2 + h_y**2) + (f(x, y + h_y) - f(x - h_x, y + 2*h_y))/sqrt(h_x**2 + h_y**2))/(2*h_x)" + "f(x, y)/(2*h_x) - f(x - h_x, y - h_y)/(2*h_x) + f(x, y)/(2*h_x) - f(x - h_x, y + h_y)/(2*h_x)" ] }, "execution_count": 15, diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index af4ae0e4eed..8fd43630e56 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -6,7 +6,8 @@ ConditionalDimension, left, right, centered, div, grad) from devito.finite_differences import Derivative, Differentiable from devito.finite_differences.differentiable import (Add, EvalDerivative, IndexSum, - IndexDerivative, Weights) + IndexDerivative, Weights, + DiffDerivative) from devito.symbolics import indexify, retrieve_indexed from devito.types.dimension import StencilDimension @@ -768,7 +769,7 @@ def test_index_derivative(self): idxder = IndexDerivative(ui*w, {x: i}) - assert idxder.evaluate == -0.5*u + 0.5*ui.subs(i, 2) + assert simplify(idxder.evaluate - (-0.5*u + 0.5*ui.subs(i, 2))) == 0 # Make sure subs works as expected v = Function(name="v", grid=grid, space_order=so) @@ -787,10 +788,33 @@ def test_dx2(self): assert isinstance(term0, EvalDerivative) term1 = f.dx2._evaluate(expand=False) - assert isinstance(term1, IndexDerivative) + assert isinstance(term1, DiffDerivative) assert term1.depth == 1 term1 = term1.evaluate - assert isinstance(term1, Add) # devito.fd.Add + + assert isinstance(term1, EvalDerivative) # devito.fd.Add + + # Check that the first partially evaluated then fully evaluated + # `term1` matches up the fully evaluated `term0` + assert EvalDerivative(*term0.args) == term1 + + def test_dx45(self): + grid = Grid(shape=(4, 4)) + + f = TimeFunction(name='f', grid=grid, space_order=4) + + term0 = f.dx45.evaluate + assert len(term0.args) == 2 + + term1 = f.dx45._evaluate(expand=False) + assert len(term1.args) == 2 + for i in range(2): + assert isinstance(term0.args[i], EvalDerivative) + assert isinstance(term1.args[i], DiffDerivative) + assert term1.args[i].depth == 1 + + term1 = term1.evaluate + assert isinstance(term1, Add) # devito.fd.EvalDerivative # Check that the first partially evaluated then fully evaluated # `term1` matches up the fully evaluated `term0` @@ -805,14 +829,14 @@ def test_dxdy(self): assert isinstance(term0, EvalDerivative) term1 = f.dx.dy._evaluate(expand=False) - assert isinstance(term1, IndexDerivative) + assert isinstance(term1, DiffDerivative) assert term1.depth == 2 term1 = term1.evaluate - assert isinstance(term1, Add) # devito.fd.Add + assert isinstance(term1, EvalDerivative) # devito.fd.Add # Through expansion and casting we also check that `term0` # is indeed mathematically equivalent to `term1` - assert Add(*term0.expand().args) == term1.expand() + assert EvalDerivative(*term0.expand().args) == term1.expand() def test_dxdy_v2(self): grid = Grid(shape=(4, 4)) @@ -820,7 +844,7 @@ def test_dxdy_v2(self): f = TimeFunction(name='f', grid=grid, space_order=4) term1 = f.dxdy._evaluate(expand=False) - assert len(term1.find(IndexDerivative)) == 2 + assert len(term1.find(DiffDerivative)) == 2 def test_transpose(self): grid = Grid(shape=(4, 4)) @@ -830,7 +854,7 @@ def test_transpose(self): f = TimeFunction(name='f', grid=grid, space_order=4) term = f.dx.T._evaluate(expand=False) - assert isinstance(term, IndexDerivative) + assert isinstance(term, DiffDerivative) i0, = term.dimensions assert term.base == f.subs(x, x + i0*h_x)