Skip to content

Commit 8109cce

Browse files
authored
Allow ufl.real/imag/conj on vectors and tensors (#496)
* Let conj, real and imag be applied on vector quantities * Add test for vector operations on real/imag/conj
1 parent c97a24f commit 8109cce

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

ffcx/ir/analysis/reconstruct.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ def handle_conditional(o, ops):
3434
return symbols
3535

3636

37-
def handle_conj(o, ops):
38-
if len(ops) != 1:
39-
raise RuntimeError("Expecting one operand")
40-
if o.ufl_shape != ():
41-
raise RuntimeError("Expecting scalar.")
42-
return [o._ufl_expr_reconstruct_(x) for x in ops[0]]
37+
def handle_elementwise_unary(o, ops):
38+
if len(ops) > 1:
39+
raise RuntimeError("Expecting unary operator.")
40+
return [o._ufl_expr_reconstruct_(op) for op in ops[0]]
4341

4442

4543
def handle_division(o, ops):
@@ -142,16 +140,16 @@ def handle_index_sum(o, ops):
142140
ufl.classes.Abs: handle_scalar_nary,
143141
ufl.classes.MinValue: handle_scalar_nary,
144142
ufl.classes.MaxValue: handle_scalar_nary,
145-
ufl.classes.Real: handle_scalar_nary,
146-
ufl.classes.Imag: handle_scalar_nary,
143+
ufl.classes.Real: handle_elementwise_unary,
144+
ufl.classes.Imag: handle_elementwise_unary,
147145
ufl.classes.Power: handle_scalar_nary,
148146
ufl.classes.BesselFunction: handle_scalar_nary,
149147
ufl.classes.Atan2: handle_scalar_nary,
150148
ufl.classes.Product: handle_product,
151149
ufl.classes.Division: handle_division,
152150
ufl.classes.Sum: handle_sum,
153151
ufl.classes.IndexSum: handle_index_sum,
154-
ufl.classes.Conj: handle_conj,
152+
ufl.classes.Conj: handle_elementwise_unary,
155153
ufl.classes.Conditional: handle_conditional,
156154
ufl.classes.Condition: handle_condition}
157155

test/test_jit_forms.py

+50
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,53 @@ def test_prism(compile_args):
645645
ffi.cast('double *', coords.ctypes.data), ffi.NULL, ffi.NULL)
646646

647647
assert np.isclose(sum(b), 0.5)
648+
649+
650+
def test_complex_operations(compile_args):
651+
mode = "double _Complex"
652+
cell = ufl.triangle
653+
c_element = ufl.VectorElement("Lagrange", cell, 1)
654+
mesh = ufl.Mesh(c_element)
655+
element = ufl.VectorElement("DG", cell, 0)
656+
V = ufl.FunctionSpace(mesh, element)
657+
u = ufl.Coefficient(V)
658+
J1 = ufl.real(u)[0] * ufl.imag(u)[1] * ufl.conj(u)[0] * ufl.dx
659+
J2 = ufl.real(u[0]) * ufl.imag(u[1]) * ufl.conj(u[0]) * ufl.dx
660+
forms = [J1, J2]
661+
662+
compiled_forms, module, code = ffcx.codegeneration.jit.compile_forms(
663+
forms, parameters={'scalar_type': mode}, cffi_extra_compile_args=compile_args)
664+
665+
form0 = compiled_forms[0].integrals(module.lib.cell)[0]
666+
form1 = compiled_forms[1].integrals(module.lib.cell)[0]
667+
668+
ffi = module.ffi
669+
np_type = cdtype_to_numpy(mode)
670+
w1 = np.array([3 + 5j, 8 - 7j], dtype=np_type)
671+
c = np.array([], dtype=np_type)
672+
673+
coords = np.array([[0.0, 0.0, 0.0],
674+
[1.0, 0.0, 0.0],
675+
[0.0, 1.0, 0.0]], dtype=np.float64)
676+
J_1 = np.zeros((1), dtype=np_type)
677+
kernel0 = ffi.cast(f"ufcx_tabulate_tensor_{np_type} *", getattr(form0, f"tabulate_tensor_{np_type}"))
678+
kernel0(ffi.cast('{type} *'.format(type=mode), J_1.ctypes.data),
679+
ffi.cast('{type} *'.format(type=mode), w1.ctypes.data),
680+
ffi.cast('{type} *'.format(type=mode), c.ctypes.data),
681+
ffi.cast('double *', coords.ctypes.data), ffi.NULL, ffi.NULL)
682+
683+
expected_result = np.array([0.5 * np.real(w1[0]) * np.imag(w1[1])
684+
* (np.real(w1[0]) - 1j * np.imag(w1[0]))], dtype=np_type)
685+
assert np.allclose(J_1, expected_result)
686+
687+
J_2 = np.zeros((1), dtype=np_type)
688+
689+
kernel1 = ffi.cast(f"ufcx_tabulate_tensor_{np_type} *", getattr(form1, f"tabulate_tensor_{np_type}"))
690+
kernel1(ffi.cast('{type} *'.format(type=mode), J_2.ctypes.data),
691+
ffi.cast('{type} *'.format(type=mode), w1.ctypes.data),
692+
ffi.cast('{type} *'.format(type=mode), c.ctypes.data),
693+
ffi.cast('double *', coords.ctypes.data), ffi.NULL, ffi.NULL)
694+
695+
assert np.allclose(J_2, expected_result)
696+
697+
assert np.allclose(J_1, J_2)

0 commit comments

Comments
 (0)