Skip to content

Commit b154ee8

Browse files
committed
GateauxDerivativeRuleset
1 parent 34bbd94 commit b154ee8

File tree

1 file changed

+136
-42
lines changed

1 file changed

+136
-42
lines changed

ufl/algorithms/apply_derivatives.py

+136-42
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
from ufl.algorithms.analysis import extract_arguments
1818
from ufl.algorithms.map_integrands import map_integrand_dags
1919
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
20-
from ufl.argument import Argument, BaseArgument
20+
from ufl.argument import Argument, BaseArgument, Coargument
2121
from ufl.averaging import CellAvg, FacetAvg
2222
from ufl.checks import is_cellwise_constant
2323
from ufl.classes import (
2424
Abs,
2525
CellCoordinate,
2626
Coefficient,
27+
Cofunction,
2728
ComponentTensor,
2829
Conj,
2930
Constant,
@@ -60,6 +61,8 @@
6061
from ufl.constantvalue import is_true_ufl_scalar, is_ufl_scalar
6162
from ufl.core.base_form_operator import BaseFormOperator
6263
from ufl.core.expr import ufl_err_str
64+
from ufl.core.external_operator import ExternalOperator
65+
from ufl.core.interpolate import Interpolate
6366
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
6467
from ufl.core.terminal import Terminal
6568
from ufl.corealg.dag_traverser import DAGTraverser
@@ -94,6 +97,7 @@
9497
Tan,
9598
Tanh,
9699
)
100+
from ufl.matrix import Matrix
97101
from ufl.operators import (
98102
MaxValue,
99103
MinValue,
@@ -1688,59 +1692,96 @@ def _(self, o: Expr) -> Expr:
16881692
return self.independent_operator(o)
16891693

16901694

1691-
class GateauxDerivativeRuleset(GenericDerivativeRulesetMultiFunction):
1695+
class GateauxDerivativeRuleset(GenericDerivativeRuleset):
16921696
"""Apply AFD (Automatic Functional Differentiation) to expression.
16931697
16941698
Implements rules for the Gateaux derivative D_w[v](...) defined as
16951699
D_w[v](e) = d/dtau e(w+tau v)|tau=0.
16961700
"""
16971701

1698-
def __init__(self, coefficients, arguments, coefficient_derivatives, pending_operations):
1702+
def __init__(
1703+
self,
1704+
coefficients,
1705+
arguments,
1706+
coefficient_derivatives,
1707+
pending_operations,
1708+
compress=True,
1709+
vcache=None,
1710+
rcache=None,
1711+
):
16991712
"""Initialise."""
1700-
GenericDerivativeRulesetMultiFunction.__init__(self, var_shape=())
1701-
1713+
super().__init__((), compress=compress, vcache=vcache, rcache=rcache)
17021714
# Type checking
17031715
if not isinstance(coefficients, ExprList):
17041716
raise ValueError("Expecting a ExprList of coefficients.")
17051717
if not isinstance(arguments, ExprList):
17061718
raise ValueError("Expecting a ExprList of arguments.")
17071719
if not isinstance(coefficient_derivatives, ExprMapping):
17081720
raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
1709-
17101721
# The coefficient(s) to differentiate w.r.t. and the
17111722
# argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
17121723
self._w = coefficients.ufl_operands
17131724
self._v = arguments.ufl_operands
17141725
self._w2v = {w: v for w, v in zip(self._w, self._v)}
1715-
17161726
# Build more convenient dict {f: df/dw} for each coefficient f
17171727
# where df/dw is nonzero
17181728
cd = coefficient_derivatives.ufl_operands
17191729
self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
1720-
17211730
# Record the operations delayed to the derivative expansion phase:
17221731
# Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
17231732
self.pending_operations = pending_operations
17241733

1725-
# Explicitly defining dg/dw == 0
1726-
geometric_quantity = GenericDerivativeRulesetMultiFunction.independent_terminal
1734+
# Work around singledispatchmethod inheritance issue;
1735+
# see https://bugs.python.org/issue36457.
1736+
@singledispatchmethod
1737+
def process(self, o: Expr) -> Expr:
1738+
"""Process ``o``.
1739+
1740+
Args:
1741+
o: `Expr` to be processed.
1742+
1743+
Returns:
1744+
Processed object.
1745+
1746+
"""
1747+
return super().process(o)
17271748

1728-
def cell_avg(self, o, fp):
1749+
# --- Specialized rules for geometric quantities
1750+
1751+
@process.register(GeometricQuantity)
1752+
def _(self, o: Expr) -> Expr:
1753+
# Explicitly defining dg/dw == 0
1754+
return self.independent_terminal(o)
1755+
1756+
@process.register(CellAvg)
1757+
@DAGTraverser.postorder
1758+
def _(self, o: Expr, fp) -> Expr:
17291759
"""Differentiate a cell_avg."""
17301760
# Cell average of a single function and differentiation
17311761
# commutes, D_f[v](cell_avg(f)) = cell_avg(v)
17321762
return cell_avg(fp)
17331763

1734-
def facet_avg(self, o, fp):
1764+
@process.register(FacetAvg)
1765+
@DAGTraverser.postorder
1766+
def _(self, o: Expr, fp) -> Expr:
17351767
"""Differentiate a facet_avg."""
17361768
# Facet average of a single function and differentiation
17371769
# commutes, D_f[v](facet_avg(f)) = facet_avg(v)
17381770
return facet_avg(fp)
17391771

1740-
# Explicitly defining da/dw == 0
1741-
argument = GenericDerivativeRulesetMultiFunction.independent_terminal
1772+
@process.register(Argument)
1773+
def _(self, o: Expr) -> Expr:
1774+
# Explicitly defining da/dw == 0
1775+
return self._process_argument(o)
17421776

1743-
def coefficient(self, o):
1777+
def _process_argument(self, o: Expr) -> Expr:
1778+
return self.independent_terminal(o)
1779+
1780+
@process.register(Coefficient)
1781+
def _(self, o: Expr) -> Expr:
1782+
return self._process_coefficient(o)
1783+
1784+
def _process_coefficient(self, o: Expr) -> Expr:
17441785
"""Differentiate a coefficient."""
17451786
# Define dw/dw := d/ds [w + s v] = v
17461787

@@ -1788,7 +1829,8 @@ def coefficient(self, o):
17881829
dosum += prod
17891830
return dosum
17901831

1791-
def reference_value(self, o):
1832+
@process.register(ReferenceValue)
1833+
def _(self, o: Expr) -> Expr:
17921834
"""Differentiate a reference_value."""
17931835
raise NotImplementedError(
17941836
"Currently no support for ReferenceValue in CoefficientDerivative."
@@ -1808,7 +1850,8 @@ def reference_value(self, o):
18081850
# else:
18091851
# return self.independent_terminal(o)
18101852

1811-
def reference_grad(self, o):
1853+
@process.register(ReferenceGrad)
1854+
def _(self, o: Expr) -> Expr:
18121855
"""Differentiate a reference_grad."""
18131856
raise NotImplementedError(
18141857
"Currently no support for ReferenceGrad in CoefficientDerivative."
@@ -1819,7 +1862,8 @@ def reference_grad(self, o):
18191862
# this to allow the user to write
18201863
# derivative(...ReferenceValue...,...).
18211864

1822-
def grad(self, g):
1865+
@process.register(Grad)
1866+
def _(self, g: Expr) -> Expr:
18231867
"""Differentiate a grad."""
18241868
# If we hit this type, it has already been propagated to a
18251869
# coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
@@ -1976,48 +2020,54 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
19762020

19772021
return gprimesum
19782022

1979-
def coordinate_derivative(self, o):
2023+
@process.register(CoordinateDerivative)
2024+
def _(self, o: Expr) -> Expr:
19802025
"""Differentiate a coordinate_derivative."""
19812026
o = o.ufl_operands
19822027
return CoordinateDerivative(map_expr_dag(self, o[0]), o[1], o[2], o[3])
19832028

1984-
def base_form_operator(self, o, *dfs):
2029+
@process.register(BaseFormOperator)
2030+
@DAGTraverser.postorder
2031+
def _(self, o: Expr, *dfs) -> Expr:
19852032
"""Differentiate a base_form_operator.
19862033
19872034
If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
19882035
variable => we call the appropriate handler. Otherwise =>
19892036
differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
19902037
we treat o as a Coefficient.
19912038
"""
1992-
d_coeff = self.coefficient(o)
2039+
d_coeff = self._process_coefficient(o)
19932040
# It also handles the non-scalar case
19942041
if d_coeff == 0:
19952042
self.pending_operations += (o,)
19962043
return d_coeff
19972044

19982045
# -- Handlers for BaseForm objects -- #
19992046

2000-
def cofunction(self, o):
2047+
@process.register(Cofunction)
2048+
def _(self, o: Expr) -> Expr:
20012049
"""Differentiate a cofunction."""
20022050
# Same rule than for Coefficient except that we use a Coargument.
20032051
# The coargument is already attached to the class (self._v)
20042052
# which `self.coefficient` relies on.
2005-
dc = self.coefficient(o)
2053+
dc = self._process_coefficient(o)
20062054
if dc == 0:
20072055
# Convert ufl.Zero into ZeroBaseForm
20082056
return ZeroBaseForm(o.arguments() + self._v)
20092057
return dc
20102058

2011-
def coargument(self, o):
2059+
@process.register(Coargument)
2060+
def _(self, o: Expr) -> Expr:
20122061
"""Differentiate a coargument."""
20132062
# Same rule than for Argument (da/dw == 0).
2014-
dc = self.argument(o)
2063+
dc = self._process_argument(o)
20152064
if dc == 0:
20162065
# Convert ufl.Zero into ZeroBaseForm
20172066
return ZeroBaseForm(o.arguments() + self._v)
20182067
return dc
20192068

2020-
def matrix(self, M):
2069+
@process.register(Matrix)
2070+
def _(self, M: Expr) -> Expr:
20212071
"""Differentiate a matrix."""
20222072
# Matrix rule: D_w[v](M) = v if M == w else 0
20232073
# We can't differentiate wrt a matrix so always return zero in
@@ -2032,10 +2082,25 @@ class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
20322082
D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
20332083
"""
20342084

2035-
def __init__(self, coefficients, arguments, coefficient_derivatives, pending_operations):
2085+
def __init__(
2086+
self,
2087+
coefficients,
2088+
arguments,
2089+
coefficient_derivatives,
2090+
pending_operations,
2091+
compress=True,
2092+
vcache=None,
2093+
rcache=None,
2094+
):
20362095
"""Initialise."""
2037-
GateauxDerivativeRuleset.__init__(
2038-
self, coefficients, arguments, coefficient_derivatives, pending_operations
2096+
super().__init__(
2097+
coefficients,
2098+
arguments,
2099+
coefficient_derivatives,
2100+
pending_operations,
2101+
compress=compress,
2102+
vcache=vcache,
2103+
rcache=rcache,
20392104
)
20402105

20412106
def pending_operations_recording(base_form_operator_handler):
@@ -2056,13 +2121,30 @@ def wrapper(self, base_form_op, *dfs):
20562121
# calling the appropriate handler.
20572122
if expression != base_form_op:
20582123
self.pending_operations += (base_form_op,)
2059-
return self.coefficient(base_form_op)
2124+
return self._process_coefficient(base_form_op)
20602125
return base_form_operator_handler(self, base_form_op, *dfs)
20612126

20622127
return wrapper
20632128

2129+
# Work around singledispatchmethod inheritance issue;
2130+
# see https://bugs.python.org/issue36457.
2131+
@singledispatchmethod
2132+
def process(self, o: Expr) -> Expr:
2133+
"""Process ``o``.
2134+
2135+
Args:
2136+
o: `Expr` to be processed.
2137+
2138+
Returns:
2139+
Processed object.
2140+
2141+
"""
2142+
return super().process(o)
2143+
2144+
@process.register(Interpolate)
2145+
@DAGTraverser.postorder
20642146
@pending_operations_recording
2065-
def interpolate(self, i_op, dw):
2147+
def _(self, i_op, dw):
20662148
"""Differentiate an interpolate."""
20672149
# Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
20682150
if not dw:
@@ -2072,6 +2154,8 @@ def interpolate(self, i_op, dw):
20722154
return ZeroBaseForm(i_op.arguments() + self._v)
20732155
return i_op._ufl_expr_reconstruct_(expr=dw)
20742156

2157+
@process.register(ExternalOperator)
2158+
@DAGTraverser.postorder
20752159
@pending_operations_recording
20762160
def external_operator(self, N, *dfs):
20772161
"""Differentiate an external_operator."""
@@ -2121,6 +2205,7 @@ def __init__(self):
21212205
self._grad_ruleset_dict = {}
21222206
self._reference_grad_ruleset_dict = {}
21232207
self._variable_ruleset_dict = {}
2208+
self._dag_traverser_dict = {}
21242209

21252210
def terminal(self, o):
21262211
"""Apply to a terminal."""
@@ -2158,11 +2243,16 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21582243
pending_operations = BaseFormOperatorDerivativeRecorder(
21592244
f, w, arguments=v, coefficient_derivatives=cd
21602245
)
2161-
rules = GateauxDerivativeRuleset(w, v, cd, pending_operations)
2162-
key = (GateauxDerivativeRuleset, w, v, cd)
2246+
key = (GateauxDerivativeRuleset, w, v, cd, f)
21632247
# We need to go through the dag first to record the pending
21642248
# operations
2165-
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
2249+
dag_traverser = self._dag_traverser_dict.setdefault(
2250+
key,
2251+
GateauxDerivativeRuleset(w, v, cd, pending_operations),
2252+
)
2253+
# If f has been seen by the traverser, it immediately returns
2254+
# the cached value.
2255+
mapped_expr = dag_traverser(f)
21662256
# Need to account for pending operations that have been stored
21672257
# in other integrands
21682258
self.pending_operations += pending_operations
@@ -2171,11 +2261,6 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21712261
def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21722262
"""Apply to a base_form_operator_derivative."""
21732263
dummy, w, v, cd = o.ufl_operands
2174-
pending_operations = BaseFormOperatorDerivativeRecorder(
2175-
f, w, arguments=v, coefficient_derivatives=cd
2176-
)
2177-
rules = BaseFormOperatorDerivativeRuleset(w, v, cd, pending_operations=pending_operations)
2178-
key = (BaseFormOperatorDerivativeRuleset, w, v, cd)
21792264
if isinstance(f, ZeroBaseForm):
21802265
(arg,) = v.ufl_operands
21812266
arguments = f.arguments()
@@ -2185,10 +2270,19 @@ def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
21852270
if isinstance(arg, BaseArgument):
21862271
arguments += (arg,)
21872272
return ZeroBaseForm(arguments)
2273+
pending_operations = BaseFormOperatorDerivativeRecorder(
2274+
f, w, arguments=v, coefficient_derivatives=cd
2275+
)
2276+
key = (BaseFormOperatorDerivativeRuleset, w, v, cd, f)
21882277
# We need to go through the dag first to record the pending operations
2189-
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
2190-
2191-
mapped_f = rules.coefficient(f)
2278+
dag_traverser = self._dag_traverser_dict.setdefault(
2279+
key,
2280+
BaseFormOperatorDerivativeRuleset(w, v, cd, pending_operations),
2281+
)
2282+
# If f has been seen by the traverser, it immediately returns
2283+
# the cached value.
2284+
mapped_expr = dag_traverser(f)
2285+
mapped_f = dag_traverser._process_coefficient(f)
21922286
if mapped_f != 0:
21932287
# If dN/dN needs to return an Argument in N space
21942288
# with N a BaseFormOperator.

0 commit comments

Comments
 (0)