Skip to content

Commit a59c592

Browse files
committed
GateauxDerivativeRuleset
1 parent 34bbd94 commit a59c592

File tree

1 file changed

+117
-42
lines changed

1 file changed

+117
-42
lines changed

ufl/algorithms/apply_derivatives.py

+117-42
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from ufl.algorithms.analysis import extract_arguments
1818
from ufl.algorithms.map_integrands import map_integrand_dags
1919
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
20-
from ufl.argument import Argument, BaseArgument
20+
from ufl.argument import Argument, BaseArgument, Coargument
2121
from ufl.averaging import CellAvg, FacetAvg
2222
from ufl.checks import is_cellwise_constant
2323
from ufl.classes import (
2424
Abs,
2525
CellCoordinate,
2626
Coefficient,
27+
Cofunction,
2728
ComponentTensor,
2829
Conj,
2930
Constant,
@@ -60,6 +61,8 @@
6061
from ufl.constantvalue import is_true_ufl_scalar, is_ufl_scalar
6162
from ufl.core.base_form_operator import BaseFormOperator
6263
from ufl.core.expr import ufl_err_str
64+
from ufl.core.external_operator import ExternalOperator
65+
from ufl.core.interpolate import Interpolate
6366
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
6467
from ufl.core.terminal import Terminal
6568
from ufl.corealg.dag_traverser import DAGTraverser
@@ -94,6 +97,7 @@
9497
Tan,
9598
Tanh,
9699
)
100+
from ufl.matrix import Matrix
97101
from ufl.operators import (
98102
MaxValue,
99103
MinValue,
@@ -1688,59 +1692,90 @@ def _(self, o: Expr) -> Expr:
16881692
return self.independent_operator(o)
16891693

16901694

1691-
class GateauxDerivativeRuleset(GenericDerivativeRulesetMultiFunction):
1695+
class GateauxDerivativeRuleset(GenericDerivativeRuleset):
16921696
"""Apply AFD (Automatic Functional Differentiation) to expression.
16931697
16941698
Implements rules for the Gateaux derivative D_w[v](...) defined as
16951699
D_w[v](e) = d/dtau e(w+tau v)|tau=0.
16961700
"""
16971701

1698-
def __init__(self, coefficients, arguments, coefficient_derivatives, pending_operations):
1702+
def __init__(
1703+
self, coefficients, arguments, coefficient_derivatives, pending_operations,
1704+
compress=True, vcache=None, rcache=None,
1705+
):
16991706
"""Initialise."""
1700-
GenericDerivativeRulesetMultiFunction.__init__(self, var_shape=())
1701-
1707+
super().__init__((), compress=compress, vcache=vcache, rcache=rcache)
17021708
# Type checking
17031709
if not isinstance(coefficients, ExprList):
17041710
raise ValueError("Expecting a ExprList of coefficients.")
17051711
if not isinstance(arguments, ExprList):
17061712
raise ValueError("Expecting a ExprList of arguments.")
17071713
if not isinstance(coefficient_derivatives, ExprMapping):
17081714
raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
1709-
17101715
# The coefficient(s) to differentiate w.r.t. and the
17111716
# argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
17121717
self._w = coefficients.ufl_operands
17131718
self._v = arguments.ufl_operands
17141719
self._w2v = {w: v for w, v in zip(self._w, self._v)}
1715-
17161720
# Build more convenient dict {f: df/dw} for each coefficient f
17171721
# where df/dw is nonzero
17181722
cd = coefficient_derivatives.ufl_operands
17191723
self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
1720-
17211724
# Record the operations delayed to the derivative expansion phase:
17221725
# Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
17231726
self.pending_operations = pending_operations
17241727

1725-
# Explicitly defining dg/dw == 0
1726-
geometric_quantity = GenericDerivativeRulesetMultiFunction.independent_terminal
1728+
# Work around singledispatchmethod inheritance issue;
1729+
# see https://bugs.python.org/issue36457.
1730+
@singledispatchmethod
1731+
def process(self, o: Expr) -> Expr:
1732+
"""Process ``o``.
1733+
1734+
Args:
1735+
o: `Expr` to be processed.
1736+
1737+
Returns:
1738+
Processed object.
1739+
1740+
"""
1741+
return super().process(o)
17271742

1728-
def cell_avg(self, o, fp):
1743+
# --- Specialized rules for geometric quantities
1744+
1745+
@process.register(GeometricQuantity)
1746+
def _(self, o: Expr) -> Expr:
1747+
# Explicitly defining dg/dw == 0
1748+
return self.independent_terminal(o)
1749+
1750+
@process.register(CellAvg)
1751+
@DAGTraverser.postorder
1752+
def _(self, o: Expr, fp) -> Expr:
17291753
"""Differentiate a cell_avg."""
17301754
# Cell average of a single function and differentiation
17311755
# commutes, D_f[v](cell_avg(f)) = cell_avg(v)
17321756
return cell_avg(fp)
17331757

1734-
def facet_avg(self, o, fp):
1758+
@process.register(FacetAvg)
1759+
@DAGTraverser.postorder
1760+
def _(self, o: Expr, fp) -> Expr:
17351761
"""Differentiate a facet_avg."""
17361762
# Facet average of a single function and differentiation
17371763
# commutes, D_f[v](facet_avg(f)) = facet_avg(v)
17381764
return facet_avg(fp)
17391765

1740-
# Explicitly defining da/dw == 0
1741-
argument = GenericDerivativeRulesetMultiFunction.independent_terminal
1766+
@process.register(Argument)
1767+
def _(self, o: Expr) -> Expr:
1768+
# Explicitly defining da/dw == 0
1769+
return self._process_argument(o)
17421770

1743-
def coefficient(self, o):
1771+
def _process_argument(self, o: Expr) -> Expr:
1772+
return self.independent_terminal(o)
1773+
1774+
@process.register(Coefficient)
1775+
def _(self, o: Expr) -> Expr:
1776+
return self._process_coefficient(o)
1777+
1778+
def _process_coefficient(self, o: Expr) -> Expr:
17441779
"""Differentiate a coefficient."""
17451780
# Define dw/dw := d/ds [w + s v] = v
17461781

@@ -1788,7 +1823,8 @@ def coefficient(self, o):
17881823
dosum += prod
17891824
return dosum
17901825

1791-
def reference_value(self, o):
1826+
@process.register(ReferenceValue)
1827+
def _(self, o: Expr) -> Expr:
17921828
"""Differentiate a reference_value."""
17931829
raise NotImplementedError(
17941830
"Currently no support for ReferenceValue in CoefficientDerivative."
@@ -1808,7 +1844,8 @@ def reference_value(self, o):
18081844
# else:
18091845
# return self.independent_terminal(o)
18101846

1811-
def reference_grad(self, o):
1847+
@process.register(ReferenceGrad)
1848+
def _(self, o: Expr) -> Expr:
18121849
"""Differentiate a reference_grad."""
18131850
raise NotImplementedError(
18141851
"Currently no support for ReferenceGrad in CoefficientDerivative."
@@ -1819,7 +1856,8 @@ def reference_grad(self, o):
18191856
# this to allow the user to write
18201857
# derivative(...ReferenceValue...,...).
18211858

1822-
def grad(self, g):
1859+
@process.register(Grad)
1860+
def _(self, g: Expr) -> Expr:
18231861
"""Differentiate a grad."""
18241862
# If we hit this type, it has already been propagated to a
18251863
# coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
@@ -1976,48 +2014,54 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
19762014

19772015
return gprimesum
19782016

1979-
def coordinate_derivative(self, o):
2017+
@process.register(CoordinateDerivative)
2018+
def _(self, o: Expr) -> Expr:
19802019
"""Differentiate a coordinate_derivative."""
19812020
o = o.ufl_operands
19822021
return CoordinateDerivative(map_expr_dag(self, o[0]), o[1], o[2], o[3])
19832022

1984-
def base_form_operator(self, o, *dfs):
2023+
@process.register(BaseFormOperator)
2024+
@DAGTraverser.postorder
2025+
def _(self, o: Expr, *dfs) -> Expr:
19852026
"""Differentiate a base_form_operator.
19862027
19872028
If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
19882029
variable => we call the appropriate handler. Otherwise =>
19892030
differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
19902031
we treat o as a Coefficient.
19912032
"""
1992-
d_coeff = self.coefficient(o)
2033+
d_coeff = self._process_coefficient(o)
19932034
# It also handles the non-scalar case
19942035
if d_coeff == 0:
19952036
self.pending_operations += (o,)
19962037
return d_coeff
19972038

19982039
# -- Handlers for BaseForm objects -- #
19992040

2000-
def cofunction(self, o):
2041+
@process.register(Cofunction)
2042+
def _(self, o: Expr) -> Expr:
20012043
"""Differentiate a cofunction."""
20022044
# Same rule than for Coefficient except that we use a Coargument.
20032045
# The coargument is already attached to the class (self._v)
20042046
# which `self.coefficient` relies on.
2005-
dc = self.coefficient(o)
2047+
dc = self._process_coefficient(o)
20062048
if dc == 0:
20072049
# Convert ufl.Zero into ZeroBaseForm
20082050
return ZeroBaseForm(o.arguments() + self._v)
20092051
return dc
20102052

2011-
def coargument(self, o):
2053+
@process.register(Coargument)
2054+
def _(self, o: Expr) -> Expr:
20122055
"""Differentiate a coargument."""
20132056
# Same rule than for Argument (da/dw == 0).
2014-
dc = self.argument(o)
2057+
dc = self._process_argument(o)
20152058
if dc == 0:
20162059
# Convert ufl.Zero into ZeroBaseForm
20172060
return ZeroBaseForm(o.arguments() + self._v)
20182061
return dc
20192062

2020-
def matrix(self, M):
2063+
@process.register(Matrix)
2064+
def _(self, M: Expr) -> Expr:
20212065
"""Differentiate a matrix."""
20222066
# Matrix rule: D_w[v](M) = v if M == w else 0
20232067
# We can't differentiate wrt a matrix so always return zero in
@@ -2032,10 +2076,14 @@ class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
20322076
D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
20332077
"""
20342078

2035-
def __init__(self, coefficients, arguments, coefficient_derivatives, pending_operations):
2079+
def __init__(
2080+
self, coefficients, arguments, coefficient_derivatives, pending_operations,
2081+
compress=True, vcache=None, rcache=None,
2082+
):
20362083
"""Initialise."""
2037-
GateauxDerivativeRuleset.__init__(
2038-
self, coefficients, arguments, coefficient_derivatives, pending_operations
2084+
super().__init__(
2085+
coefficients, arguments, coefficient_derivatives, pending_operations,
2086+
compress=compress, vcache=vcache, rcache=rcache,
20392087
)
20402088

20412089
def pending_operations_recording(base_form_operator_handler):
@@ -2056,13 +2104,30 @@ def wrapper(self, base_form_op, *dfs):
20562104
# calling the appropriate handler.
20572105
if expression != base_form_op:
20582106
self.pending_operations += (base_form_op,)
2059-
return self.coefficient(base_form_op)
2107+
return self._process_coefficient(base_form_op)
20602108
return base_form_operator_handler(self, base_form_op, *dfs)
20612109

20622110
return wrapper
20632111

2112+
# Work around singledispatchmethod inheritance issue;
2113+
# see https://bugs.python.org/issue36457.
2114+
@singledispatchmethod
2115+
def process(self, o: Expr) -> Expr:
2116+
"""Process ``o``.
2117+
2118+
Args:
2119+
o: `Expr` to be processed.
2120+
2121+
Returns:
2122+
Processed object.
2123+
2124+
"""
2125+
return super().process(o)
2126+
2127+
@process.register(Interpolate)
2128+
@DAGTraverser.postorder
20642129
@pending_operations_recording
2065-
def interpolate(self, i_op, dw):
2130+
def _(self, i_op, dw):
20662131
"""Differentiate an interpolate."""
20672132
# Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
20682133
if not dw:
@@ -2072,6 +2137,8 @@ def interpolate(self, i_op, dw):
20722137
return ZeroBaseForm(i_op.arguments() + self._v)
20732138
return i_op._ufl_expr_reconstruct_(expr=dw)
20742139

2140+
@process.register(ExternalOperator)
2141+
@DAGTraverser.postorder
20752142
@pending_operations_recording
20762143
def external_operator(self, N, *dfs):
20772144
"""Differentiate an external_operator."""
@@ -2121,6 +2188,7 @@ def __init__(self):
21212188
self._grad_ruleset_dict = {}
21222189
self._reference_grad_ruleset_dict = {}
21232190
self._variable_ruleset_dict = {}
2191+
self._dag_traverser_dict = {}
21242192

21252193
def terminal(self, o):
21262194
"""Apply to a terminal."""
@@ -2158,11 +2226,15 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21582226
pending_operations = BaseFormOperatorDerivativeRecorder(
21592227
f, w, arguments=v, coefficient_derivatives=cd
21602228
)
2161-
rules = GateauxDerivativeRuleset(w, v, cd, pending_operations)
2162-
key = (GateauxDerivativeRuleset, w, v, cd)
2229+
key = (GateauxDerivativeRuleset, w, v, cd, f)
21632230
# We need to go through the dag first to record the pending
21642231
# operations
2165-
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
2232+
dag_traverser = self._dag_traverser_dict.setdefault(
2233+
key, GateauxDerivativeRuleset(w, v, cd, pending_operations),
2234+
)
2235+
# If f has been seen by the traverser, it immediately returns
2236+
# the cached value.
2237+
mapped_expr = dag_traverser(f)
21662238
# Need to account for pending operations that have been stored
21672239
# in other integrands
21682240
self.pending_operations += pending_operations
@@ -2171,11 +2243,6 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21712243
def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21722244
"""Apply to a base_form_operator_derivative."""
21732245
dummy, w, v, cd = o.ufl_operands
2174-
pending_operations = BaseFormOperatorDerivativeRecorder(
2175-
f, w, arguments=v, coefficient_derivatives=cd
2176-
)
2177-
rules = BaseFormOperatorDerivativeRuleset(w, v, cd, pending_operations=pending_operations)
2178-
key = (BaseFormOperatorDerivativeRuleset, w, v, cd)
21792246
if isinstance(f, ZeroBaseForm):
21802247
(arg,) = v.ufl_operands
21812248
arguments = f.arguments()
@@ -2185,10 +2252,18 @@ def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21852252
if isinstance(arg, BaseArgument):
21862253
arguments += (arg,)
21872254
return ZeroBaseForm(arguments)
2255+
pending_operations = BaseFormOperatorDerivativeRecorder(
2256+
f, w, arguments=v, coefficient_derivatives=cd
2257+
)
2258+
key = (BaseFormOperatorDerivativeRuleset, w, v, cd, f)
21882259
# We need to go through the dag first to record the pending operations
2189-
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
2190-
2191-
mapped_f = rules.coefficient(f)
2260+
dag_traverser = self._dag_traverser_dict.setdefault(
2261+
key, BaseFormOperatorDerivativeRuleset(w, v, cd, pending_operations),
2262+
)
2263+
# If f has been seen by the traverser, it immediately returns
2264+
# the cached value.
2265+
mapped_expr = dag_traverser(f)
2266+
mapped_f = dag_traverser._process_coefficient(f)
21922267
if mapped_f != 0:
21932268
# If dN/dN needs to return an Argument in N space
21942269
# with N a BaseFormOperator.

0 commit comments

Comments
 (0)