Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Expression and IntegralIR with a common base #680

Merged
merged 11 commits into from
May 21, 2024
60 changes: 31 additions & 29 deletions ffcx/codegeneration/C/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Generate UFC code for an expression."""

from __future__ import annotations

import logging

import numpy as np
Expand All @@ -14,17 +16,19 @@
from ffcx.codegeneration.C.c_implementation import CFormatter
from ffcx.codegeneration.expression_generator import ExpressionGenerator
from ffcx.codegeneration.utils import dtype_to_c_type, dtype_to_scalar_dtype
from ffcx.ir.representation import ExpressionIR

logger = logging.getLogger("ffcx")


def generator(ir, options):
def generator(ir: ExpressionIR, options):
"""Generate UFC code for an expression."""
logger.info("Generating code for expression:")
logger.info(f"--- points: {ir.points}")
logger.info(f"--- name: {ir.name}")

factory_name = ir.name
assert len(ir.expression.integrand) == 1, "Expressions only support single quadrature rule"
points = next(iter(ir.expression.integrand)).points
logger.info(f"--- points: {points}")
factory_name = ir.expression.name
logger.info(f"--- name: {factory_name}")

# Format declaration
declaration = expressions_template.declaration.format(
Expand All @@ -34,58 +38,56 @@ def generator(ir, options):
backend = FFCXBackend(ir, options)
eg = ExpressionGenerator(ir, backend)

d = {}
d: dict[str, str | int] = {}
d["name_from_uflfile"] = ir.name_from_uflfile
d["factory_name"] = ir.name

d["factory_name"] = factory_name
parts = eg.generate()

CF = CFormatter(options["scalar_type"])
d["tabulate_expression"] = CF.c_format(parts)

if len(ir.original_coefficient_positions) > 0:
d["original_coefficient_positions"] = f"original_coefficient_positions_{ir.name}"
d["original_coefficient_positions"] = f"original_coefficient_positions_{factory_name}"
values = ", ".join(str(i) for i in ir.original_coefficient_positions)
sizes = len(ir.original_coefficient_positions)
d["original_coefficient_positions_init"] = (
f"static int original_coefficient_positions_{ir.name}[{sizes}] = {{{values}}};"
f"static int original_coefficient_positions_{factory_name}[{sizes}] = {{{values}}};"
)
else:
d["original_coefficient_positions"] = "NULL"
d["original_coefficient_positions_init"] = ""

values = ", ".join(str(p) for p in ir.points.flatten())
sizes = ir.points.size
d["points_init"] = f"static double points_{ir.name}[{sizes}] = {{{values}}};"
d["points"] = f"points_{ir.name}"
values = ", ".join(str(p) for p in points.flatten())
sizes = points.size
d["points_init"] = f"static double points_{factory_name}[{sizes}] = {{{values}}};"
d["points"] = f"points_{factory_name}"

if len(ir.expression_shape) > 0:
values = ", ".join(str(i) for i in ir.expression_shape)
sizes = len(ir.expression_shape)
d["value_shape_init"] = f"static int value_shape_{ir.name}[{sizes}] = {{{values}}};"
d["value_shape"] = f"value_shape_{ir.name}"
if len(ir.expression.shape) > 0:
values = ", ".join(str(i) for i in ir.expression.shape)
sizes = len(ir.expression.shape)
d["value_shape_init"] = f"static int value_shape_{factory_name}[{sizes}] = {{{values}}};"
d["value_shape"] = f"value_shape_{factory_name}"
else:
d["value_shape_init"] = ""
d["value_shape"] = "NULL"

d["num_components"] = len(ir.expression_shape)
d["num_coefficients"] = len(ir.coefficient_numbering)
d["num_components"] = len(ir.expression.shape)
d["num_coefficients"] = len(ir.expression.coefficient_numbering)
d["num_constants"] = len(ir.constant_names)
d["num_points"] = ir.points.shape[0]
d["entity_dimension"] = ir.points.shape[1]
d["num_points"] = points.shape[0]
d["entity_dimension"] = points.shape[1]
d["scalar_type"] = dtype_to_c_type(options["scalar_type"])
d["geom_type"] = dtype_to_c_type(dtype_to_scalar_dtype(options["scalar_type"]))
d["np_scalar_type"] = np.dtype(options["scalar_type"]).name

d["rank"] = len(ir.tensor_shape)
d["rank"] = len(ir.expression.tensor_shape)

if len(ir.coefficient_names) > 0:
values = ", ".join(f'"{name}"' for name in ir.coefficient_names)
sizes = len(ir.coefficient_names)
d["coefficient_names_init"] = (
f"static const char* coefficient_names_{ir.name}[{sizes}] = {{{values}}};"
f"static const char* coefficient_names_{factory_name}[{sizes}] = {{{values}}};"
)
d["coefficient_names"] = f"coefficient_names_{ir.name}"
d["coefficient_names"] = f"coefficient_names_{factory_name}"
else:
d["coefficient_names_init"] = ""
d["coefficient_names"] = "NULL"
Expand All @@ -94,9 +96,9 @@ def generator(ir, options):
values = ", ".join(f'"{name}"' for name in ir.constant_names)
sizes = len(ir.constant_names)
d["constant_names_init"] = (
f"static const char* constant_names_{ir.name}[{sizes}] = {{{values}}};"
f"static const char* constant_names_{factory_name}[{sizes}] = {{{values}}};"
)
d["constant_names"] = f"constant_names_{ir.name}"
d["constant_names"] = f"constant_names_{factory_name}"
else:
d["constant_names_init"] = ""
d["constant_names"] = "NULL"
Expand Down
17 changes: 10 additions & 7 deletions ffcx/codegeneration/C/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,43 @@
# old implementation in FFC
"""Generate UFC code for a form."""

from __future__ import annotations

import logging

import numpy as np

from ffcx.codegeneration.C import form_template
from ffcx.ir.representation import FormIR

logger = logging.getLogger("ffcx")


def generator(ir, options):
def generator(ir: FormIR, options):
"""Generate UFC code for a form."""
logger.info("Generating code for form:")
logger.info(f"--- rank: {ir.rank}")
logger.info(f"--- name: {ir.name}")

d = {}
d: dict[str, int | str] = {}
d["factory_name"] = ir.name
d["name_from_uflfile"] = ir.name_from_uflfile
d["signature"] = f'"{ir.signature}"'
d["rank"] = ir.rank
d["num_coefficients"] = ir.num_coefficients
d["num_constants"] = ir.num_constants

if len(ir.original_coefficient_position) > 0:
values = ", ".join(str(i) for i in ir.original_coefficient_position)
sizes = len(ir.original_coefficient_position)
if len(ir.original_coefficient_positions) > 0:
values = ", ".join(str(i) for i in ir.original_coefficient_positions)
sizes = len(ir.original_coefficient_positions)

d["original_coefficient_position_init"] = (
f"int original_coefficient_position_{ir.name}[{sizes}] = {{{values}}};"
)
d["original_coefficient_position"] = f"original_coefficient_position_{ir.name}"
d["original_coefficient_positions"] = f"original_coefficient_position_{ir.name}"
else:
d["original_coefficient_position_init"] = ""
d["original_coefficient_position"] = "NULL"
d["original_coefficient_positions"] = "NULL"

if len(ir.coefficient_names) > 0:
values = ", ".join(f'"{name}"' for name in ir.coefficient_names)
Expand Down
2 changes: 1 addition & 1 deletion ffcx/codegeneration/C/form_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
.rank = {rank},
.num_coefficients = {num_coefficients},
.num_constants = {num_constants},
.original_coefficient_position = {original_coefficient_position},
.original_coefficient_positions = {original_coefficient_positions},

.coefficient_name_map = {coefficient_names},
.constant_name_map = {constant_names},
Expand Down
12 changes: 6 additions & 6 deletions ffcx/codegeneration/C/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
def generator(ir: IntegralIR, options):
"""Generate C code for an integral."""
logger.info("Generating code for integral:")
logger.info(f"--- type: {ir.integral_type}")
logger.info(f"--- name: {ir.name}")
logger.info(f"--- type: {ir.expression.integral_type}")
logger.info(f"--- name: {ir.expression.name}")

"""Generate code for an integral."""
factory_name = ir.name
factory_name = ir.expression.name

# Format declaration
declaration = ufcx_integrals.declaration.format(factory_name=factory_name)
Expand All @@ -51,9 +51,9 @@ def generator(ir: IntegralIR, options):
values = ", ".join("1" if i else "0" for i in ir.enabled_coefficients)
sizes = len(ir.enabled_coefficients)
code["enabled_coefficients_init"] = (
f"bool enabled_coefficients_{ir.name}[{sizes}] = {{{values}}};"
f"bool enabled_coefficients_{ir.expression.name}[{sizes}] = {{{values}}};"
)
code["enabled_coefficients"] = f"enabled_coefficients_{ir.name}"
code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}"
else:
code["enabled_coefficients_init"] = ""
code["enabled_coefficients"] = "NULL"
Expand All @@ -74,7 +74,7 @@ def generator(ir: IntegralIR, options):
enabled_coefficients=code["enabled_coefficients"],
enabled_coefficients_init=code["enabled_coefficients_init"],
tabulate_tensor=code["tabulate_tensor"],
needs_facet_permutations="true" if ir.needs_facet_permutations else "false",
needs_facet_permutations="true" if ir.expression.needs_facet_permutations else "false",
scalar_type=dtype_to_c_type(options["scalar_type"]),
geom_type=dtype_to_c_type(dtype_to_scalar_dtype(options["scalar_type"])),
coordinate_element_hash=f"UINT64_C({element_hash})",
Expand Down
14 changes: 7 additions & 7 deletions ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
class FFCXBackendAccess:
"""FFCx specific formatter class."""

def __init__(self, ir, symbols, options):
def __init__(self, entity_type: str, integral_type: str, symbols, options):
"""Initialise."""
# Store ir and options
self.entitytype = ir.entitytype
self.integral_type = ir.integral_type
self.entity_type = entity_type
self.integral_type = integral_type
self.symbols = symbols
self.options = options

Expand Down Expand Up @@ -72,7 +72,7 @@ def get(
break

if handler:
return handler(mt, tabledata, quadrature_rule)
return handler(mt, tabledata, quadrature_rule) # type: ignore
else:
raise RuntimeError(f"Not handled: {type(e)}")

Expand Down Expand Up @@ -400,7 +400,7 @@ def _pass(self, *args, **kwargs):
def table_access(
self,
tabledata: UniqueTableReferenceT,
entitytype: str,
entity_type: str,
restriction: str,
quadrature_index: L.MultiIndex,
dof_index: L.MultiIndex,
Expand All @@ -409,12 +409,12 @@ def table_access(

Args:
tabledata: Table data object
entitytype: Entity type ("cell", "facet", "vertex")
entity_type: Entity type ("cell", "facet", "vertex")
restriction: Restriction ("+", "-")
quadrature_index: Quadrature index
dof_index: Dof index
"""
entity = self.symbols.entity(entitytype, restriction)
entity = self.symbols.entity(entity_type, restriction)
iq_global_index = quadrature_index.global_index
ic_global_index = dof_index.global_index
qp = 0 # quadrature permutation
Expand Down
19 changes: 13 additions & 6 deletions ffcx/codegeneration/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Collection of FFCx specific pieces for the code generation phase."""

from __future__ import annotations

from ffcx.codegeneration.access import FFCXBackendAccess
from ffcx.codegeneration.definitions import FFCXBackendDefinitions
from ffcx.codegeneration.symbols import FFCXBackendSymbols
from ffcx.ir.representation import ExpressionIR, IntegralIR


class FFCXBackend:
"""Class collecting all aspects of the FFCx backend."""

def __init__(self, ir, options):
def __init__(self, ir: IntegralIR | ExpressionIR, options):
"""Initialise."""
coefficient_numbering = ir.coefficient_numbering
coefficient_offsets = ir.coefficient_offsets
coefficient_numbering = ir.expression.coefficient_numbering
coefficient_offsets = ir.expression.coefficient_offsets

original_constant_offsets = ir.original_constant_offsets
original_constant_offsets = ir.expression.original_constant_offsets

self.symbols = FFCXBackendSymbols(
coefficient_numbering, coefficient_offsets, original_constant_offsets
)
self.access = FFCXBackendAccess(ir, self.symbols, options)
self.definitions = FFCXBackendDefinitions(ir, self.access, options)
self.access = FFCXBackendAccess(
ir.expression.entity_type, ir.expression.integral_type, self.symbols, options
)
self.definitions = FFCXBackendDefinitions(
ir.expression.entity_type, ir.expression.integral_type, self.access, options
)
18 changes: 10 additions & 8 deletions ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,14 @@ def create_dof_index(tabledata, dof_index_symbol):
class FFCXBackendDefinitions:
"""FFCx specific code definitions."""

def __init__(self, ir, access, options):
def __init__(self, entity_type: str, integral_type: str, access, options):
"""Initialise."""
# Store ir and options
self.integral_type = ir.integral_type
self.entitytype = ir.entitytype
self.integral_type = integral_type
self.entity_type = entity_type
self.access = access
self.symbols = access.symbols
self.options = options

self.ir = ir

# called, depending on the first argument type.
self.handler_lookup = {
ufl.coefficient.Coefficient: self.coefficient,
Expand All @@ -80,6 +77,11 @@ def __init__(self, ir, access, options):
ufl.geometry.FacetOrientation: self.pass_through,
}

@property
def symbols(self):
"""Return formatter."""
return self.access.symbols

def get(
self,
mt: ModifiedTerminal,
Expand Down Expand Up @@ -141,7 +143,7 @@ def coefficient(
assert begin < end

# Get access to element table
FE, tables = self.access.table_access(tabledata, self.entitytype, mt.restriction, iq, ic)
FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
dof_access: L.ArrayAccess = self.symbols.coefficient_dof_access(
mt.terminal, (ic.global_index) * bs + begin
)
Expand Down Expand Up @@ -190,7 +192,7 @@ def _define_coordinate_dofs_lincomb(
iq_symbol = self.symbols.quadrature_loop_index
ic = create_dof_index(tabledata, ic_symbol)
iq = create_quadrature_index(quadrature_rule, iq_symbol)
FE, tables = self.access.table_access(tabledata, self.entitytype, mt.restriction, iq, ic)
FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)

dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL)

Expand Down
Loading
Loading