Skip to content

Commit d3e72fb

Browse files
committed
handle restrictions on MixedMesh
1 parent 52bd425 commit d3e72fb

12 files changed

+559
-38
lines changed

test/test_mixed_function_space_with_mixed_mesh.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ufl import (triangle, Mesh, MixedMesh, FunctionSpace, TestFunction, TrialFunction, Coefficient,
2-
Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, split, )
2+
Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, div, split, )
33
from ufl.algorithms import compute_form_data
44
from ufl.finiteelement import FiniteElement, MixedElement
55
from ufl.pullback import identity_pullback, contravariant_piola
@@ -27,8 +27,9 @@ def test_mixed_function_space_with_mixed_mesh_basic():
2727
f0, f1, f2 = split(f)
2828
g0, g1, g2 = split(g)
2929
dx1 = Measure("dx", mesh1)
30+
ds2 = Measure("ds", mesh2)
3031
x = SpatialCoordinate(mesh1)
31-
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999)
32+
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999) + div(f1) * g2 * inner(u1, grad(v2)) * ds2(888)
3233
fd = compute_form_data(form,
3334
do_apply_function_pullbacks=True,
3435
do_apply_integral_scaling=True,
@@ -37,16 +38,60 @@ def test_mixed_function_space_with_mixed_mesh_basic():
3738
do_apply_restrictions=True,
3839
do_estimate_degrees=True,
3940
complex_mode=False)
40-
id0, = fd.integral_data
41+
id0, id1 = fd.integral_data
4142
assert fd.preprocessed_form.arguments() == (v, u)
42-
assert fd.reduced_coefficients == [f]
43+
assert fd.reduced_coefficients == [f, g]
4344
assert form.coefficients()[fd.original_coefficient_positions[0]] is f
45+
assert form.coefficients()[fd.original_coefficient_positions[1]] is g
4446
assert id0.domain is mesh1
4547
assert id0.integral_type == 'cell'
4648
assert id0.subdomain_id == (999, )
4749
assert fd.original_form.domain_numbering()[id0.domain] == 0
4850
assert id0.integral_coefficients == set([f])
49-
assert id0.enabled_coefficients == [True]
51+
assert id0.enabled_coefficients == [True, False]
52+
assert id1.domain is mesh2
53+
assert id1.integral_type == 'exterior_facet'
54+
assert id1.subdomain_id == (888, )
55+
assert fd.original_form.domain_numbering()[id1.domain] == 1
56+
assert id1.integral_coefficients == set([f, g])
57+
assert id1.enabled_coefficients == [True, True]
58+
59+
60+
def test_mixed_function_space_with_mixed_mesh_restriction():
61+
cell = triangle
62+
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
63+
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2, ), contravariant_piola, HDiv)
64+
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
65+
elem = MixedElement([elem0, elem1, elem2])
66+
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100)
67+
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101)
68+
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=102)
69+
domain = MixedMesh(mesh0, mesh1, mesh2)
70+
V = FunctionSpace(domain, elem)
71+
V1 = FunctionSpace(mesh1, elem1)
72+
V2 = FunctionSpace(mesh2, elem2)
73+
u1 = TrialFunction(V1)
74+
v2 = TestFunction(V2)
75+
f = Coefficient(V, count=1000)
76+
g = Coefficient(V, count=2000)
77+
f0, f1, f2 = split(f)
78+
g0, g1, g2 = split(g)
79+
dS1 = Measure("dS", mesh1)
80+
x2 = SpatialCoordinate(mesh2)
81+
form = inner(x2, g1) * g2 * inner(u1('-'), grad(v2('|'))) * dS1(999)
82+
fd = compute_form_data(form,
83+
do_apply_function_pullbacks=True,
84+
do_apply_integral_scaling=True,
85+
do_apply_geometry_lowering=True,
86+
preserve_geometry_types=(CellVolume, FacetArea),
87+
do_apply_restrictions=True,
88+
do_estimate_degrees=True,
89+
do_split_coefficients=(f, g),
90+
do_assume_single_integral_type=False,
91+
complex_mode=False)
92+
integral_data, = fd.integral_data
93+
assert integral_data.domain_integral_type_map[mesh1] == "interior_facet"
94+
assert integral_data.domain_integral_type_map[mesh2] == "exterior_facet"
5095

