Skip to content

Commit 4ced3a7

Browse files
committed
remove apply_default_restriction rebase
1 parent f9412a0 commit 4ced3a7

File tree

2 files changed

+92
-23
lines changed

2 files changed

+92
-23
lines changed

ufl/algorithms/apply_restrictions.py

+60-16
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,23 @@
2626
class RestrictionPropagator(MultiFunction):
2727
"""Restriction propagator."""
2828

29-
def __init__(self, side=None, assume_single_integral_type=True, apply_default=True):
29+
def __init__(self, side=None, assume_single_integral_type=True, apply_default=True, default_restriction=None):
3030
"""Initialise."""
3131
MultiFunction.__init__(self)
3232
self.current_restriction = side
33-
self.default_restriction = "+" if assume_single_integral_type else "?"
33+
if default_restriction is None:
34+
default_restriction = "+" if assume_single_integral_type else "?"
35+
self.default_restriction = default_restriction
3436
self.apply_default = apply_default
3537
# Caches for propagating the restriction with map_expr_dag
3638
self.vcaches = {"+": {}, "-": {}, "|": {}, "?": {}}
3739
self.rcaches = {"+": {}, "-": {}, "|": {}, "?": {}}
3840
if self.current_restriction is None:
3941
self._rp = {
40-
"+": RestrictionPropagator("+", assume_single_integral_type, apply_default),
41-
"-": RestrictionPropagator("-", assume_single_integral_type, apply_default),
42-
"|": RestrictionPropagator("|", assume_single_integral_type, apply_default),
43-
"?": RestrictionPropagator("?", assume_single_integral_type, apply_default),
42+
"+": RestrictionPropagator("+", assume_single_integral_type, apply_default, default_restriction),
43+
"-": RestrictionPropagator("-", assume_single_integral_type, apply_default, default_restriction),
44+
"|": RestrictionPropagator("|", assume_single_integral_type, apply_default, default_restriction),
45+
"?": RestrictionPropagator("?", assume_single_integral_type, apply_default, default_restriction),
4446
}
4547
self.assume_single_integral_type = assume_single_integral_type
4648

@@ -71,6 +73,9 @@ def _require_restriction(self, o):
7173
if self.current_restriction is not None:
7274
return o(self.current_restriction)
7375
elif not self.assume_single_integral_type:
76+
# If integration if over interior facet of meshA and exterior facet of meshB,
77+
# arguments (say) on meshA must be restricted, but those on meshB do not
78+
# need to be.
7479
return o
7580
else:
7681
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
@@ -84,7 +89,19 @@ def _default_restricted(self, o):
8489
domain = extract_unique_domain(o, expand_mixed_mesh=False)
8590
if isinstance(domain, MixedMesh):
8691
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}")
87-
return o(self.default_restriction[domain])
92+
if isinstance(self.default_restriction, dict):
93+
if domain not in self.default_restriction:
94+
raise RuntimeError(f"Integral type on {domain} not known")
95+
r = self.default_restriction[domain]
96+
if r is None:
97+
return o
98+
elif r in ["+", "-"]:
99+
return o(r)
100+
else:
101+
raise RuntimeError(f"Unknown default restriction {r} on domain {domain}")
102+
else:
103+
# conventional "+" default:
104+
return o(self.default_restriction)
88105
else:
89106
return o
90107

