Skip to content

Commit faffbc4

Browse files
committed
add CoefficientSplitter
1 parent 83b8213 commit faffbc4

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# using code from TSFC.
2+
3+
from functools import singledispatchmethod
4+
import numpy as np
5+
from ufl.classes import (
6+
Coefficient,
7+
ComponentTensor,
8+
Expr,
9+
MultiIndex,
10+
NegativeRestricted,
11+
PositiveRestricted,
12+
ReferenceGrad,
13+
ReferenceValue,
14+
Restricted,
15+
Terminal,
16+
Zero,
17+
)
18+
from ufl.corealg.dag_visitor import DAGVisitor
19+
from ufl.core.multiindex import indices
20+
from ufl.tensors import as_tensor
21+
22+
23+
class CoefficientSplitter(DAGVisitor):
24+
25+
def __init__(self, coefficient_split):
26+
"""Split mixed coefficients.
27+
28+
Args:
29+
coefficient_split: `dict` that maps mixed coefficients to their components.
30+
reference_value: If `ReferenceValue` has been applied.
31+
reference_grad: Number of `ReferenceGrad`s that have been applied.
32+
restricted: '+', '-', or None.
33+
cache: `dict` for caching DAG nodes.
34+
35+
Returns:
36+
This node wrapped with `ReferenceValue` (if ``reference_value``),
37+
`ReferenceGrad` (``reference_grad`` times), and `Restricted` (if
38+
``restricted`` is '+' or '-'). The underlying terminal will be
39+
decomposed into components according to ``coefficient_split``.
40+
41+
"""
42+
super().__init__()
43+
self._coefficient_split = coefficient_split
44+
45+
@singledispatchmethod
46+
def process(self, node, *args):
47+
"""Handle base case."""
48+
raise AssertionError(f"UFL node expected: got {node}")
49+
50+
@process.register(Expr)
51+
def _(self, node, *args):
52+
"""Handle Expr."""
53+
return self.reuse_if_untouched(node, *args)
54+
55+
@process.register(ReferenceValue)
56+
def _(self, node, reference_value: bool, reference_grad: int, restricted: str):
57+
"""Handle ReferenceValue."""
58+
if reference_value:
59+
raise RuntimeError(f"Can not apply ReferenceValue on a ReferenceValue: got {node}")
60+
op, = node.ufl_operands
61+
if not op._ufl_terminal_modifiers_:
62+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
63+
return self(op, True, reference_grad, restricted)
64+
65+
@process.register(ReferenceGrad)
66+
def _(self, node, reference_value: bool, reference_grad: int, restricted: str):
67+
"""Handle ReferenceGrad."""
68+
op, = node.ufl_operands
69+
if not op._ufl_terminal_modifiers_:
70+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
71+
return self(op, reference_value, reference_grad + 1, restricted)
72+
73+
@process.register(Restricted)
74+
def _(self, node, reference_value: bool, reference_grad: int, restricted: str):
75+
"""Handle Restricted."""
76+
if restricted is not None:
77+
raise RuntimeError(f"Can not apply Restricted on a Restricted: got {node}")
78+
op, = node.ufl_operands
79+
if not op._ufl_terminal_modifiers_:
80+
raise ValueError(f"Must be a terminal modifier: {op!r}.")
81+
return self(op, reference_value, reference_grad, node._side)
82+
83+
@process.register(Terminal)
84+
def _(self, node, reference_value: bool, reference_grad: int, restricted: str):
85+
"""Handle Terminal."""
86+
return self._handle_terminal(node, reference_value, reference_grad, restricted)
87+
88+
@process.register(Coefficient)
89+
def _(self, node, reference_value: bool, reference_grad: int, restricted: str):
90+
"""Handle Coefficient."""
91+
if node not in self._coefficient_split:
92+
return self._handle_terminal(node, reference_value, reference_grad, restricted)
93+
if not reference_value:
94+
raise RuntimeError(f"ReferenceValue expected: got {o}")
95+
beta = indices(reference_grad)
96+
components = []
97+
for coeff in self._coefficient_split[node]:
98+
c = self._handle_terminal(coeff, reference_value, reference_grad, restricted)
99+
for alpha in np.ndindex(coeff.ufl_element().reference_value_shape):
100+
components.append(c[alpha + beta])
101+
# Repack derivative indices to shape
102+
i, = indices(1)
103+
return ComponentTensor(as_tensor(components)[i], MultiIndex((i,) + beta))
104+
105+
def _handle_terminal(self, node, reference_value: bool, reference_grad: int, restricted: str):
106+
c = node
107+
if reference_value:
108+
c = ReferenceValue(c)
109+
for _ in range(reference_grad):
110+
c = ReferenceGrad(c)
111+
if restricted == "+":
112+
c = PositiveRestricted(c)
113+
elif restricted == "-":
114+
c = NegativeRestricted(c)
115+
elif restricted is not None:
116+
raise RuntimeError(f"Got unknown restriction: {restricted}")
117+
return c
118+
119+
120+
def apply_coefficient_split(expr: Expr, coefficient_split: dict):
121+
"""Split mixed coefficients.
122+
123+
Args:
124+
expr: UFL expression.
125+
coefficient_split: `dict` that maps mixed coefficients to their components.
126+
127+
Returns:
128+
``expr`` with uderlying mixed coefficients split according to ``coefficient_split``.
129+
130+
"""
131+
reference_value = False
132+
reference_grad = 0
133+
restricted = None
134+
return CoefficientSplitter(coefficient_split)(expr, reference_value, reference_grad, restricted)

0 commit comments

Comments
 (0)