17
17
from ufl .algorithms .analysis import extract_arguments
18
18
from ufl .algorithms .map_integrands import map_integrand_dags
19
19
from ufl .algorithms .replace_derivative_nodes import replace_derivative_nodes
20
- from ufl .argument import Argument , BaseArgument
20
+ from ufl .argument import Argument , BaseArgument , Coargument
21
21
from ufl .averaging import CellAvg , FacetAvg
22
22
from ufl .checks import is_cellwise_constant
23
23
from ufl .classes import (
24
24
Abs ,
25
25
CellCoordinate ,
26
26
Coefficient ,
27
+ Cofunction ,
27
28
ComponentTensor ,
28
29
Conj ,
29
30
Constant ,
60
61
from ufl .constantvalue import is_true_ufl_scalar , is_ufl_scalar
61
62
from ufl .core .base_form_operator import BaseFormOperator
62
63
from ufl .core .expr import ufl_err_str
64
+ from ufl .core .external_operator import ExternalOperator
65
+ from ufl .core .interpolate import Interpolate
63
66
from ufl .core .multiindex import FixedIndex , MultiIndex , indices
64
67
from ufl .core .terminal import Terminal
65
68
from ufl .corealg .dag_traverser import DAGTraverser
94
97
Tan ,
95
98
Tanh ,
96
99
)
100
+ from ufl .matrix import Matrix
97
101
from ufl .operators import (
98
102
MaxValue ,
99
103
MinValue ,
@@ -1688,59 +1692,90 @@ def _(self, o: Expr) -> Expr:
1688
1692
return self .independent_operator (o )
1689
1693
1690
1694
1691
- class GateauxDerivativeRuleset (GenericDerivativeRulesetMultiFunction ):
1695
+ class GateauxDerivativeRuleset (GenericDerivativeRuleset ):
1692
1696
"""Apply AFD (Automatic Functional Differentiation) to expression.
1693
1697
1694
1698
Implements rules for the Gateaux derivative D_w[v](...) defined as
1695
1699
D_w[v](e) = d/dtau e(w+tau v)|tau=0.
1696
1700
"""
1697
1701
1698
- def __init__ (self , coefficients , arguments , coefficient_derivatives , pending_operations ):
1702
+ def __init__ (
1703
+ self , coefficients , arguments , coefficient_derivatives , pending_operations ,
1704
+ compress = True , vcache = None , rcache = None ,
1705
+ ):
1699
1706
"""Initialise."""
1700
- GenericDerivativeRulesetMultiFunction .__init__ (self , var_shape = ())
1701
-
1707
+ super ().__init__ ((), compress = compress , vcache = vcache , rcache = rcache )
1702
1708
# Type checking
1703
1709
if not isinstance (coefficients , ExprList ):
1704
1710
raise ValueError ("Expecting a ExprList of coefficients." )
1705
1711
if not isinstance (arguments , ExprList ):
1706
1712
raise ValueError ("Expecting a ExprList of arguments." )
1707
1713
if not isinstance (coefficient_derivatives , ExprMapping ):
1708
1714
raise ValueError ("Expecting a coefficient-coefficient ExprMapping." )
1709
-
1710
1715
# The coefficient(s) to differentiate w.r.t. and the
1711
1716
# argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
1712
1717
self ._w = coefficients .ufl_operands
1713
1718
self ._v = arguments .ufl_operands
1714
1719
self ._w2v = {w : v for w , v in zip (self ._w , self ._v )}
1715
-
1716
1720
# Build more convenient dict {f: df/dw} for each coefficient f
1717
1721
# where df/dw is nonzero
1718
1722
cd = coefficient_derivatives .ufl_operands
1719
1723
self ._cd = {cd [2 * i ]: cd [2 * i + 1 ] for i in range (len (cd ) // 2 )}
1720
-
1721
1724
# Record the operations delayed to the derivative expansion phase:
1722
1725
# Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
1723
1726
self .pending_operations = pending_operations
1724
1727
1725
- # Explicitly defining dg/dw == 0
1726
- geometric_quantity = GenericDerivativeRulesetMultiFunction .independent_terminal
1728
+ # Work around singledispatchmethod inheritance issue;
1729
+ # see https://bugs.python.org/issue36457.
1730
+ @singledispatchmethod
1731
+ def process (self , o : Expr ) -> Expr :
1732
+ """Process ``o``.
1733
+
1734
+ Args:
1735
+ o: `Expr` to be processed.
1736
+
1737
+ Returns:
1738
+ Processed object.
1739
+
1740
+ """
1741
+ return super ().process (o )
1727
1742
1728
- def cell_avg (self , o , fp ):
1743
+ # --- Specialized rules for geometric quantities
1744
+
1745
+ @process .register (GeometricQuantity )
1746
+ def _ (self , o : Expr ) -> Expr :
1747
+ # Explicitly defining dg/dw == 0
1748
+ return self .independent_terminal (o )
1749
+
1750
+ @process .register (CellAvg )
1751
+ @DAGTraverser .postorder
1752
+ def _ (self , o : Expr , fp ) -> Expr :
1729
1753
"""Differentiate a cell_avg."""
1730
1754
# Cell average of a single function and differentiation
1731
1755
# commutes, D_f[v](cell_avg(f)) = cell_avg(v)
1732
1756
return cell_avg (fp )
1733
1757
1734
- def facet_avg (self , o , fp ):
1758
+ @process .register (FacetAvg )
1759
+ @DAGTraverser .postorder
1760
+ def _ (self , o : Expr , fp ) -> Expr :
1735
1761
"""Differentiate a facet_avg."""
1736
1762
# Facet average of a single function and differentiation
1737
1763
# commutes, D_f[v](facet_avg(f)) = facet_avg(v)
1738
1764
return facet_avg (fp )
1739
1765
1740
- # Explicitly defining da/dw == 0
1741
- argument = GenericDerivativeRulesetMultiFunction .independent_terminal
1766
+ @process .register (Argument )
1767
+ def _ (self , o : Expr ) -> Expr :
1768
+ # Explicitly defining da/dw == 0
1769
+ return self ._process_argument (o )
1742
1770
1743
- def coefficient (self , o ):
1771
+ def _process_argument (self , o : Expr ) -> Expr :
1772
+ return self .independent_terminal (o )
1773
+
1774
+ @process .register (Coefficient )
1775
+ def _ (self , o : Expr ) -> Expr :
1776
+ return self ._process_coefficient (o )
1777
+
1778
+ def _process_coefficient (self , o : Expr ) -> Expr :
1744
1779
"""Differentiate a coefficient."""
1745
1780
# Define dw/dw := d/ds [w + s v] = v
1746
1781
@@ -1788,7 +1823,8 @@ def coefficient(self, o):
1788
1823
dosum += prod
1789
1824
return dosum
1790
1825
1791
- def reference_value (self , o ):
1826
+ @process .register (ReferenceValue )
1827
+ def _ (self , o : Expr ) -> Expr :
1792
1828
"""Differentiate a reference_value."""
1793
1829
raise NotImplementedError (
1794
1830
"Currently no support for ReferenceValue in CoefficientDerivative."
@@ -1808,7 +1844,8 @@ def reference_value(self, o):
1808
1844
# else:
1809
1845
# return self.independent_terminal(o)
1810
1846
1811
- def reference_grad (self , o ):
1847
+ @process .register (ReferenceGrad )
1848
+ def _ (self , o : Expr ) -> Expr :
1812
1849
"""Differentiate a reference_grad."""
1813
1850
raise NotImplementedError (
1814
1851
"Currently no support for ReferenceGrad in CoefficientDerivative."
@@ -1819,7 +1856,8 @@ def reference_grad(self, o):
1819
1856
# this to allow the user to write
1820
1857
# derivative(...ReferenceValue...,...).
1821
1858
1822
- def grad (self , g ):
1859
+ @process .register (Grad )
1860
+ def _ (self , g : Expr ) -> Expr :
1823
1861
"""Differentiate a grad."""
1824
1862
# If we hit this type, it has already been propagated to a
1825
1863
# coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
@@ -1976,48 +2014,54 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
1976
2014
1977
2015
return gprimesum
1978
2016
1979
- def coordinate_derivative (self , o ):
2017
+ @process .register (CoordinateDerivative )
2018
+ def _ (self , o : Expr ) -> Expr :
1980
2019
"""Differentiate a coordinate_derivative."""
1981
2020
o = o .ufl_operands
1982
2021
return CoordinateDerivative (map_expr_dag (self , o [0 ]), o [1 ], o [2 ], o [3 ])
1983
2022
1984
- def base_form_operator (self , o , * dfs ):
2023
+ @process .register (BaseFormOperator )
2024
+ @DAGTraverser .postorder
2025
+ def _ (self , o : Expr , * dfs ) -> Expr :
1985
2026
"""Differentiate a base_form_operator.
1986
2027
1987
2028
If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
1988
2029
variable => we call the appropriate handler. Otherwise =>
1989
2030
differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
1990
2031
we treat o as a Coefficient.
1991
2032
"""
1992
- d_coeff = self .coefficient (o )
2033
+ d_coeff = self ._process_coefficient (o )
1993
2034
# It also handles the non-scalar case
1994
2035
if d_coeff == 0 :
1995
2036
self .pending_operations += (o ,)
1996
2037
return d_coeff
1997
2038
1998
2039
# -- Handlers for BaseForm objects -- #
1999
2040
2000
- def cofunction (self , o ):
2041
+ @process .register (Cofunction )
2042
+ def _ (self , o : Expr ) -> Expr :
2001
2043
"""Differentiate a cofunction."""
2002
2044
# Same rule than for Coefficient except that we use a Coargument.
2003
2045
# The coargument is already attached to the class (self._v)
2004
2046
# which `self.coefficient` relies on.
2005
- dc = self .coefficient (o )
2047
+ dc = self ._process_coefficient (o )
2006
2048
if dc == 0 :
2007
2049
# Convert ufl.Zero into ZeroBaseForm
2008
2050
return ZeroBaseForm (o .arguments () + self ._v )
2009
2051
return dc
2010
2052
2011
- def coargument (self , o ):
2053
+ @process .register (Coargument )
2054
+ def _ (self , o : Expr ) -> Expr :
2012
2055
"""Differentiate a coargument."""
2013
2056
# Same rule than for Argument (da/dw == 0).
2014
- dc = self .argument (o )
2057
+ dc = self ._process_argument (o )
2015
2058
if dc == 0 :
2016
2059
# Convert ufl.Zero into ZeroBaseForm
2017
2060
return ZeroBaseForm (o .arguments () + self ._v )
2018
2061
return dc
2019
2062
2020
- def matrix (self , M ):
2063
+ @process .register (Matrix )
2064
+ def _ (self , M : Expr ) -> Expr :
2021
2065
"""Differentiate a matrix."""
2022
2066
# Matrix rule: D_w[v](M) = v if M == w else 0
2023
2067
# We can't differentiate wrt a matrix so always return zero in
@@ -2032,10 +2076,14 @@ class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
2032
2076
D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
2033
2077
"""
2034
2078
2035
- def __init__ (self , coefficients , arguments , coefficient_derivatives , pending_operations ):
2079
+ def __init__ (
2080
+ self , coefficients , arguments , coefficient_derivatives , pending_operations ,
2081
+ compress = True , vcache = None , rcache = None ,
2082
+ ):
2036
2083
"""Initialise."""
2037
- GateauxDerivativeRuleset .__init__ (
2038
- self , coefficients , arguments , coefficient_derivatives , pending_operations
2084
+ super ().__init__ (
2085
+ coefficients , arguments , coefficient_derivatives , pending_operations ,
2086
+ compress = compress , vcache = vcache , rcache = rcache ,
2039
2087
)
2040
2088
2041
2089
def pending_operations_recording (base_form_operator_handler ):
@@ -2056,13 +2104,30 @@ def wrapper(self, base_form_op, *dfs):
2056
2104
# calling the appropriate handler.
2057
2105
if expression != base_form_op :
2058
2106
self .pending_operations += (base_form_op ,)
2059
- return self .coefficient (base_form_op )
2107
+ return self ._process_coefficient (base_form_op )
2060
2108
return base_form_operator_handler (self , base_form_op , * dfs )
2061
2109
2062
2110
return wrapper
2063
2111
2112
+ # Work around singledispatchmethod inheritance issue;
2113
+ # see https://bugs.python.org/issue36457.
2114
+ @singledispatchmethod
2115
+ def process (self , o : Expr ) -> Expr :
2116
+ """Process ``o``.
2117
+
2118
+ Args:
2119
+ o: `Expr` to be processed.
2120
+
2121
+ Returns:
2122
+ Processed object.
2123
+
2124
+ """
2125
+ return super ().process (o )
2126
+
2127
+ @process .register (Interpolate )
2128
+ @DAGTraverser .postorder
2064
2129
@pending_operations_recording
2065
- def interpolate (self , i_op , dw ):
2130
+ def _ (self , i_op , dw ):
2066
2131
"""Differentiate an interpolate."""
2067
2132
# Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
2068
2133
if not dw :
@@ -2072,6 +2137,8 @@ def interpolate(self, i_op, dw):
2072
2137
return ZeroBaseForm (i_op .arguments () + self ._v )
2073
2138
return i_op ._ufl_expr_reconstruct_ (expr = dw )
2074
2139
2140
+ @process .register (ExternalOperator )
2141
+ @DAGTraverser .postorder
2075
2142
@pending_operations_recording
2076
2143
def external_operator (self , N , * dfs ):
2077
2144
"""Differentiate an external_operator."""
@@ -2121,6 +2188,7 @@ def __init__(self):
2121
2188
self ._grad_ruleset_dict = {}
2122
2189
self ._reference_grad_ruleset_dict = {}
2123
2190
self ._variable_ruleset_dict = {}
2191
+ self ._dag_traverser_dict = {}
2124
2192
2125
2193
def terminal (self , o ):
2126
2194
"""Apply to a terminal."""
@@ -2158,11 +2226,15 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2158
2226
pending_operations = BaseFormOperatorDerivativeRecorder (
2159
2227
f , w , arguments = v , coefficient_derivatives = cd
2160
2228
)
2161
- rules = GateauxDerivativeRuleset (w , v , cd , pending_operations )
2162
- key = (GateauxDerivativeRuleset , w , v , cd )
2229
+ key = (GateauxDerivativeRuleset , w , v , cd , f )
2163
2230
# We need to go through the dag first to record the pending
2164
2231
# operations
2165
- mapped_expr = map_expr_dag (rules , f , vcache = self .vcaches [key ], rcache = self .rcaches [key ])
2232
+ dag_traverser = self ._dag_traverser_dict .setdefault (
2233
+ key , GateauxDerivativeRuleset (w , v , cd , pending_operations ),
2234
+ )
2235
+ # If f has been seen by the traverser, it immediately returns
2236
+ # the cached value.
2237
+ mapped_expr = dag_traverser (f )
2166
2238
# Need to account for pending operations that have been stored
2167
2239
# in other integrands
2168
2240
self .pending_operations += pending_operations
@@ -2171,11 +2243,6 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2171
2243
def base_form_operator_derivative (self , o , f , dummy_w , dummy_v , dummy_cd ):
2172
2244
"""Apply to a base_form_operator_derivative."""
2173
2245
dummy , w , v , cd = o .ufl_operands
2174
- pending_operations = BaseFormOperatorDerivativeRecorder (
2175
- f , w , arguments = v , coefficient_derivatives = cd
2176
- )
2177
- rules = BaseFormOperatorDerivativeRuleset (w , v , cd , pending_operations = pending_operations )
2178
- key = (BaseFormOperatorDerivativeRuleset , w , v , cd )
2179
2246
if isinstance (f , ZeroBaseForm ):
2180
2247
(arg ,) = v .ufl_operands
2181
2248
arguments = f .arguments ()
@@ -2185,10 +2252,18 @@ def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2185
2252
if isinstance (arg , BaseArgument ):
2186
2253
arguments += (arg ,)
2187
2254
return ZeroBaseForm (arguments )
2255
+ pending_operations = BaseFormOperatorDerivativeRecorder (
2256
+ f , w , arguments = v , coefficient_derivatives = cd
2257
+ )
2258
+ key = (BaseFormOperatorDerivativeRuleset , w , v , cd , f )
2188
2259
# We need to go through the dag first to record the pending operations
2189
- mapped_expr = map_expr_dag (rules , f , vcache = self .vcaches [key ], rcache = self .rcaches [key ])
2190
-
2191
- mapped_f = rules .coefficient (f )
2260
+ dag_traverser = self ._dag_traverser_dict .setdefault (
2261
+ key , BaseFormOperatorDerivativeRuleset (w , v , cd , pending_operations ),
2262
+ )
2263
+ # If f has been seen by the traverser, it immediately returns
2264
+ # the cached value.
2265
+ mapped_expr = dag_traverser (f )
2266
+ mapped_f = dag_traverser ._process_coefficient (f )
2192
2267
if mapped_f != 0 :
2193
2268
# If dN/dN needs to return an Argument in N space
2194
2269
# with N a BaseFormOperator.
0 commit comments