diff --git a/devito/core/autotuning.py b/devito/core/autotuning.py index cb2209690a..23dac69d61 100644 --- a/devito/core/autotuning.py +++ b/devito/core/autotuning.py @@ -1,7 +1,6 @@ from collections import OrderedDict from itertools import combinations, product from functools import total_ordering -from sympy import SympifyError, sympify from devito.arch import KNL, KNL7210 from devito.ir import Backward, retrieve_iteration_tree @@ -9,6 +8,7 @@ from devito.mpi.distributed import MPI, MPINeighborhood from devito.mpi.routines import MPIMsgEnriched from devito.parameters import configuration +from devito.symbolics import normalize_args from devito.tools import filter_ordered, flatten, is_integer, prod from devito.types import Timer @@ -274,13 +274,7 @@ def calculate_nblocks(tree, blockable): def generate_block_shapes(blockable, args, level): - # Make sure all params are substitutable (ie cannot raise SympifyError) - rargs = {} - for k, v in args.items(): - try: - rargs[k] = sympify(v, strict=True) - except SympifyError: - continue + args = normalize_args(args) if not blockable: raise ValueError @@ -292,7 +286,7 @@ def generate_block_shapes(blockable, args, level): # Generate level-0 block shapes level_0 = [d for d, v in mapper.items() if v == 0] # Max attemptable block shape - max_bs = tuple((d.step, d.symbolic_size.subs(rargs)) for d in level_0) + max_bs = tuple((d.step, d.symbolic_size.subs(args)) for d in level_0) # Defaults (basic mode) ret = [tuple((d.step, v) for d in level_0) for v in options['blocksize-l0']] # Always try the entire iteration space (degenerate block) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 694f85b69e..eb35c68e91 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from functools import singledispatch -from sympy import Pow, Add, Mul, Min, Max +from sympy import Pow, Add, Mul, Min, Max, SympifyError, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort @@ -16,7 +16,8 @@ from devito.types.relational import Le, Lt, Gt, Ge __all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args', - 'uxreplace', 'Uxmapper', 'reuse_if_untouched', 'evalrel'] + 'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched', + 'evalrel'] def uxreplace(expr, rule): @@ -269,6 +270,22 @@ def subs_op_args(expr, args): return expr.subs({i.name: args[i.name] for i in expr.free_symbols if i.name in args}) +def normalize_args(args): + """ + Produce a new `args` dictionary in which only the actually substitutable + arguments are retained. This ensures that none of the subsequent substitutions + will throw `SympifyError` exceptions. + """ + retval = {} + for k, v in args.items(): + try: + retval[k] = sympify(v, strict=True) + except SympifyError: + continue + + return retval + + def reuse_if_untouched(expr, args, evaluate=False): """ Reconstruct `expr` iff any of the provided `args` is different than