Skip to content

Commit 6cf0882

Browse files
jorgensdmscroggs
andauthored
Start adding some typing and documentation (#751)
* Start adding some typing and documentation * Type hints + various improvements and one bugfix * Revert | to typing unions * Fix typing * Add flag * More legacy typesetting * Ruff format * Fixes * Apply suggestions from code review Co-authored-by: Matthew Scroggs <[email protected]> --------- Co-authored-by: Matthew Scroggs <[email protected]>
1 parent bb5bbbf commit 6cf0882

15 files changed

+140
-73
lines changed

ffcx/analysis.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ class UFLData(typing.NamedTuple):
4141

4242
def analyze_ufl_objects(
4343
ufl_objects: list[
44-
ufl.form.Form
45-
| ufl.AbstractFiniteElement
46-
| ufl.Mesh
47-
| tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]]
44+
typing.Union[
45+
ufl.form.Form,
46+
ufl.AbstractFiniteElement,
47+
ufl.Mesh,
48+
tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]],
49+
]
4850
],
4951
scalar_type: npt.DTypeLike,
5052
) -> UFLData:
@@ -246,7 +248,7 @@ def _analyze_form(
246248

247249

248250
def _has_custom_integrals(
249-
o: ufl.integral.Integral | ufl.classes.Form | list | tuple,
251+
o: typing.Union[ufl.integral.Integral, ufl.classes.Form, list, tuple],
250252
) -> bool:
251253
"""Check for custom integrals."""
252254
if isinstance(o, ufl.integral.Integral):

ffcx/codegeneration/C/expressions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import logging
11+
import typing
1112

1213
import numpy as np
1314

@@ -38,7 +39,7 @@ def generator(ir: ExpressionIR, options):
3839
backend = FFCXBackend(ir, options)
3940
eg = ExpressionGenerator(ir, backend)
4041

41-
d: dict[str, str | int] = {}
42+
d: dict[str, typing.Union[str, int]] = {}
4243
d["name_from_uflfile"] = ir.name_from_uflfile
4344
d["factory_name"] = factory_name
4445
parts = eg.generate()

ffcx/codegeneration/C/form.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import annotations
1414

1515
import logging
16+
import typing
1617

1718
import numpy as np
1819

@@ -28,7 +29,7 @@ def generator(ir: FormIR, options):
2829
logger.info(f"--- rank: {ir.rank}")
2930
logger.info(f"--- name: {ir.name}")
3031

31-
d: dict[str, int | str] = {}
32+
d: dict[str, typing.Union[int, str]] = {}
3233
d["factory_name"] = ir.name
3334
d["name_from_uflfile"] = ir.name_from_uflfile
3435
d["signature"] = f'"{ir.signature}"'

ffcx/codegeneration/access.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import ufl
1414

1515
import ffcx.codegeneration.lnodes as L
16+
from ffcx.definitions import entity_types
1617
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1718
from ffcx.ir.elementtables import UniqueTableReferenceT
1819
from ffcx.ir.representationutils import QuadratureRule
@@ -23,7 +24,9 @@
2324
class FFCXBackendAccess:
2425
"""FFCx specific formatter class."""
2526

26-
def __init__(self, entity_type: str, integral_type: str, symbols, options):
27+
entity_type: entity_types
28+
29+
def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
2730
"""Initialise."""
2831
# Store ir and options
2932
self.entity_type = entity_type
@@ -88,6 +91,8 @@ def coefficient(
8891

8992
num_dofs = tabledata.values.shape[3]
9093
begin = tabledata.offset
94+
assert begin is not None
95+
assert tabledata.block_size is not None
9196
end = begin + tabledata.block_size * (num_dofs - 1) + 1
9297

9398
if ttype == "ones" and (end - begin) == 1:
@@ -406,7 +411,7 @@ def _pass(self, *args, **kwargs):
406411
def table_access(
407412
self,
408413
tabledata: UniqueTableReferenceT,
409-
entity_type: str,
414+
entity_type: entity_types,
410415
restriction: str,
411416
quadrature_index: L.MultiIndex,
412417
dof_index: L.MultiIndex,
@@ -415,7 +420,7 @@ def table_access(
415420
416421
Args:
417422
tabledata: Table data object
418-
entity_type: Entity type ("cell", "facet", "vertex")
423+
entity_type: Entity type
419424
restriction: Restriction ("+", "-")
420425
quadrature_index: Quadrature index
421426
dof_index: Dof index
@@ -446,6 +451,7 @@ def table_access(
446451
], symbols
447452
else:
448453
FE = []
454+
assert tabledata.tensor_factors is not None
449455
for i in range(dof_index.dim):
450456
factor = tabledata.tensor_factors[i]
451457
iq_i = quadrature_index.local_index(i)

ffcx/codegeneration/backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from __future__ import annotations
99

10+
import typing
11+
1012
from ffcx.codegeneration.access import FFCXBackendAccess
1113
from ffcx.codegeneration.definitions import FFCXBackendDefinitions
1214
from ffcx.codegeneration.symbols import FFCXBackendSymbols
@@ -16,7 +18,7 @@
1618
class FFCXBackend:
1719
"""Class collecting all aspects of the FFCx backend."""
1820

19-
def __init__(self, ir: IntegralIR | ExpressionIR, options):
21+
def __init__(self, ir: typing.Union[IntegralIR, ExpressionIR], options):
2022
"""Initialise."""
2123
coefficient_numbering = ir.expression.coefficient_numbering
2224
coefficient_offsets = ir.expression.coefficient_offsets

ffcx/codegeneration/codegeneration.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ class CodeBlocks(typing.NamedTuple):
3939
file_post: list[tuple[str, str]]
4040

4141

42-
def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) -> CodeBlocks:
42+
def generate_code(
43+
ir: DataIR, options: dict[str, typing.Union[int, float, npt.DTypeLike]]
44+
) -> CodeBlocks:
4345
"""Generate code blocks from intermediate representation."""
4446
logger.info(79 * "*")
4547
logger.info("Compiler stage 3: Generating code")

ffcx/codegeneration/definitions.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import ufl
1212

1313
import ffcx.codegeneration.lnodes as L
14+
from ffcx.definitions import entity_types
1415
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1516
from ffcx.ir.elementtables import UniqueTableReferenceT
1617
from ffcx.ir.representationutils import QuadratureRule
@@ -50,7 +51,9 @@ def create_dof_index(tabledata, dof_index_symbol):
5051
class FFCXBackendDefinitions:
5152
"""FFCx specific code definitions."""
5253

53-
def __init__(self, entity_type: str, integral_type: str, access, options):
54+
entity_type: entity_types
55+
56+
def __init__(self, entity_type: entity_types, integral_type: str, access, options):
5457
"""Initialise."""
5558
# Store ir and options
5659
self.integral_type = integral_type
@@ -130,6 +133,8 @@ def coefficient(
130133
num_dofs = tabledata.values.shape[3]
131134
bs = tabledata.block_size
132135
begin = tabledata.offset
136+
assert bs is not None
137+
assert begin is not None
133138
end = begin + bs * (num_dofs - 1) + 1
134139

135140
if ttype == "zeros":

ffcx/codegeneration/jit.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sysconfig
1717
import tempfile
1818
import time
19+
import typing
1920
from contextlib import redirect_stdout
2021
from pathlib import Path
2122

@@ -152,7 +153,7 @@ def _compilation_signature(cffi_extra_compile_args, cffi_debug):
152153
def compile_forms(
153154
forms: list[ufl.Form],
154155
options: dict = {},
155-
cache_dir: Path | None = None,
156+
cache_dir: typing.Optional[Path] = None,
156157
timeout: int = 10,
157158
cffi_extra_compile_args: list[str] = [],
158159
cffi_verbose: bool = False,
@@ -231,7 +232,7 @@ def compile_forms(
231232
def compile_expressions(
232233
expressions: list[tuple[ufl.Expr, npt.NDArray[np.floating]]],
233234
options: dict = {},
234-
cache_dir: Path | None = None,
235+
cache_dir: typing.Optional[Path] = None,
235236
timeout: int = 10,
236237
cffi_extra_compile_args: list[str] = [],
237238
cffi_verbose: bool = False,

ffcx/codegeneration/symbols.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ufl
1111

1212
import ffcx.codegeneration.lnodes as L
13+
from ffcx.definitions import entity_types
1314

1415
logger = logging.getLogger("ffcx")
1516

@@ -95,7 +96,7 @@ def __init__(self, coefficient_numbering, coefficient_offsets, original_constant
9596
# Table for chunk of custom quadrature points (physical coordinates).
9697
self.custom_points_table = L.Symbol("points_chunk", dtype=L.DataType.REAL)
9798

98-
def entity(self, entity_type, restriction):
99+
def entity(self, entity_type: entity_types, restriction):
99100
"""Entity index for lookup in element tables."""
100101
if entity_type == "cell":
101102
# Always 0 for cells (even with restriction)
@@ -176,7 +177,7 @@ def constant_index_access(self, constant, index):
176177
return c[offset + index]
177178

178179
# TODO: Remove this, use table_access instead
179-
def element_table(self, tabledata, entity_type, restriction):
180+
def element_table(self, tabledata, entity_type: entity_types, restriction):
180181
"""Get an element table."""
181182
entity = self.entity(entity_type, restriction)
182183

ffcx/definitions.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module for storing type definitions used in the FFCx code base."""
2+
3+
from typing import Literal
4+
5+
entity_types = Literal["cell", "facet", "vertex"]

ffcx/ir/analysis/graph.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Linearized data structure for the computational graph."""
77

88
import logging
9+
import typing
910

1011
import numpy as np
1112
import ufl
@@ -73,7 +74,7 @@ def build_graph_vertices(expressions, skip_terminal_modifiers=False):
7374
return G
7475

7576

76-
def build_scalar_graph(expression):
77+
def build_scalar_graph(expression) -> ExpressionGraph:
7778
"""Build list representation of expression graph covering the given expressions."""
7879
# Populate with vertices
7980
G = build_graph_vertices([expression], skip_terminal_modifiers=False)
@@ -86,7 +87,7 @@ def build_scalar_graph(expression):
8687
G = build_graph_vertices(scalar_expressions, skip_terminal_modifiers=True)
8788

8889
# Compute graph edges
89-
V_deps = []
90+
V_deps: list[typing.Union[tuple[()], list[int]]] = []
9091
for i, v in G.nodes.items():
9192
expr = v["expression"]
9293
if expr._ufl_is_terminal_ or expr._ufl_is_terminal_modifier_:

0 commit comments

Comments
 (0)