17
17
from ufl .algorithms .analysis import extract_arguments
18
18
from ufl .algorithms .map_integrands import map_integrand_dags
19
19
from ufl .algorithms .replace_derivative_nodes import replace_derivative_nodes
20
- from ufl .argument import Argument , BaseArgument
20
+ from ufl .argument import Argument , BaseArgument , Coargument
21
21
from ufl .averaging import CellAvg , FacetAvg
22
22
from ufl .checks import is_cellwise_constant
23
23
from ufl .classes import (
24
24
Abs ,
25
25
CellCoordinate ,
26
26
Coefficient ,
27
+ Cofunction ,
27
28
ComponentTensor ,
28
29
Conj ,
29
30
Constant ,
60
61
from ufl .constantvalue import is_true_ufl_scalar , is_ufl_scalar
61
62
from ufl .core .base_form_operator import BaseFormOperator
62
63
from ufl .core .expr import ufl_err_str
64
+ from ufl .core .external_operator import ExternalOperator
65
+ from ufl .core .interpolate import Interpolate
63
66
from ufl .core .multiindex import FixedIndex , MultiIndex , indices
64
67
from ufl .core .terminal import Terminal
65
68
from ufl .corealg .dag_traverser import DAGTraverser
94
97
Tan ,
95
98
Tanh ,
96
99
)
100
+ from ufl .matrix import Matrix
97
101
from ufl .operators import (
98
102
MaxValue ,
99
103
MinValue ,
@@ -1688,59 +1692,96 @@ def _(self, o: Expr) -> Expr:
1688
1692
return self .independent_operator (o )
1689
1693
1690
1694
1691
- class GateauxDerivativeRuleset (GenericDerivativeRulesetMultiFunction ):
1695
+ class GateauxDerivativeRuleset (GenericDerivativeRuleset ):
1692
1696
"""Apply AFD (Automatic Functional Differentiation) to expression.
1693
1697
1694
1698
Implements rules for the Gateaux derivative D_w[v](...) defined as
1695
1699
D_w[v](e) = d/dtau e(w+tau v)|tau=0.
1696
1700
"""
1697
1701
1698
- def __init__ (self , coefficients , arguments , coefficient_derivatives , pending_operations ):
1702
+ def __init__ (
1703
+ self ,
1704
+ coefficients ,
1705
+ arguments ,
1706
+ coefficient_derivatives ,
1707
+ pending_operations ,
1708
+ compress = True ,
1709
+ vcache = None ,
1710
+ rcache = None ,
1711
+ ):
1699
1712
"""Initialise."""
1700
- GenericDerivativeRulesetMultiFunction .__init__ (self , var_shape = ())
1701
-
1713
+ super ().__init__ ((), compress = compress , vcache = vcache , rcache = rcache )
1702
1714
# Type checking
1703
1715
if not isinstance (coefficients , ExprList ):
1704
1716
raise ValueError ("Expecting a ExprList of coefficients." )
1705
1717
if not isinstance (arguments , ExprList ):
1706
1718
raise ValueError ("Expecting a ExprList of arguments." )
1707
1719
if not isinstance (coefficient_derivatives , ExprMapping ):
1708
1720
raise ValueError ("Expecting a coefficient-coefficient ExprMapping." )
1709
-
1710
1721
# The coefficient(s) to differentiate w.r.t. and the
1711
1722
# argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
1712
1723
self ._w = coefficients .ufl_operands
1713
1724
self ._v = arguments .ufl_operands
1714
1725
self ._w2v = {w : v for w , v in zip (self ._w , self ._v )}
1715
-
1716
1726
# Build more convenient dict {f: df/dw} for each coefficient f
1717
1727
# where df/dw is nonzero
1718
1728
cd = coefficient_derivatives .ufl_operands
1719
1729
self ._cd = {cd [2 * i ]: cd [2 * i + 1 ] for i in range (len (cd ) // 2 )}
1720
-
1721
1730
# Record the operations delayed to the derivative expansion phase:
1722
1731
# Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
1723
1732
self .pending_operations = pending_operations
1724
1733
1725
- # Explicitly defining dg/dw == 0
1726
- geometric_quantity = GenericDerivativeRulesetMultiFunction .independent_terminal
1734
+ # Work around singledispatchmethod inheritance issue;
1735
+ # see https://bugs.python.org/issue36457.
1736
+ @singledispatchmethod
1737
+ def process (self , o : Expr ) -> Expr :
1738
+ """Process ``o``.
1739
+
1740
+ Args:
1741
+ o: `Expr` to be processed.
1742
+
1743
+ Returns:
1744
+ Processed object.
1745
+
1746
+ """
1747
+ return super ().process (o )
1727
1748
1728
- def cell_avg (self , o , fp ):
1749
+ # --- Specialized rules for geometric quantities
1750
+
1751
+ @process .register (GeometricQuantity )
1752
+ def _ (self , o : Expr ) -> Expr :
1753
+ # Explicitly defining dg/dw == 0
1754
+ return self .independent_terminal (o )
1755
+
1756
+ @process .register (CellAvg )
1757
+ @DAGTraverser .postorder
1758
+ def _ (self , o : Expr , fp ) -> Expr :
1729
1759
"""Differentiate a cell_avg."""
1730
1760
# Cell average of a single function and differentiation
1731
1761
# commutes, D_f[v](cell_avg(f)) = cell_avg(v)
1732
1762
return cell_avg (fp )
1733
1763
1734
- def facet_avg (self , o , fp ):
1764
+ @process .register (FacetAvg )
1765
+ @DAGTraverser .postorder
1766
+ def _ (self , o : Expr , fp ) -> Expr :
1735
1767
"""Differentiate a facet_avg."""
1736
1768
# Facet average of a single function and differentiation
1737
1769
# commutes, D_f[v](facet_avg(f)) = facet_avg(v)
1738
1770
return facet_avg (fp )
1739
1771
1740
- # Explicitly defining da/dw == 0
1741
- argument = GenericDerivativeRulesetMultiFunction .independent_terminal
1772
+ @process .register (Argument )
1773
+ def _ (self , o : Expr ) -> Expr :
1774
+ # Explicitly defining da/dw == 0
1775
+ return self ._process_argument (o )
1742
1776
1743
- def coefficient (self , o ):
1777
+ def _process_argument (self , o : Expr ) -> Expr :
1778
+ return self .independent_terminal (o )
1779
+
1780
+ @process .register (Coefficient )
1781
+ def _ (self , o : Expr ) -> Expr :
1782
+ return self ._process_coefficient (o )
1783
+
1784
+ def _process_coefficient (self , o : Expr ) -> Expr :
1744
1785
"""Differentiate a coefficient."""
1745
1786
# Define dw/dw := d/ds [w + s v] = v
1746
1787
@@ -1788,7 +1829,8 @@ def coefficient(self, o):
1788
1829
dosum += prod
1789
1830
return dosum
1790
1831
1791
- def reference_value (self , o ):
1832
+ @process .register (ReferenceValue )
1833
+ def _ (self , o : Expr ) -> Expr :
1792
1834
"""Differentiate a reference_value."""
1793
1835
raise NotImplementedError (
1794
1836
"Currently no support for ReferenceValue in CoefficientDerivative."
@@ -1808,7 +1850,8 @@ def reference_value(self, o):
1808
1850
# else:
1809
1851
# return self.independent_terminal(o)
1810
1852
1811
- def reference_grad (self , o ):
1853
+ @process .register (ReferenceGrad )
1854
+ def _ (self , o : Expr ) -> Expr :
1812
1855
"""Differentiate a reference_grad."""
1813
1856
raise NotImplementedError (
1814
1857
"Currently no support for ReferenceGrad in CoefficientDerivative."
@@ -1819,7 +1862,8 @@ def reference_grad(self, o):
1819
1862
# this to allow the user to write
1820
1863
# derivative(...ReferenceValue...,...).
1821
1864
1822
- def grad (self , g ):
1865
+ @process .register (Grad )
1866
+ def _ (self , g : Expr ) -> Expr :
1823
1867
"""Differentiate a grad."""
1824
1868
# If we hit this type, it has already been propagated to a
1825
1869
# coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
@@ -1976,48 +2020,54 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
1976
2020
1977
2021
return gprimesum
1978
2022
1979
- def coordinate_derivative (self , o ):
2023
+ @process .register (CoordinateDerivative )
2024
+ def _ (self , o : Expr ) -> Expr :
1980
2025
"""Differentiate a coordinate_derivative."""
1981
2026
o = o .ufl_operands
1982
2027
return CoordinateDerivative (map_expr_dag (self , o [0 ]), o [1 ], o [2 ], o [3 ])
1983
2028
1984
- def base_form_operator (self , o , * dfs ):
2029
+ @process .register (BaseFormOperator )
2030
+ @DAGTraverser .postorder
2031
+ def _ (self , o : Expr , * dfs ) -> Expr :
1985
2032
"""Differentiate a base_form_operator.
1986
2033
1987
2034
If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
1988
2035
variable => we call the appropriate handler. Otherwise =>
1989
2036
differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
1990
2037
we treat o as a Coefficient.
1991
2038
"""
1992
- d_coeff = self .coefficient (o )
2039
+ d_coeff = self ._process_coefficient (o )
1993
2040
# It also handles the non-scalar case
1994
2041
if d_coeff == 0 :
1995
2042
self .pending_operations += (o ,)
1996
2043
return d_coeff
1997
2044
1998
2045
# -- Handlers for BaseForm objects -- #
1999
2046
2000
- def cofunction (self , o ):
2047
+ @process .register (Cofunction )
2048
+ def _ (self , o : Expr ) -> Expr :
2001
2049
"""Differentiate a cofunction."""
2002
2050
# Same rule than for Coefficient except that we use a Coargument.
2003
2051
# The coargument is already attached to the class (self._v)
2004
2052
# which `self.coefficient` relies on.
2005
- dc = self .coefficient (o )
2053
+ dc = self ._process_coefficient (o )
2006
2054
if dc == 0 :
2007
2055
# Convert ufl.Zero into ZeroBaseForm
2008
2056
return ZeroBaseForm (o .arguments () + self ._v )
2009
2057
return dc
2010
2058
2011
- def coargument (self , o ):
2059
+ @process .register (Coargument )
2060
+ def _ (self , o : Expr ) -> Expr :
2012
2061
"""Differentiate a coargument."""
2013
2062
# Same rule than for Argument (da/dw == 0).
2014
- dc = self .argument (o )
2063
+ dc = self ._process_argument (o )
2015
2064
if dc == 0 :
2016
2065
# Convert ufl.Zero into ZeroBaseForm
2017
2066
return ZeroBaseForm (o .arguments () + self ._v )
2018
2067
return dc
2019
2068
2020
- def matrix (self , M ):
2069
+ @process .register (Matrix )
2070
+ def _ (self , M : Expr ) -> Expr :
2021
2071
"""Differentiate a matrix."""
2022
2072
# Matrix rule: D_w[v](M) = v if M == w else 0
2023
2073
# We can't differentiate wrt a matrix so always return zero in
@@ -2032,10 +2082,25 @@ class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
2032
2082
D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
2033
2083
"""
2034
2084
2035
- def __init__ (self , coefficients , arguments , coefficient_derivatives , pending_operations ):
2085
+ def __init__ (
2086
+ self ,
2087
+ coefficients ,
2088
+ arguments ,
2089
+ coefficient_derivatives ,
2090
+ pending_operations ,
2091
+ compress = True ,
2092
+ vcache = None ,
2093
+ rcache = None ,
2094
+ ):
2036
2095
"""Initialise."""
2037
- GateauxDerivativeRuleset .__init__ (
2038
- self , coefficients , arguments , coefficient_derivatives , pending_operations
2096
+ super ().__init__ (
2097
+ coefficients ,
2098
+ arguments ,
2099
+ coefficient_derivatives ,
2100
+ pending_operations ,
2101
+ compress = compress ,
2102
+ vcache = vcache ,
2103
+ rcache = rcache ,
2039
2104
)
2040
2105
2041
2106
def pending_operations_recording (base_form_operator_handler ):
@@ -2056,13 +2121,30 @@ def wrapper(self, base_form_op, *dfs):
2056
2121
# calling the appropriate handler.
2057
2122
if expression != base_form_op :
2058
2123
self .pending_operations += (base_form_op ,)
2059
- return self .coefficient (base_form_op )
2124
+ return self ._process_coefficient (base_form_op )
2060
2125
return base_form_operator_handler (self , base_form_op , * dfs )
2061
2126
2062
2127
return wrapper
2063
2128
2129
+ # Work around singledispatchmethod inheritance issue;
2130
+ # see https://bugs.python.org/issue36457.
2131
+ @singledispatchmethod
2132
+ def process (self , o : Expr ) -> Expr :
2133
+ """Process ``o``.
2134
+
2135
+ Args:
2136
+ o: `Expr` to be processed.
2137
+
2138
+ Returns:
2139
+ Processed object.
2140
+
2141
+ """
2142
+ return super ().process (o )
2143
+
2144
+ @process .register (Interpolate )
2145
+ @DAGTraverser .postorder
2064
2146
@pending_operations_recording
2065
- def interpolate (self , i_op , dw ):
2147
+ def _ (self , i_op , dw ):
2066
2148
"""Differentiate an interpolate."""
2067
2149
# Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
2068
2150
if not dw :
@@ -2072,6 +2154,8 @@ def interpolate(self, i_op, dw):
2072
2154
return ZeroBaseForm (i_op .arguments () + self ._v )
2073
2155
return i_op ._ufl_expr_reconstruct_ (expr = dw )
2074
2156
2157
+ @process .register (ExternalOperator )
2158
+ @DAGTraverser .postorder
2075
2159
@pending_operations_recording
2076
2160
def external_operator (self , N , * dfs ):
2077
2161
"""Differentiate an external_operator."""
@@ -2121,6 +2205,7 @@ def __init__(self):
2121
2205
self ._grad_ruleset_dict = {}
2122
2206
self ._reference_grad_ruleset_dict = {}
2123
2207
self ._variable_ruleset_dict = {}
2208
+ self ._dag_traverser_dict = {}
2124
2209
2125
2210
def terminal (self , o ):
2126
2211
"""Apply to a terminal."""
@@ -2158,11 +2243,16 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2158
2243
pending_operations = BaseFormOperatorDerivativeRecorder (
2159
2244
f , w , arguments = v , coefficient_derivatives = cd
2160
2245
)
2161
- rules = GateauxDerivativeRuleset (w , v , cd , pending_operations )
2162
- key = (GateauxDerivativeRuleset , w , v , cd )
2246
+ key = (GateauxDerivativeRuleset , w , v , cd , f )
2163
2247
# We need to go through the dag first to record the pending
2164
2248
# operations
2165
- mapped_expr = map_expr_dag (rules , f , vcache = self .vcaches [key ], rcache = self .rcaches [key ])
2249
+ dag_traverser = self ._dag_traverser_dict .setdefault (
2250
+ key ,
2251
+ GateauxDerivativeRuleset (w , v , cd , pending_operations ),
2252
+ )
2253
+ # If f has been seen by the traverser, it immediately returns
2254
+ # the cached value.
2255
+ mapped_expr = dag_traverser (f )
2166
2256
# Need to account for pending operations that have been stored
2167
2257
# in other integrands
2168
2258
self .pending_operations += pending_operations
@@ -2171,11 +2261,6 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2171
2261
def base_form_operator_derivative (self , o , f , dummy_w , dummy_v , dummy_cd ):
2172
2262
"""Apply to a base_form_operator_derivative."""
2173
2263
dummy , w , v , cd = o .ufl_operands
2174
- pending_operations = BaseFormOperatorDerivativeRecorder (
2175
- f , w , arguments = v , coefficient_derivatives = cd
2176
- )
2177
- rules = BaseFormOperatorDerivativeRuleset (w , v , cd , pending_operations = pending_operations )
2178
- key = (BaseFormOperatorDerivativeRuleset , w , v , cd )
2179
2264
if isinstance (f , ZeroBaseForm ):
2180
2265
(arg ,) = v .ufl_operands
2181
2266
arguments = f .arguments ()
@@ -2185,10 +2270,19 @@ def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
2185
2270
if isinstance (arg , BaseArgument ):
2186
2271
arguments += (arg ,)
2187
2272
return ZeroBaseForm (arguments )
2273
+ pending_operations = BaseFormOperatorDerivativeRecorder (
2274
+ f , w , arguments = v , coefficient_derivatives = cd
2275
+ )
2276
+ key = (BaseFormOperatorDerivativeRuleset , w , v , cd , f )
2188
2277
# We need to go through the dag first to record the pending operations
2189
- mapped_expr = map_expr_dag (rules , f , vcache = self .vcaches [key ], rcache = self .rcaches [key ])
2190
-
2191
- mapped_f = rules .coefficient (f )
2278
+ dag_traverser = self ._dag_traverser_dict .setdefault (
2279
+ key ,
2280
+ BaseFormOperatorDerivativeRuleset (w , v , cd , pending_operations ),
2281
+ )
2282
+ # If f has been seen by the traverser, it immediately returns
2283
+ # the cached value.
2284
+ mapped_expr = dag_traverser (f )
2285
+ mapped_f = dag_traverser ._process_coefficient (f )
2192
2286
if mapped_f != 0 :
2193
2287
# If dN/dN needs to return an Argument in N space
2194
2288
# with N a BaseFormOperator.
0 commit comments