26
26
class RestrictionPropagator (MultiFunction ):
27
27
"""Restriction propagator."""
28
28
29
- def __init__ (self , side = None , assume_single_integral_type = True , apply_default = True ):
29
+ def __init__ (self , side = None , assume_single_integral_type = True , apply_default = True , default_restriction = None ):
30
30
"""Initialise."""
31
31
MultiFunction .__init__ (self )
32
32
self .current_restriction = side
33
- self .default_restriction = "+" if assume_single_integral_type else "?"
33
+ if default_restriction is None :
34
+ default_restriction = "+" if assume_single_integral_type else "?"
35
+ self .default_restriction = default_restriction
34
36
self .apply_default = apply_default
35
37
# Caches for propagating the restriction with map_expr_dag
36
38
self .vcaches = {"+" : {}, "-" : {}, "|" : {}, "?" : {}}
37
39
self .rcaches = {"+" : {}, "-" : {}, "|" : {}, "?" : {}}
38
40
if self .current_restriction is None :
39
41
self ._rp = {
40
- "+" : RestrictionPropagator ("+" , assume_single_integral_type , apply_default ),
41
- "-" : RestrictionPropagator ("-" , assume_single_integral_type , apply_default ),
42
- "|" : RestrictionPropagator ("|" , assume_single_integral_type , apply_default ),
43
- "?" : RestrictionPropagator ("?" , assume_single_integral_type , apply_default ),
42
+ "+" : RestrictionPropagator ("+" , assume_single_integral_type , apply_default , default_restriction ),
43
+ "-" : RestrictionPropagator ("-" , assume_single_integral_type , apply_default , default_restriction ),
44
+ "|" : RestrictionPropagator ("|" , assume_single_integral_type , apply_default , default_restriction ),
45
+ "?" : RestrictionPropagator ("?" , assume_single_integral_type , apply_default , default_restriction ),
44
46
}
45
47
self .assume_single_integral_type = assume_single_integral_type
46
48
@@ -71,6 +73,9 @@ def _require_restriction(self, o):
71
73
if self .current_restriction is not None :
72
74
return o (self .current_restriction )
73
75
elif not self .assume_single_integral_type :
76
+ # If integration if over interior facet of meshA and exterior facet of meshB,
77
+ # arguments (say) on meshA must be restricted, but those on meshB do not
78
+ # need to be.
74
79
return o
75
80
else :
76
81
raise ValueError (f"Discontinuous type { o ._ufl_class_ .__name__ } must be restricted." )
@@ -84,7 +89,19 @@ def _default_restricted(self, o):
84
89
domain = extract_unique_domain (o , expand_mixed_mesh = False )
85
90
if isinstance (domain , MixedMesh ):
86
91
raise RuntimeError (f"Not expecting a terminal object on a mixed mesh at this stage: found { repr (o )} " )
87
- return o (self .default_restriction [domain ])
92
+ if isinstance (self .default_restriction , dict ):
93
+ if domain not in self .default_restriction :
94
+ raise RuntimeError (f"Integral type on { domain } not known" )
95
+ r = self .default_restriction [domain ]
96
+ if r is None :
97
+ return o
98
+ elif r in ["+" , "-" ]:
99
+ return o (r )
100
+ else :
101
+ raise RuntimeError (f"Unknown default restriction { r } on domain { domain } " )
102
+ else :
103
+ # conventional "+" default:
104
+ return o (self .default_restriction )
88
105
else :
89
106
return o
90
107
@@ -93,12 +110,26 @@ def _opposite(self, o):
93
110
94
111
If the current restriction is different swap the sign, require a side to be set.
95
112
"""
96
- if self .current_restriction is None :
97
- raise ValueError (f"Discontinuous type { o ._ufl_class_ .__name__ } must be restricted." )
98
- elif self .current_restriction == self .default_restriction :
99
- return o (self .default_restriction )
113
+ if isinstance (self .default_restriction , dict ):
114
+ domain = extract_unique_domain (o , expand_mixed_mesh = False )
115
+ if isinstance (domain , MixedMesh ):
116
+ raise RuntimeError (f"Not expecting a terminal object on a mixed mesh at this stage: found { repr (o )} " )
117
+ if domain not in self .default_restriction :
118
+ raise RuntimeError (f"Integral type on { domain } not known" )
119
+ r = self .default_restriction [domain ]
100
120
else :
101
- return - o (self .default_restriction )
121
+ r = self .default_restriction
122
+ if r is None :
123
+ if self .current_restriction is not None :
124
+ raise ValueError (f"Expecting current_restriction None: got { self .current_restriction } " )
125
+ return o
126
+ else :
127
+ if self .current_restriction is None :
128
+ raise ValueError (f"Discontinuous type { o ._ufl_class_ .__name__ } must be restricted." )
129
+ elif self .current_restriction == r :
130
+ return o (self .default_restriction )
131
+ else :
132
+ return - o (self .default_restriction )
102
133
103
134
def _missing_rule (self , o ):
104
135
"""Raise an error."""
@@ -206,7 +237,7 @@ def facet_normal(self, o):
206
237
return self ._require_restriction (o )
207
238
208
239
209
- def apply_restrictions (expression , assume_single_integral_type = True , apply_default = True ):
240
+ def apply_restrictions (expression , assume_single_integral_type = True , apply_default = True , default_restriction = None ):
210
241
"""Propagate restriction nodes to wrap differential terminals directly."""
211
242
if assume_single_integral_type :
212
243
integral_types = [
@@ -217,7 +248,7 @@ def apply_restrictions(expression, assume_single_integral_type=True, apply_defau
217
248
# the integral type of a given function; e.g., the former can be
218
249
# ``exterior_facet`` and the latter ``interior_facet``.
219
250
integral_types = None
220
- rules = RestrictionPropagator (assume_single_integral_type = assume_single_integral_type , apply_default = apply_default )
251
+ rules = RestrictionPropagator (assume_single_integral_type = assume_single_integral_type , apply_default = apply_default , default_restriction = default_restriction )
221
252
if isinstance (expression , FormData ):
222
253
for integral_data in expression .integral_data :
223
254
integral_data .integrals = tuple (
@@ -347,15 +378,28 @@ def to_be_restricted(self, o):
347
378
return mt
348
379
elif integral_type == "exterior_facet" :
349
380
return SingleValueRestricted (mt )
350
- elif integral_type == "interial_facet " :
381
+ elif integral_type == "interior_facet " :
351
382
return PositiveRestricted (mt )
352
383
else :
353
384
raise RuntimeError (f"Unknown integral type: { integral_type } " )
354
385
355
386
356
387
def replace_to_be_restricted (integral_data ):
357
388
new_integrals = []
358
- rule = ToBeRestrectedReplacer (integral_data .domain_integral_type_map )
389
+ #rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map)
390
+ rule = RestrictionPropagator (
391
+ side = None ,
392
+ assume_single_integral_type = False ,
393
+ apply_default = True ,
394
+ default_restriction = {
395
+ domain : {
396
+ "cell" : None ,
397
+ "exterior_facet" : None ,
398
+ "interior_facet" : "+" ,
399
+ }[integral_type ]
400
+ for domain , integral_type in integral_data .domain_integral_type_map .items ()
401
+ },
402
+ )
359
403
for integral in integral_data .integrals :
360
404
integrand = map_expr_dag (rule , integral .integrand ())
361
405
new_integrals .append (integral .reconstruct (integrand = integrand ))
0 commit comments