Skip to content

Commit 12e585f

Browse files
authored
Fix unique tables (#413)
* Fix bug * Fix unique tables handling
1 parent 97ac1b3 commit 12e585f

File tree

2 files changed

+26
-45
lines changed

2 files changed

+26
-45
lines changed

ffcx/ir/elementtables.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def build_optimized_tables(
287287
ufl.algorithms.analysis.extract_sub_elements(all_elements))
288288
element_numbers = {element: i for i, element in enumerate(unique_elements)}
289289

290-
tables = existing_tables
291290
mt_tables = {}
292291

293292
for mt in modified_terminals:
@@ -365,20 +364,12 @@ def build_optimized_tables(
365364
tbl = tbl[:1, :, :, :]
366365

367366
# Check for existing identical table
368-
xname_found = False
369-
for xname in tables:
370-
if equal_tables(tbl, tables[xname]):
371-
xname_found = True
367+
for table_name in existing_tables:
368+
if equal_tables(tbl, existing_tables[table_name]):
369+
name = table_name
370+
tbl = existing_tables[name]
372371
break
373372

374-
if xname_found:
375-
name = xname
376-
# Retrieve existing table
377-
tbl = tables[name]
378-
else:
379-
# Store new table
380-
tables[name] = tbl
381-
382373
cell_offset = 0
383374
basix_element = create_element(element)
384375

ffcx/ir/integral.py

+22-32
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import itertools
1010
import logging
1111

12-
import numpy
13-
1412
import ufl
1513
from ffcx.ir.analysis.factorization import \
1614
compute_argument_factorization
@@ -83,7 +81,7 @@ def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_sh
8381
for i, v in S.nodes.items()
8482
if is_modified_terminal(v['expression'])}
8583

86-
mt_unique_table_reference = build_optimized_tables(
84+
mt_table_reference = build_optimized_tables(
8785
quadrature_rule,
8886
cell,
8987
integral_type,
@@ -92,18 +90,20 @@ def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_sh
9290
ir["unique_tables"],
9391
rtol=p["table_rtol"],
9492
atol=p["table_atol"])
95-
unique_tables = {v.name: v.values for v in mt_unique_table_reference.values()}
96-
unique_table_types = {v.name: v.ttype for v in mt_unique_table_reference.values()}
93+
94+
# Fetch unique tables for this quadrature rule
95+
table_types = {v.name: v.ttype for v in mt_table_reference.values()}
96+
tables = {v.name: v.values for v in mt_table_reference.values()}
9797

9898
S_targets = [i for i, v in S.nodes.items() if v.get('target', False)]
9999

100-
if 'zeros' in unique_table_types.values() and len(S_targets) == 1:
100+
if 'zeros' in table_types.values() and len(S_targets) == 1:
101101
# If there are any 'zero' tables, replace symbolically and rebuild graph
102102
#
103103
# TODO: Implement zero table elimination for non-scalar graphs
104104
for i, mt in initial_terminals.items():
105105
# Set modified terminals with zero tables to zero
106-
tr = mt_unique_table_reference.get(mt)
106+
tr = mt_table_reference.get(mt)
107107
if tr is not None and tr.ttype == "zeros":
108108
S.nodes[i]['expression'] = ufl.as_ufl(0.0)
109109

@@ -161,12 +161,12 @@ def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_sh
161161
if is_modified_terminal(expr):
162162
mt = analyse_modified_terminal(expr)
163163
F.nodes[i]['mt'] = mt
164-
tr = mt_unique_table_reference.get(mt)
164+
tr = mt_table_reference.get(mt)
165165
if tr is not None:
166166
F.nodes[i]['tr'] = tr
167167

168168
# Attach 'status' to each node: 'inactive', 'piecewise' or 'varying'
169-
analyse_dependencies(F, mt_unique_table_reference)
169+
analyse_dependencies(F, mt_table_reference)
170170

171171
# Output diagnostic graph as pdf
172172
if visualise:
@@ -208,8 +208,8 @@ def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_sh
208208
# Check if each *each* factor corresponding to this argument is piecewise
209209
all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == 'piecewise' for ifi in fi_ci)
210210
block_is_permuted = False
211-
for n in unames:
212-
if unique_tables[n].shape[0] > 1:
211+
for name in unames:
212+
if tables[name].shape[0] > 1:
213213
block_is_permuted = True
214214
ma_data = []
215215
for i, ma in enumerate(ma_indices):
@@ -240,28 +240,18 @@ def compute_integral_ir(cell, integral_type, entitytype, integrands, argument_sh
240240
for mad in blockdata.ma_data:
241241
active_table_names.add(mad.tabledata.name)
242242

243-
# Record all table types before dropping tables
244-
ir["unique_table_types"].update(unique_table_types)
245-
246-
# Drop tables not referenced from modified terminals
247-
# and tables of zeros and ones
248-
unused_ttypes = ("zeros", "ones")
249-
keep_table_names = set()
250243
for name in active_table_names:
251-
ttype = ir["unique_table_types"][name]
252-
if ttype not in unused_ttypes:
253-
if name in unique_tables:
254-
keep_table_names.add(name)
255-
unique_tables = {name: unique_tables[name] for name in keep_table_names}
256-
257-
# Add to global set of all tables
258-
for name, table in unique_tables.items():
259-
tbl = ir["unique_tables"].get(name)
260-
if tbl is not None and not numpy.allclose(
261-
tbl, table, rtol=p["table_rtol"], atol=p["table_atol"]):
262-
raise RuntimeError("Table values mismatch with same name.")
263-
264-
ir["unique_tables"].update(unique_tables)
244+
# Drop tables not referenced from modified terminals
245+
if name not in tables.keys():
246+
del tables[name]
247+
# Drop tables which are inlined
248+
if table_types[name] in ("zeros", "ones"):
249+
del tables[name]
250+
del table_types[name]
251+
252+
# Add tables and types for this quadrature rule to global tables dict
253+
ir["unique_tables"].update(tables)
254+
ir["unique_table_types"].update(table_types)
265255

266256
# Build IR dict for the given expressions
267257
# Store final ir for this num_points

0 commit comments

Comments
 (0)