diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index e85b3b1845e..487fcc71093 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -8,7 +8,7 @@ Stencil, detect_io, detect_accesses) from devito.symbolics import IntDiv, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin +from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min __all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax'] @@ -192,11 +192,15 @@ def __new__(cls, *args, **kwargs): conditionals[d] = GuardFactor(d) else: cond = diff2sympy(lower_exprs(d.condition)) - if d.factor is not None: + if d._factor is not None: cond = sympy.And(cond, GuardFactor(d)) conditionals[d] = cond - if d.factor is not None: - expr = uxreplace(expr, {d: IntDiv(d.fact_index, d.factor)}) + # Replace dimension with index + index = d.index + if d.condition is not None: + index = index - relational_min(d.condition, d.parent) + expr = uxreplace(expr, {d: IntDiv(index, d.factor)}) + conditionals = frozendict(conditionals) # Lower all Differentiable operations into SymPy operations diff --git a/devito/types/dimension.py b/devito/types/dimension.py index b6deeb32bde..6dd0d728166 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -13,7 +13,6 @@ from devito.types.args import ArgProvider from devito.types.basic import Symbol, DataSymbol, Scalar from devito.types.constant import Constant -from devito.types.relational import relational_min __all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension', @@ -866,7 +865,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None, super().__init_finalize__(name, parent) # Always make the factor symbolic to allow overrides with different factor. - if factor is None: + if factor is None or factor == 1: self._factor = None elif is_integer(factor): self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32) @@ -895,18 +894,6 @@ def condition(self): def indirect(self): return self._indirect - @property - def fact_index(self): - if self.condition is None or self._factor is None: - return self.index - - # This is the corner case where both a condition and a factor are provided - # the index will need to be `self.parent - min(self.condition)` to avoid - # shifted indexing. E.g if you have `factor=2` and `condition=Ge(time, 10)` - # then the lowered index needs to be `(time - 10)/ 2` - ltkn = relational_min(self.condition, self.parent) - return self.index - ltkn - @cached_property def free_symbols(self): retval = set(super().free_symbols) diff --git a/devito/types/relational.py b/devito/types/relational.py index a7737a0023a..36b208653d3 100644 --- a/devito/types/relational.py +++ b/devito/types/relational.py @@ -218,7 +218,8 @@ def relational_min(expr, s): - if `expr` is `s < 10`, then the minimum valid value for `s` is 0 - if `expr` is `s >= 10`, then the minimum valid value for `s` is 10 """ - assert expr.has(s), "Symbol %s not found in expression %s" % (s, expr) + if not expr.has(s): + return 0 return _relational_min(expr, s)