Skip to content

Commit

Permalink
api: make sure reconstructed factor are None
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 17, 2024
1 parent efe4344 commit 546b61d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 19 deletions.
12 changes: 8 additions & 4 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Stencil, detect_io, detect_accesses)
from devito.symbolics import IntDiv, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min

__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax']

Expand Down Expand Up @@ -192,11 +192,15 @@ def __new__(cls, *args, **kwargs):
conditionals[d] = GuardFactor(d)
else:
cond = diff2sympy(lower_exprs(d.condition))
if d.factor is not None:
if d._factor is not None:
cond = sympy.And(cond, GuardFactor(d))
conditionals[d] = cond
if d.factor is not None:
expr = uxreplace(expr, {d: IntDiv(d.fact_index, d.factor)})
# Replace dimension with index
index = d.index
if d.condition is not None:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.factor)})

conditionals = frozendict(conditionals)

# Lower all Differentiable operations into SymPy operations
Expand Down
15 changes: 1 addition & 14 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from devito.types.args import ArgProvider
from devito.types.basic import Symbol, DataSymbol, Scalar
from devito.types.constant import Constant
from devito.types.relational import relational_min


__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
Expand Down Expand Up @@ -866,7 +865,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
super().__init_finalize__(name, parent)

# Always make the factor symbolic to allow overrides with different factor.
if factor is None:
if factor is None or factor == 1:
self._factor = None
elif is_integer(factor):
self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32)
Expand Down Expand Up @@ -895,18 +894,6 @@ def condition(self):
def indirect(self):
return self._indirect

@property
def fact_index(self):
if self.condition is None or self._factor is None:
return self.index

# This is the corner case where both a condition and a factor are provided
# the index will need to be `self.parent - min(self.condition)` to avoid
# shifted indexing. E.g if you have `factor=2` and `condition=Ge(time, 10)`
# then the lowered index needs to be `(time - 10)/ 2`
ltkn = relational_min(self.condition, self.parent)
return self.index - ltkn

@cached_property
def free_symbols(self):
retval = set(super().free_symbols)
Expand Down
3 changes: 2 additions & 1 deletion devito/types/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def relational_min(expr, s):
- if `expr` is `s < 10`, then the minimum valid value for `s` is 0
- if `expr` is `s >= 10`, then the minimum valid value for `s` is 10
"""
assert expr.has(s), "Symbol %s not found in expression %s" % (s, expr)
if not expr.has(s):
return 0

return _relational_min(expr, s)

Expand Down

0 comments on commit 546b61d

Please sign in to comment.