Skip to content

Commit e66d41d

Browse files
committed
handle restrictions on MixedMesh
1 parent 285d8e6 commit e66d41d

12 files changed

+622
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from ufl import (triangle, Mesh, MixedMesh, FunctionSpace, TestFunction, TrialFunction, Coefficient, Constant,
2+
Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, div, split, )
3+
from ufl.algorithms import compute_form_data
4+
from ufl.finiteelement import FiniteElement, MixedElement
5+
from ufl.pullback import identity_pullback, contravariant_piola
6+
from ufl.sobolevspace import H1, HDiv, L2
7+
from ufl.domain import extract_domains
8+
9+
10+
def test_mixed_function_space_with_mixed_mesh_basic():
11+
cell = triangle
12+
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
13+
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2, ), contravariant_piola, HDiv)
14+
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
15+
elem = MixedElement([elem0, elem1, elem2])
16+
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100)
17+
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101)
18+
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=102)
19+
domain = MixedMesh(mesh0, mesh1, mesh2)
20+
V = FunctionSpace(domain, elem)
21+
u = TrialFunction(V)
22+
v = TestFunction(V)
23+
f = Coefficient(V, count=1000)
24+
g = Coefficient(V, count=2000)
25+
u0, u1, u2 = split(u)
26+
v0, v1, v2 = split(v)
27+
f0, f1, f2 = split(f)
28+
g0, g1, g2 = split(g)
29+
dx1 = Measure("dx", mesh1)
30+
ds2 = Measure("ds", mesh2)
31+
x = SpatialCoordinate(mesh1)
32+
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999) + div(f1) * g2 * inner(u1, grad(v2)) * ds2(888)
33+
fd = compute_form_data(form,
34+
do_apply_function_pullbacks=True,
35+
do_apply_integral_scaling=True,
36+
do_apply_geometry_lowering=True,
37+
preserve_geometry_types=(CellVolume, FacetArea),
38+
do_apply_restrictions=True,
39+
do_estimate_degrees=True,
40+
complex_mode=False)
41+
id0, id1 = fd.integral_data
42+
assert fd.preprocessed_form.arguments() == (v, u)
43+
assert fd.reduced_coefficients == [f, g]
44+
assert form.coefficients()[fd.original_coefficient_positions[0]] is f
45+
assert form.coefficients()[fd.original_coefficient_positions[1]] is g
46+
assert id0.domain is mesh1
47+
assert id0.integral_type == 'cell'
48+
assert id0.subdomain_id == (999, )
49+
assert fd.original_form.domain_numbering()[id0.domain] == 0
50+
assert id0.integral_coefficients == set([f])
51+
assert id0.enabled_coefficients == [True, False]
52+
assert id1.domain is mesh2
53+
assert id1.integral_type == 'exterior_facet'
54+
assert id1.subdomain_id == (888, )
55+
assert fd.original_form.domain_numbering()[id1.domain] == 1
56+
assert id1.integral_coefficients == set([f, g])
57+
assert id1.enabled_coefficients == [True, True]
58+
59+
60+
def test_mixed_function_space_with_mixed_mesh_restriction():
61+
cell = triangle
62+
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
63+
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2, ), contravariant_piola, HDiv)
64+
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
65+
elem = MixedElement([elem0, elem1, elem2])
66+
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100)
67+
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101)
68+
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=102)
69+
domain = MixedMesh(mesh0, mesh1, mesh2)
70+
V = FunctionSpace(domain, elem)
71+
V0 = FunctionSpace(mesh0, elem0)
72+
V1 = FunctionSpace(mesh1, elem1)
73+
V2 = FunctionSpace(mesh2, elem2)
74+
u1 = TrialFunction(V1)
75+
v2 = TestFunction(V2)
76+
f = Coefficient(V, count=1000)
77+
g = Coefficient(V, count=2000)
78+
f0, f1, f2 = split(f)
79+
g0, g1, g2 = split(g)
80+
dS1 = Measure("dS", mesh1)
81+
x2 = SpatialCoordinate(mesh2)
82+
form = inner(x2, g1) * g2 * inner(u1('-'), grad(v2('|'))) * dS1(999)
83+
fd = compute_form_data(form,
84+
do_apply_function_pullbacks=True,
85+
do_apply_integral_scaling=True,
86+
do_apply_geometry_lowering=True,
87+
preserve_geometry_types=(CellVolume, FacetArea),
88+
do_apply_restrictions=True,
89+
do_estimate_degrees=True,
90+
do_split_coefficients=(f, g),
91+
do_assume_single_integral_type=False,
92+
complex_mode=False)
93+
integral_data, = fd.integral_data
94+
assert integral_data.domain_integral_type_map[mesh1] == "interior_facet"
95+
assert integral_data.domain_integral_type_map[mesh2] == "exterior_facet"
96+
97+
98+
def test_mixed_function_space_with_mixed_mesh_signature():
99+
cell = triangle
100+
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100)
101+
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101)
102+
dx0 = Measure("dx", mesh0)
103+
dx1 = Measure("dx", mesh1)
104+
n0 = FacetNormal(mesh0)
105+
n1 = FacetNormal(mesh1)
106+
form_a = inner(n1, n1) * dx0(999)
107+
form_b = inner(n0, n0) * dx1(999)
108+
assert form_a.signature() == form_b.signature()
109+
assert extract_domains(form_a) == (mesh0, mesh1)
110+
assert extract_domains(form_b) == (mesh1, mesh0)
+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""Apply coefficient split.
2+
3+
This module contains classes and functions to split coefficients defined on mixed function spaces.
4+
"""
5+
6+
import functools
7+
import numpy
8+
from ufl.classes import Restricted
9+
from ufl.corealg.map_dag import map_expr_dag
10+
from ufl.corealg.multifunction import MultiFunction, memoized_handler
11+
from ufl.domain import extract_unique_domain
12+
from ufl.classes import (Coefficient, Form, ReferenceGrad, ReferenceValue,
13+
Indexed, MultiIndex, Index, FixedIndex,
14+
ComponentTensor, ListTensor, Zero,
15+
NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted)
16+
from ufl import indices
17+
from ufl.checks import is_cellwise_constant
18+
from ufl.tensors import as_tensor
19+
20+
21+
class CoefficientSplitter(MultiFunction):
22+
23+
def __init__(self, coefficient_split):
24+
MultiFunction.__init__(self)
25+
self._coefficient_split = coefficient_split
26+
27+
expr = MultiFunction.reuse_if_untouched
28+
29+
def modified_terminal(self, o):
30+
restriction = None
31+
local_derivatives = 0
32+
reference_value = False
33+
t = o
34+
while not t._ufl_is_terminal_:
35+
assert t._ufl_is_terminal_modifier_, f"Got {repr(t)}"
36+
if isinstance(t, ReferenceValue):
37+
assert not reference_value, "Got twice pulled back terminal!"
38+
reference_value = True
39+
t, = t.ufl_operands
40+
elif isinstance(t, ReferenceGrad):
41+
local_derivatives += 1
42+
t, = t.ufl_operands
43+
elif isinstance(t, Restricted):
44+
assert restriction is None, "Got twice restricted terminal!"
45+
restriction = t._side
46+
t, = t.ufl_operands
47+
elif t._ufl_terminal_modifiers_:
48+
raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t)))
49+
else:
50+
raise ValueError("Unexpected type %s object %s." % (type(t), repr(t)))
51+
if not isinstance(t, Coefficient):
52+
# Only split coefficients
53+
return o
54+
if t not in self._coefficient_split:
55+
# Only split mixed coefficients
56+
return o
57+
# Reference value expected
58+
assert reference_value
59+
# Derivative indices
60+
beta = indices(local_derivatives)
61+
components = []
62+
for subcoeff in self._coefficient_split[t]:
63+
c = subcoeff
64+
# Apply terminal modifiers onto the subcoefficient
65+
if reference_value:
66+
c = ReferenceValue(c)
67+
for n in range(local_derivatives):
68+
# Return zero if expression is trivially constant. This has to
69+
# happen here because ReferenceGrad has no access to the
70+
# topological dimension of a literal zero.
71+
if is_cellwise_constant(c):
72+
dim = extract_unique_domain(subcoeff).topological_dimension()
73+
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
74+
else:
75+
c = ReferenceGrad(c)
76+
if restriction == '+':
77+
c = PositiveRestricted(c)
78+
elif restriction == '-':
79+
c = NegativeRestricted(c)
80+
elif restriction == '|':
81+
c = SingleValueRestricted(c)
82+
elif restriction == '?':
83+
c = ToBeRestricted(c)
84+
elif restriction is not None:
85+
raise RuntimeError(f"Got unknown restriction: {restriction}")
86+
# Collect components of the subcoefficient
87+
for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape):
88+
# New modified terminal: component[alpha + beta]
89+
components.append(c[alpha + beta])
90+
# Repack derivative indices to shape
91+
c, = indices(1)
92+
return ComponentTensor(as_tensor(components)[c], MultiIndex((c,) + beta))
93+
94+
positive_restricted = modified_terminal
95+
negative_restricted = modified_terminal
96+
single_value_restricted = modified_terminal
97+
to_be_restricted = modified_terminal
98+
reference_grad = modified_terminal
99+
reference_value = modified_terminal
100+
terminal = modified_terminal
101+
102+
103+
def apply_coefficient_split(expr, coefficient_split):
104+
"""Split mixed coefficients, so mixed elements need not be
105+
implemented.
106+
107+
:arg split: A :py:class:`dict` mapping each mixed coefficient to a
108+
sequence of subcoefficients. If None, calling this
109+
function is a no-op.
110+
"""
111+
if coefficient_split is None:
112+
return expr
113+
splitter = CoefficientSplitter(coefficient_split)
114+
return map_expr_dag(splitter, expr)
115+
116+
117+
class FixedIndexRemover(MultiFunction):
118+
119+
def __init__(self, fimap):
120+
MultiFunction.__init__(self)
121+
self.fimap = fimap
122+
self._object_cache = {}
123+
124+
expr = MultiFunction.reuse_if_untouched
125+
126+
@memoized_handler
127+
def zero(self, o):
128+
free_indices = []
129+
index_dimensions = []
130+
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
131+
if Index(i) in self.fimap:
132+
ind_j = self.fimap[Index(i)]
133+
if not isinstance(ind_j, FixedIndex):
134+
free_indices.append(ind_j.count())
135+
index_dimensions.append(d)
136+
else:
137+
free_indices.append(i)
138+
index_dimensions.append(d)
139+
return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions))
140+
141+
@memoized_handler
142+
def list_tensor(self, o):
143+
cc = []
144+
for o1 in o.ufl_operands:
145+
comp = map_expr_dag(self, o1)
146+
cc.append(comp)
147+
return ListTensor(*cc)
148+
149+
@memoized_handler
150+
def multi_index(self, o):
151+
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))
152+
153+
154+
class IndexRemover(MultiFunction):
155+
156+
def __init__(self):
157+
MultiFunction.__init__(self)
158+
self._object_cache = {}
159+
160+
expr = MultiFunction.reuse_if_untouched
161+
162+
@memoized_handler
163+
def _zero_simplify(self, o):
164+
operand, = o.ufl_operands
165+
operand = map_expr_dag(self, operand)
166+
if isinstance(operand, Zero):
167+
return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions)
168+
else:
169+
return o._ufl_expr_reconstruct_(operand)
170+
171+
@memoized_handler
172+
def indexed(self, o):
173+
o1, i1 = o.ufl_operands
174+
if isinstance(o1, ComponentTensor):
175+
o2, i2 = o1.ufl_operands
176+
assert len(i2.indices()) == len(i1.indices())
177+
fimap = dict(zip(i2.indices(), i1.indices()))
178+
rule = FixedIndexRemover(fimap)
179+
v = map_expr_dag(rule, o2)
180+
return map_expr_dag(self, v)
181+
elif isinstance(o1, ListTensor):
182+
if isinstance(i1[0], FixedIndex):
183+
o1 = o1.ufl_operands[i1[0]._value]
184+
if len(i1) > 1:
185+
i1 = MultiIndex(i1[1:])
186+
return map_expr_dag(self, Indexed(o1, i1))
187+
else:
188+
return map_expr_dag(self, o1)
189+
o1 = map_expr_dag(self, o1)
190+
return Indexed(o1, i1)
191+
192+
# Do something nicer
193+
positive_restricted = _zero_simplify
194+
negative_restricted = _zero_simplify
195+
single_value_restricted = _zero_simplify
196+
to_be_restricted = _zero_simplify
197+
reference_grad = _zero_simplify
198+
reference_value = _zero_simplify
199+
200+
201+
def remove_component_and_list_tensors(o):
202+
if isinstance(o, Form):
203+
integrals = []
204+
for integral in o.integrals():
205+
integrand = remove_component_and_list_tensors(integral.integrand())
206+
if not isinstance(integrand, Zero):
207+
integrals.append(integral.reconstruct(integrand=integrand))
208+
return o._ufl_expr_reconstruct_(integrals)
209+
else:
210+
rule = IndexRemover()
211+
return map_expr_dag(rule, o)

0 commit comments

Comments
 (0)