|
| 1 | +"""Apply coefficient split. |
| 2 | +
|
| 3 | +This module contains classes and functions to split coefficients defined on mixed function spaces. |
| 4 | +""" |
| 5 | + |
| 6 | +import numpy |
| 7 | +from ufl.classes import Restricted |
| 8 | +from ufl.corealg.map_dag import map_expr_dag |
| 9 | +from ufl.corealg.multifunction import MultiFunction, memoized_handler |
| 10 | +from ufl.domain import extract_unique_domain |
| 11 | +from ufl.classes import (Coefficient, Form, ReferenceGrad, ReferenceValue, |
| 12 | + Indexed, MultiIndex, Index, FixedIndex, |
| 13 | + ComponentTensor, ListTensor, Zero, |
| 14 | + NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted) |
| 15 | +from ufl import indices |
| 16 | +from ufl.checks import is_cellwise_constant |
| 17 | +from ufl.tensors import as_tensor |
| 18 | + |
| 19 | + |
| 20 | +class CoefficientSplitter(MultiFunction): |
| 21 | + |
| 22 | + def __init__(self, coefficient_split): |
| 23 | + MultiFunction.__init__(self) |
| 24 | + self._coefficient_split = coefficient_split |
| 25 | + |
| 26 | + expr = MultiFunction.reuse_if_untouched |
| 27 | + |
| 28 | + def modified_terminal(self, o): |
| 29 | + restriction = None |
| 30 | + local_derivatives = 0 |
| 31 | + reference_value = False |
| 32 | + t = o |
| 33 | + while not t._ufl_is_terminal_: |
| 34 | + assert t._ufl_is_terminal_modifier_, f"Got {repr(t)}" |
| 35 | + if isinstance(t, ReferenceValue): |
| 36 | + assert not reference_value, "Got twice pulled back terminal!" |
| 37 | + reference_value = True |
| 38 | + t, = t.ufl_operands |
| 39 | + elif isinstance(t, ReferenceGrad): |
| 40 | + local_derivatives += 1 |
| 41 | + t, = t.ufl_operands |
| 42 | + elif isinstance(t, Restricted): |
| 43 | + assert restriction is None, "Got twice restricted terminal!" |
| 44 | + restriction = t._side |
| 45 | + t, = t.ufl_operands |
| 46 | + elif t._ufl_terminal_modifiers_: |
| 47 | + raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t))) |
| 48 | + else: |
| 49 | + raise ValueError("Unexpected type %s object %s." % (type(t), repr(t))) |
| 50 | + if not isinstance(t, Coefficient): |
| 51 | + # Only split coefficients |
| 52 | + return o |
| 53 | + if t not in self._coefficient_split: |
| 54 | + # Only split mixed coefficients |
| 55 | + return o |
| 56 | + # Reference value expected |
| 57 | + assert reference_value |
| 58 | + # Derivative indices |
| 59 | + beta = indices(local_derivatives) |
| 60 | + components = [] |
| 61 | + for subcoeff in self._coefficient_split[t]: |
| 62 | + c = subcoeff |
| 63 | + # Apply terminal modifiers onto the subcoefficient |
| 64 | + if reference_value: |
| 65 | + c = ReferenceValue(c) |
| 66 | + for n in range(local_derivatives): |
| 67 | + # Return zero if expression is trivially constant. This has to |
| 68 | + # happen here because ReferenceGrad has no access to the |
| 69 | + # topological dimension of a literal zero. |
| 70 | + if is_cellwise_constant(c): |
| 71 | + dim = extract_unique_domain(subcoeff).topological_dimension() |
| 72 | + c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions) |
| 73 | + else: |
| 74 | + c = ReferenceGrad(c) |
| 75 | + if restriction == '+': |
| 76 | + c = PositiveRestricted(c) |
| 77 | + elif restriction == '-': |
| 78 | + c = NegativeRestricted(c) |
| 79 | + elif restriction == '|': |
| 80 | + c = SingleValueRestricted(c) |
| 81 | + elif restriction == '?': |
| 82 | + c = ToBeRestricted(c) |
| 83 | + elif restriction is not None: |
| 84 | + raise RuntimeError(f"Got unknown restriction: {restriction}") |
| 85 | + # Collect components of the subcoefficient |
| 86 | + for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape): |
| 87 | + # New modified terminal: component[alpha + beta] |
| 88 | + components.append(c[alpha + beta]) |
| 89 | + # Repack derivative indices to shape |
| 90 | + c, = indices(1) |
| 91 | + return ComponentTensor(as_tensor(components)[c], MultiIndex((c,) + beta)) |
| 92 | + |
| 93 | + positive_restricted = modified_terminal |
| 94 | + negative_restricted = modified_terminal |
| 95 | + single_value_restricted = modified_terminal |
| 96 | + to_be_restricted = modified_terminal |
| 97 | + reference_grad = modified_terminal |
| 98 | + reference_value = modified_terminal |
| 99 | + terminal = modified_terminal |
| 100 | + |
| 101 | + |
| 102 | +def apply_coefficient_split(expr, coefficient_split): |
| 103 | + """Split mixed coefficients, so mixed elements need not be |
| 104 | + implemented. |
| 105 | +
|
| 106 | + :arg split: A :py:class:`dict` mapping each mixed coefficient to a |
| 107 | + sequence of subcoefficients. If None, calling this |
| 108 | + function is a no-op. |
| 109 | + """ |
| 110 | + if coefficient_split is None: |
| 111 | + return expr |
| 112 | + splitter = CoefficientSplitter(coefficient_split) |
| 113 | + return map_expr_dag(splitter, expr) |
| 114 | + |
| 115 | + |
| 116 | +class FixedIndexRemover(MultiFunction): |
| 117 | + |
| 118 | + def __init__(self, fimap): |
| 119 | + MultiFunction.__init__(self) |
| 120 | + self.fimap = fimap |
| 121 | + self._object_cache = {} |
| 122 | + |
| 123 | + expr = MultiFunction.reuse_if_untouched |
| 124 | + |
| 125 | + @memoized_handler |
| 126 | + def zero(self, o): |
| 127 | + free_indices = [] |
| 128 | + index_dimensions = [] |
| 129 | + for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions): |
| 130 | + if Index(i) in self.fimap: |
| 131 | + ind_j = self.fimap[Index(i)] |
| 132 | + if not isinstance(ind_j, FixedIndex): |
| 133 | + free_indices.append(ind_j.count()) |
| 134 | + index_dimensions.append(d) |
| 135 | + else: |
| 136 | + free_indices.append(i) |
| 137 | + index_dimensions.append(d) |
| 138 | + return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions)) |
| 139 | + |
| 140 | + @memoized_handler |
| 141 | + def list_tensor(self, o): |
| 142 | + cc = [] |
| 143 | + for o1 in o.ufl_operands: |
| 144 | + comp = map_expr_dag(self, o1) |
| 145 | + cc.append(comp) |
| 146 | + return ListTensor(*cc) |
| 147 | + |
| 148 | + @memoized_handler |
| 149 | + def multi_index(self, o): |
| 150 | + return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices())) |
| 151 | + |
| 152 | + |
| 153 | +class IndexRemover(MultiFunction): |
| 154 | + |
| 155 | + def __init__(self): |
| 156 | + MultiFunction.__init__(self) |
| 157 | + self._object_cache = {} |
| 158 | + |
| 159 | + expr = MultiFunction.reuse_if_untouched |
| 160 | + |
| 161 | + @memoized_handler |
| 162 | + def _zero_simplify(self, o): |
| 163 | + operand, = o.ufl_operands |
| 164 | + operand = map_expr_dag(self, operand) |
| 165 | + if isinstance(operand, Zero): |
| 166 | + return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions) |
| 167 | + else: |
| 168 | + return o._ufl_expr_reconstruct_(operand) |
| 169 | + |
| 170 | + @memoized_handler |
| 171 | + def indexed(self, o): |
| 172 | + o1, i1 = o.ufl_operands |
| 173 | + if isinstance(o1, ComponentTensor): |
| 174 | + o2, i2 = o1.ufl_operands |
| 175 | + assert len(i2.indices()) == len(i1.indices()) |
| 176 | + fimap = dict(zip(i2.indices(), i1.indices())) |
| 177 | + rule = FixedIndexRemover(fimap) |
| 178 | + v = map_expr_dag(rule, o2) |
| 179 | + return map_expr_dag(self, v) |
| 180 | + elif isinstance(o1, ListTensor): |
| 181 | + if isinstance(i1[0], FixedIndex): |
| 182 | + o1 = o1.ufl_operands[i1[0]._value] |
| 183 | + if len(i1) > 1: |
| 184 | + i1 = MultiIndex(i1[1:]) |
| 185 | + return map_expr_dag(self, Indexed(o1, i1)) |
| 186 | + else: |
| 187 | + return map_expr_dag(self, o1) |
| 188 | + o1 = map_expr_dag(self, o1) |
| 189 | + return Indexed(o1, i1) |
| 190 | + |
| 191 | + # Do something nicer |
| 192 | + positive_restricted = _zero_simplify |
| 193 | + negative_restricted = _zero_simplify |
| 194 | + single_value_restricted = _zero_simplify |
| 195 | + to_be_restricted = _zero_simplify |
| 196 | + reference_grad = _zero_simplify |
| 197 | + reference_value = _zero_simplify |
| 198 | + |
| 199 | + |
| 200 | +def remove_component_and_list_tensors(o): |
| 201 | + if isinstance(o, Form): |
| 202 | + integrals = [] |
| 203 | + for integral in o.integrals(): |
| 204 | + integrand = remove_component_and_list_tensors(integral.integrand()) |
| 205 | + if not isinstance(integrand, Zero): |
| 206 | + integrals.append(integral.reconstruct(integrand=integrand)) |
| 207 | + return o._ufl_expr_reconstruct_(integrals) |
| 208 | + else: |
| 209 | + rule = IndexRemover() |
| 210 | + return map_expr_dag(rule, o) |
0 commit comments