5196

5297
def test_mixed_function_space_with_mixed_mesh_signature():
+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""Apply coefficient split.
2+
3+
This module contains classes and functions to split coefficients defined on mixed function spaces.
4+
"""
5+
6+
import numpy
7+
from ufl.classes import Restricted
8+
from ufl.corealg.map_dag import map_expr_dag
9+
from ufl.corealg.multifunction import MultiFunction, memoized_handler
10+
from ufl.domain import extract_unique_domain
11+
from ufl.classes import (Coefficient, Form, ReferenceGrad, ReferenceValue,
12+
Indexed, MultiIndex, Index, FixedIndex,
13+
ComponentTensor, ListTensor, Zero,
14+
NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted)
15+
from ufl import indices
16+
from ufl.checks import is_cellwise_constant
17+
from ufl.tensors import as_tensor
18+
19+
20+
class CoefficientSplitter(MultiFunction):
21+
22+
def __init__(self, coefficient_split):
23+
MultiFunction.__init__(self)
24+
self._coefficient_split = coefficient_split
25+
26+
expr = MultiFunction.reuse_if_untouched
27+
28+
def modified_terminal(self, o):
29+
restriction = None
30+
local_derivatives = 0
31+
reference_value = False
32+
t = o
33+
while not t._ufl_is_terminal_:
34+
assert t._ufl_is_terminal_modifier_, f"Got {repr(t)}"
35+
if isinstance(t, ReferenceValue):
36+
assert not reference_value, "Got twice pulled back terminal!"
37+
reference_value = True
38+
t, = t.ufl_operands
39+
elif isinstance(t, ReferenceGrad):
40+
local_derivatives += 1
41+
t, = t.ufl_operands
42+
elif isinstance(t, Restricted):
43+
assert restriction is None, "Got twice restricted terminal!"
44+
restriction = t._side
45+
t, = t.ufl_operands
46+
elif t._ufl_terminal_modifiers_:
47+
raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t)))
48+
else:
49+
raise ValueError("Unexpected type %s object %s." % (type(t), repr(t)))
50+
if not isinstance(t, Coefficient):
51+
# Only split coefficients
52+
return o
53+
if t not in self._coefficient_split:
54+
# Only split mixed coefficients
55+
return o
56+
# Reference value expected
57+
assert reference_value
58+
# Derivative indices
59+
beta = indices(local_derivatives)
60+
components = []
61+
for subcoeff in self._coefficient_split[t]:
62+
c = subcoeff
63+
# Apply terminal modifiers onto the subcoefficient
64+
if reference_value:
65+
c = ReferenceValue(c)
66+
for n in range(local_derivatives):
67+
# Return zero if expression is trivially constant. This has to
68+
# happen here because ReferenceGrad has no access to the
69+
# topological dimension of a literal zero.
70+
if is_cellwise_constant(c):
71+
dim = extract_unique_domain(subcoeff).topological_dimension()
72+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
73+
else:
74+
c = ReferenceGrad(c)
75+
if restriction == '+':
76+
c = PositiveRestricted(c)
77+
elif restriction == '-':
78+
c = NegativeRestricted(c)
79+
elif restriction == '|':
80+
c = SingleValueRestricted(c)
81+
elif restriction == '?':
82+
c = ToBeRestricted(c)
83+
elif restriction is not None:
84+
raise RuntimeError(f"Got unknown restriction: {restriction}")
85+
# Collect components of the subcoefficient
86+
for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape):
87+
# New modified terminal: component[alpha + beta]
88+
components.append(c[alpha + beta])
89+
# Repack derivative indices to shape
90+
c, = indices(1)
91+
return ComponentTensor(as_tensor(components)[c], MultiIndex((c,) + beta))
92+
93+
positive_restricted = modified_terminal
94+
negative_restricted = modified_terminal
95+
single_value_restricted = modified_terminal
96+
to_be_restricted = modified_terminal
97+
reference_grad = modified_terminal
98+
reference_value = modified_terminal
99+
terminal = modified_terminal
100+
101+
102+
def apply_coefficient_split(expr, coefficient_split):
103+
"""Split mixed coefficients, so mixed elements need not be
104+
implemented.
105+
106+
:arg split: A :py:class:`dict` mapping each mixed coefficient to a
107+
sequence of subcoefficients. If None, calling this
108+
function is a no-op.
109+
"""
110+
if coefficient_split is None:
111+
return expr
112+
splitter = CoefficientSplitter(coefficient_split)
113+
return map_expr_dag(splitter, expr)
114+
115+
116+
class FixedIndexRemover(MultiFunction):
117+
118+
def __init__(self, fimap):
119+
MultiFunction.__init__(self)
120+
self.fimap = fimap
121+
self._object_cache = {}
122+
123+
expr = MultiFunction.reuse_if_untouched
124+
125+
@memoized_handler
126+
def zero(self, o):
127+
free_indices = []
128+
index_dimensions = []
129+
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
130+
if Index(i) in self.fimap:
131+
ind_j = self.fimap[Index(i)]
132+
if not isinstance(ind_j, FixedIndex):
133+
free_indices.append(ind_j.count())
134+
index_dimensions.append(d)
135+
else:
136+
free_indices.append(i)
137+
index_dimensions.append(d)
138+
return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions))
139+
140+
@memoized_handler
141+
def list_tensor(self, o):
142+
cc = []
143+
for o1 in o.ufl_operands:
144+
comp = map_expr_dag(self, o1)
145+
cc.append(comp)
146+
return ListTensor(*cc)
147+
148+
@memoized_handler
149+
def multi_index(self, o):
150+
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))
151+
152+
153+
class IndexRemover(MultiFunction):
154+
155+
def __init__(self):
156+
MultiFunction.__init__(self)
157+
self._object_cache = {}
158+
159+
expr = MultiFunction.reuse_if_untouched
160+
161+
@memoized_handler
162+
def _zero_simplify(self, o):
163+
operand, = o.ufl_operands
164+
operand = map_expr_dag(self, operand)
165+
if isinstance(operand, Zero):
166+
return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions)
167+
else:
168+
return o._ufl_expr_reconstruct_(operand)
169+
170+
@memoized_handler
171+
def indexed(self, o):
172+
o1, i1 = o.ufl_operands
173+
if isinstance(o1, ComponentTensor):
174+
o2, i2 = o1.ufl_operands
175+
assert len(i2.indices()) == len(i1.indices())
176+
fimap = dict(zip(i2.indices(), i1.indices()))
177+
rule = FixedIndexRemover(fimap)
178+
v = map_expr_dag(rule, o2)
179+
return map_expr_dag(self, v)
180+
elif isinstance(o1, ListTensor):
181+
if isinstance(i1[0], FixedIndex):
182+
o1 = o1.ufl_operands[i1[0]._value]
183+
if len(i1) > 1:
184+
i1 = MultiIndex(i1[1:])
185+
return map_expr_dag(self, Indexed(o1, i1))
186+
else:
187+
return map_expr_dag(self, o1)
188+
o1 = map_expr_dag(self, o1)
189+
return Indexed(o1, i1)
190+
191+
# Do something nicer
192+
positive_restricted = _zero_simplify
193+
negative_restricted = _zero_simplify
194+
single_value_restricted = _zero_simplify
195+
to_be_restricted = _zero_simplify
196+
reference_grad = _zero_simplify
197+
reference_value = _zero_simplify
198+
199+
200+
def remove_component_and_list_tensors(o):
201+
if isinstance(o, Form):
202+
integrals = []
203+
for integral in o.integrals():
204+
integrand = remove_component_and_list_tensors(integral.integrand())
205+
if not isinstance(integrand, Zero):
206+
integrals.append(integral.reconstruct(integrand=integrand))
207+
return o._ufl_expr_reconstruct_(integrals)
208+
else:
209+
rule = IndexRemover()
210+
return map_expr_dag(rule, o)

0 commit comments

Comments
 (0)