@@ -93,12 +110,26 @@ def _opposite(self, o):
93110
94111
If the current restriction is different swap the sign, require a side to be set.
95112
"""
96-
if self.current_restriction is None:
97-
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
98-
elif self.current_restriction == self.default_restriction:
99-
return o(self.default_restriction)
113+
if isinstance(self.default_restriction, dict):
114+
domain = extract_unique_domain(o, expand_mixed_mesh=False)
115+
if isinstance(domain, MixedMesh):
116+
raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(o)}")
117+
if domain not in self.default_restriction:
118+
raise RuntimeError(f"Integral type on {domain} not known")
119+
r = self.default_restriction[domain]
100120
else:
101-
return -o(self.default_restriction)
121+
r = self.default_restriction
122+
if r is None:
123+
if self.current_restriction is not None:
124+
raise ValueError(f"Expecting current_restriction None: got {self.current_restriction}")
125+
return o
126+
else:
127+
if self.current_restriction is None:
128+
raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.")
129+
elif self.current_restriction == r:
130+
return o(self.default_restriction)
131+
else:
132+
return -o(self.default_restriction)
102133

103134
def _missing_rule(self, o):
104135
"""Raise an error."""
@@ -206,7 +237,7 @@ def facet_normal(self, o):
206237
return self._require_restriction(o)
207238

208239

209-
def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True):
240+
def apply_restrictions(expression, assume_single_integral_type=True, apply_default=True, default_restriction=None):
210241
"""Propagate restriction nodes to wrap differential terminals directly."""
211242
if assume_single_integral_type:
212243
integral_types = [
@@ -217,7 +248,7 @@ def apply_restrictions(expression, assume_single_integral_type=True, apply_defau
217248
# the integral type of a given function; e.g., the former can be
218249
# ``exterior_facet`` and the latter ``interior_facet``.
219250
integral_types = None
220-
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default)
251+
rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type, apply_default=apply_default, default_restriction=default_restriction)
221252
if isinstance(expression, FormData):
222253
for integral_data in expression.integral_data:
223254
integral_data.integrals = tuple(
@@ -347,15 +378,28 @@ def to_be_restricted(self, o):
347378
return mt
348379
elif integral_type == "exterior_facet":
349380
return SingleValueRestricted(mt)
350-
elif integral_type == "interial_facet":
381+
elif integral_type == "interior_facet":
351382
return PositiveRestricted(mt)
352383
else:
353384
raise RuntimeError(f"Unknown integral type: {integral_type}")
354385

355386

356387
def replace_to_be_restricted(integral_data):
357388
new_integrals = []
358-
rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map)
389+
#rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map)
390+
rule = RestrictionPropagator(
391+
side=None,
392+
assume_single_integral_type=False,
393+
apply_default=True,
394+
default_restriction={
395+
domain: {
396+
"cell": None,
397+
"exterior_facet": None,
398+
"interior_facet": "+",
399+
}[integral_type]
400+
for domain, integral_type in integral_data.domain_integral_type_map.items()
401+
},
402+
)
359403
for integral in integral_data.integrals:
360404
integrand = map_expr_dag(rule, integral.integrand())
361405
new_integrals.append(integral.reconstruct(integrand=integrand))

ufl/algorithms/compute_form_data.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,6 @@ def compute_form_data(
338338

339339
form = apply_coordinate_derivatives(form)
340340

341-
# Propagate restrictions to terminals
342-
if do_apply_restrictions:
343-
if do_assume_single_integral_type:
344-
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)
345-
else:
346-
form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False)
347-
348341
# If in real mode, remove any complex nodes introduced during form processing.
349342
if not complex_mode:
350343
form = remove_complex_nodes(form)
@@ -353,6 +346,38 @@ def compute_form_data(
353346
# Most of the heavy lifting is done above in group_form_integrals.
354347
self.integral_data = build_integral_data(form.integrals())
355348

349+
# Propagate restrictions to terminals
350+
if do_apply_restrictions:
351+
if do_assume_single_integral_type or have_single_domain:
352+
for itg_data in self.integral_data:
353+
new_integrals = []
354+
for integral in itg_data.integrals:
355+
new_integral = apply_restrictions(
356+
integral,
357+
apply_default=do_apply_default_restrictions,
358+
default_restriction={
359+
itg_data.domain: {
360+
"cell": None,
361+
"exterior_facet": None,
362+
"interior_facet": "+",
363+
}[itg_data.integral_type]
364+
},
365+
)
366+
new_integrals.append(new_integral)
367+
itg_data.integrals = new_integrals
368+
else:
369+
#form = apply_restrictions(form, assume_single_integral_type=have_single_domain, apply_default=False)
370+
for itg_data in self.integral_data:
371+
new_integrals = []
372+
for integral in itg_data.integrals:
373+
new_integral = apply_restrictions(
374+
integral,
375+
assume_single_integral_type=have_single_domain,
376+
apply_default=False,
377+
)
378+
new_integrals.append(new_integral)
379+
itg_data.integrals = new_integrals
380+
356381
# --- Create replacements for arguments and coefficients
357382

358383
# Figure out which form coefficients each integral should enable

0 commit comments

Comments
 (0)