From 7af3cc98a0ff9a1e549409dac3342077c9b78938 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 17 Jul 2024 09:40:08 -0400 Subject: [PATCH 1/2] api: Support combination of condition and factor for ConditionalDimension --- devito/ir/equations/equation.py | 7 ++++-- devito/types/dimension.py | 17 ++++++++++++- devito/types/relational.py | 43 ++++++++++++++++++++++++++++++++- tests/test_dimension.py | 29 ++++++++++++++++++++++ 4 files changed, 92 insertions(+), 4 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 8e478fe1d7..e85b3b1845 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -191,9 +191,12 @@ def __new__(cls, *args, **kwargs): if d.condition is None: conditionals[d] = GuardFactor(d) else: - conditionals[d] = diff2sympy(lower_exprs(d.condition)) + cond = diff2sympy(lower_exprs(d.condition)) + if d.factor is not None: + cond = sympy.And(cond, GuardFactor(d)) + conditionals[d] = cond if d.factor is not None: - expr = uxreplace(expr, {d: IntDiv(d.index, d.factor)}) + expr = uxreplace(expr, {d: IntDiv(d.fact_index, d.factor)}) conditionals = frozendict(conditionals) # Lower all Differentiable operations into SymPy operations diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 17981587a1..d1109e7c33 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -13,6 +13,8 @@ from devito.types.args import ArgProvider from devito.types.basic import Symbol, DataSymbol, Scalar from devito.types.constant import Constant +from devito.types.relational import relational_min + __all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension', 'CustomDimension', 'SteppingDimension', 'SubDimension', @@ -954,6 +956,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None, self._factor = factor else: raise ValueError("factor must be an integer or integer Constant") + self._condition = condition self._indirect = indirect @@ -974,6 +977,18 @@ def condition(self): def indirect(self): return self._indirect + @property + def fact_index(self): + if self.condition is None or self._factor is None: + return self.index + + # This is the corner case where both a condition and a factor are provided + # the index will need to be `self.parent - min(self.condition)` to avoid + # shifted indexing. E.g if you have `factor=2` and `condition=Ge(time, 10)` + # then the lowered index needs to be `(time - 10)/ 2` + ltkn = relational_min(self.condition, self.parent) + return self.index - ltkn + @cached_property def free_symbols(self): retval = set(super().free_symbols) @@ -1009,7 +1024,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None): # `factor` endpoints are legal, so we return them all. It's then # up to the caller to decide which one to pick upon reduction dim = alias or self - if dim._factor is None or size is None: + if dim.condition is not None or size is None: return defaults try: # Is it a symbolic factor? diff --git a/devito/types/relational.py b/devito/types/relational.py index 4aec1deb12..a7737a0023 100644 --- a/devito/types/relational.py +++ b/devito/types/relational.py @@ -1,8 +1,9 @@ """User API to specify relationals.""" +from functools import singledispatch import sympy -__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne'] +__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min'] class AbstractRel: @@ -208,3 +209,43 @@ def __new__(cls, lhs, rhs=0, subdomain=None, **kwargs): ops = {Ge: Lt, Gt: Le, Le: Gt, Lt: Ge} rev = {Ge: Le, Gt: Lt, Lt: Gt, Le: Ge} + + +def relational_min(expr, s): + """ + Infer the minimum valid value for symbol `s` in the expression `expr`. + For example + - if `expr` is `s < 10`, then the minimum valid value for `s` is 0 + - if `expr` is `s >= 10`, then the minimum valid value for `s` is 10 + """ + assert expr.has(s), "Symbol %s not found in expression %s" % (s, expr) + + return _relational_min(expr, s) + + +@singledispatch +def _relational_min(s, expr): + return 0 + + +@_relational_min.register(sympy.And) +def _(expr, s): + return max([_relational_min(e, s) for e in expr.args]) + + +@_relational_min.register(Gt) +@_relational_min.register(Lt) +def _(expr, s): + if s == expr.gts: + return expr.lts + 1 + else: + return 0 + + +@_relational_min.register(Ge) +@_relational_min.register(Le) +def _(expr, s): + if s == expr.gts: + return expr.lts + else: + return 0 diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 18482ca65d..1cf98d2e9a 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -1013,6 +1013,35 @@ def test_issue_1753(self): op.apply(time_M=1) assert np.all(np.flatnonzero(f.data) == [3, 30]) + def test_issue_2273(self): + grid = Grid(shape=(11, 11)) + time = grid.time_dim + + nt = 200 + bounds = (10, 100) + factor = 5 + + condition = And(Ge(time, bounds[0]), Le(time, bounds[1])) + + time_under = ConditionalDimension(name='timeu', parent=time, + factor=factor, condition=condition) + buffer_size = (bounds[1] - bounds[0] + factor) // factor + 1 + + rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt, + coordinates=[(.5, .5)]) + rec.data[:] = 1.0 + + u = TimeFunction(name='u', grid=grid, space_order=2) + usaved = TimeFunction(name='usaved', grid=grid, space_order=2, + time_dim=time_under, save=buffer_size) + + eq = [Eq(u.forward, u)] + rec.inject(field=u.forward, expr=rec) + [Eq(usaved, u)] + + op = Operator(eq) + op(time_m=0, time_M=nt-1) + expected = np.linspace(bounds[0], bounds[1], num=buffer_size-1) + assert np.allclose(usaved.data[:-1, 5, 5], expected) + def test_subsampled_fd(self): """ Test that the FD shortcuts are handled correctly with ConditionalDimensions From a112c16d631bf9981768eb19f960c3e4a9cc0f59 Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 17 Jul 2024 10:51:42 -0400 Subject: [PATCH 2/2] api: make sure reconstructed factor are None --- devito/ir/equations/equation.py | 12 ++++++++---- devito/types/dimension.py | 15 +-------------- devito/types/relational.py | 3 ++- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index e85b3b1845..487fcc7109 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -8,7 +8,7 @@ Stencil, detect_io, detect_accesses) from devito.symbolics import IntDiv, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin +from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min __all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax'] @@ -192,11 +192,15 @@ def __new__(cls, *args, **kwargs): conditionals[d] = GuardFactor(d) else: cond = diff2sympy(lower_exprs(d.condition)) - if d.factor is not None: + if d._factor is not None: cond = sympy.And(cond, GuardFactor(d)) conditionals[d] = cond - if d.factor is not None: - expr = uxreplace(expr, {d: IntDiv(d.fact_index, d.factor)}) + # Replace dimension with index + index = d.index + if d.condition is not None: + index = index - relational_min(d.condition, d.parent) + expr = uxreplace(expr, {d: IntDiv(index, d.factor)}) + conditionals = frozendict(conditionals) # Lower all Differentiable operations into SymPy operations diff --git a/devito/types/dimension.py b/devito/types/dimension.py index d1109e7c33..ae850f1b34 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -13,7 +13,6 @@ from devito.types.args import ArgProvider from devito.types.basic import Symbol, DataSymbol, Scalar from devito.types.constant import Constant -from devito.types.relational import relational_min __all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension', @@ -948,7 +947,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None, super().__init_finalize__(name, parent) # Always make the factor symbolic to allow overrides with different factor. - if factor is None: + if factor is None or factor == 1: self._factor = None elif is_integer(factor): self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32) @@ -977,18 +976,6 @@ def condition(self): def indirect(self): return self._indirect - @property - def fact_index(self): - if self.condition is None or self._factor is None: - return self.index - - # This is the corner case where both a condition and a factor are provided - # the index will need to be `self.parent - min(self.condition)` to avoid - # shifted indexing. E.g if you have `factor=2` and `condition=Ge(time, 10)` - # then the lowered index needs to be `(time - 10)/ 2` - ltkn = relational_min(self.condition, self.parent) - return self.index - ltkn - @cached_property def free_symbols(self): retval = set(super().free_symbols) diff --git a/devito/types/relational.py b/devito/types/relational.py index a7737a0023..36b208653d 100644 --- a/devito/types/relational.py +++ b/devito/types/relational.py @@ -218,7 +218,8 @@ def relational_min(expr, s): - if `expr` is `s < 10`, then the minimum valid value for `s` is 0 - if `expr` is `s >= 10`, then the minimum valid value for `s` is 10 """ - assert expr.has(s), "Symbol %s not found in expression %s" % (s, expr) + if not expr.has(s): + return 0 return _relational_min(expr, s)