Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use float type that matches scalar type #502

Merged
merged 18 commits into from
Jun 24, 2022
Prev Previous commit
Next Next commit
Tidy up
garth-wells committed Jun 23, 2022
commit db95ef1718a5a9c13aaa9c37900c08e48763e7c2
4 changes: 3 additions & 1 deletion ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
@@ -204,9 +204,11 @@ def jacobian(self, e, mt, tabledata, quadrature_rule, access):
if mt.restriction == "-":
offset = num_scalar_dofs * dim

value_type = scalar_to_value_type(self.parameters["scalar_type"])

code = []
body = [L.AssignAdd(access, dof_access[ic * dim + begin + offset] * FE[ic])]
code += [L.VariableDecl("double", access, 0.0)]
code += [L.VariableDecl(f"{value_type}", access, 0.0)]
code += [L.ForRange(ic, 0, num_scalar_dofs, body)]

return [], code
28 changes: 11 additions & 17 deletions ffcx/codegeneration/integrals.py
Original file line number Diff line number Diff line change
@@ -178,24 +178,24 @@ def generate(self):

parts = []
scalar_type = self.backend.access.parameters["scalar_type"]
float_type = scalar_to_value_type(scalar_type)
value_type = scalar_to_value_type(scalar_type)
alignment = self.ir.params['assume_aligned']
if alignment != -1:
scalar_type = self.backend.access.parameters["scalar_type"]
parts += [L.VerbatimStatement(f"A = ({scalar_type}*)__builtin_assume_aligned(A, {alignment});"),
L.VerbatimStatement(f"w = (const {scalar_type}*)__builtin_assume_aligned(w, {alignment});"),
L.VerbatimStatement(f"c = (const {scalar_type}*)__builtin_assume_aligned(c, {alignment});"),
L.VerbatimStatement(f"coordinate_dofs = (const {float_type}*)__builtin_assume_aligned(coordinate_dofs, {alignment});")] # noqa
L.VerbatimStatement(f"coordinate_dofs = (const {value_type}*)__builtin_assume_aligned(coordinate_dofs, {alignment});")] # noqa

# Generate the tables of quadrature points and weights
parts += self.generate_quadrature_tables(float_type)
parts += self.generate_quadrature_tables(value_type)

# Generate the tables of basis function values and
# pre-integrated blocks
parts += self.generate_element_tables(float_type)
parts += self.generate_element_tables(value_type)

# Generate the tables of geometry data that are needed
parts += self.generate_geometry_tables(float_type)
parts += self.generate_geometry_tables(value_type)

# Loop generation code will produce parts to go before
# quadloops, to define the quadloops, and to go after the
@@ -226,7 +226,7 @@ def generate(self):

return L.StatementList(parts)

def generate_quadrature_tables(self, float_type):
def generate_quadrature_tables(self, value_type: str):
"""Generate static tables of quadrature points and weights."""
L = self.backend.language

@@ -246,7 +246,7 @@ def generate_quadrature_tables(self, float_type):

# Generate quadrature weights array
wsym = self.backend.symbols.weights_table(quadrature_rule)
parts += [L.ArrayDecl(f"static const {float_type}", wsym, num_points,
parts += [L.ArrayDecl(f"static const {value_type}", wsym, num_points,
quadrature_rule.weights, padlen=padlen)]

# Add leading comment if there are any tables
@@ -288,12 +288,9 @@ def generate_element_tables(self, float_type: str):
"""Generate static tables with precomputed element basisfunction values in quadrature points."""
L = self.backend.language
parts = []

tables = self.ir.unique_tables
table_types = self.ir.unique_table_types

padlen = self.ir.params["padlen"]

if self.ir.integral_type in ufl.custom_integral_types:
# Define only piecewise tables
table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes]
@@ -308,20 +305,18 @@ def generate_element_tables(self, float_type: str):
# Add leading comment if there are any tables
parts = L.commented_code_list(parts, [
"Precomputed values of basis functions and precomputations",
"FE* dimensions: [permutation][entities][points][dofs]",
])
"FE* dimensions: [permutation][entities][points][dofs]"])
return parts

def declare_table(self, name, table, padlen, float_type: str):
def declare_table(self, name, table, padlen, value_type: str):
"""Declare a table.

If the dof dimensions of the table have dof rotations, apply
these rotations.

"""
L = self.backend.language

return [L.ArrayDecl(f"static const {float_type}", name, table.shape, table, padlen=padlen)]
return [L.ArrayDecl(f"static const {value_type}", name, table.shape, table, padlen=padlen)]

def generate_quadrature_loop(self, quadrature_rule: QuadratureRule):
"""Generate quadrature loop with for this quadrature_rule."""
@@ -333,8 +328,7 @@ def generate_quadrature_loop(self, quadrature_rule: QuadratureRule):

# Generate dofblock parts, some of this will be placed before or
# after quadloop
preparts, quadparts = \
self.generate_dofblock_partition(quadrature_rule)
preparts, quadparts = self.generate_dofblock_partition(quadrature_rule)
body += quadparts

# Wrap body in loop or scope