Skip to content

Commit

Permalink
Merge pull request #2413 from devitocodes/fix-2273
Browse files Browse the repository at this point in the history
api: Support combination of condition and factor for ConditionalDimension
  • Loading branch information
mloubout authored Jul 17, 2024
2 parents 0706ced + a112c16 commit 830a8cf
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
15 changes: 11 additions & 4 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Stencil, detect_io, detect_accesses)
from devito.symbolics import IntDiv, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min

__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax']

Expand Down Expand Up @@ -191,9 +191,16 @@ def __new__(cls, *args, **kwargs):
if d.condition is None:
conditionals[d] = GuardFactor(d)
else:
conditionals[d] = diff2sympy(lower_exprs(d.condition))
if d.factor is not None:
expr = uxreplace(expr, {d: IntDiv(d.index, d.factor)})
cond = diff2sympy(lower_exprs(d.condition))
if d._factor is not None:
cond = sympy.And(cond, GuardFactor(d))
conditionals[d] = cond
# Replace dimension with index
index = d.index
if d.condition is not None:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.factor)})

conditionals = frozendict(conditionals)

# Lower all Differentiable operations into SymPy operations
Expand Down
6 changes: 4 additions & 2 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from devito.types.basic import Symbol, DataSymbol, Scalar
from devito.types.constant import Constant


__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
'CustomDimension', 'SteppingDimension', 'SubDimension',
'MultiSubDimension', 'ConditionalDimension', 'ModuloDimension',
Expand Down Expand Up @@ -946,14 +947,15 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
super().__init_finalize__(name, parent)

# Always make the factor symbolic to allow overrides with different factor.
if factor is None:
if factor is None or factor == 1:
self._factor = None
elif is_integer(factor):
self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32)
elif factor.is_Constant and is_integer(factor.data):
self._factor = factor
else:
raise ValueError("factor must be an integer or integer Constant")

self._condition = condition
self._indirect = indirect

Expand Down Expand Up @@ -1009,7 +1011,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
# `factor` endpoints are legal, so we return them all. It's then
# up to the caller to decide which one to pick upon reduction
dim = alias or self
if dim._factor is None or size is None:
if dim.condition is not None or size is None:
return defaults
try:
# Is it a symbolic factor?
Expand Down
44 changes: 43 additions & 1 deletion devito/types/relational.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""User API to specify relationals."""
from functools import singledispatch

import sympy

__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne']
__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min']


class AbstractRel:
Expand Down Expand Up @@ -208,3 +209,44 @@ def __new__(cls, lhs, rhs=0, subdomain=None, **kwargs):

ops = {Ge: Lt, Gt: Le, Le: Gt, Lt: Ge}
rev = {Ge: Le, Gt: Lt, Lt: Gt, Le: Ge}


def relational_min(expr, s):
"""
Infer the minimum valid value for symbol `s` in the expression `expr`.
For example
- if `expr` is `s < 10`, then the minimum valid value for `s` is 0
- if `expr` is `s >= 10`, then the minimum valid value for `s` is 10
"""
if not expr.has(s):
return 0

return _relational_min(expr, s)


@singledispatch
def _relational_min(s, expr):
return 0


@_relational_min.register(sympy.And)
def _(expr, s):
return max([_relational_min(e, s) for e in expr.args])


@_relational_min.register(Gt)
@_relational_min.register(Lt)
def _(expr, s):
if s == expr.gts:
return expr.lts + 1
else:
return 0


@_relational_min.register(Ge)
@_relational_min.register(Le)
def _(expr, s):
if s == expr.gts:
return expr.lts
else:
return 0
29 changes: 29 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,35 @@ def test_issue_1753(self):
op.apply(time_M=1)
assert np.all(np.flatnonzero(f.data) == [3, 30])

def test_issue_2273(self):
grid = Grid(shape=(11, 11))
time = grid.time_dim

nt = 200
bounds = (10, 100)
factor = 5

condition = And(Ge(time, bounds[0]), Le(time, bounds[1]))

time_under = ConditionalDimension(name='timeu', parent=time,
factor=factor, condition=condition)
buffer_size = (bounds[1] - bounds[0] + factor) // factor + 1

rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt,
coordinates=[(.5, .5)])
rec.data[:] = 1.0

u = TimeFunction(name='u', grid=grid, space_order=2)
usaved = TimeFunction(name='usaved', grid=grid, space_order=2,
time_dim=time_under, save=buffer_size)

eq = [Eq(u.forward, u)] + rec.inject(field=u.forward, expr=rec) + [Eq(usaved, u)]

op = Operator(eq)
op(time_m=0, time_M=nt-1)
expected = np.linspace(bounds[0], bounds[1], num=buffer_size-1)
assert np.allclose(usaved.data[:-1, 5, 5], expected)

def test_subsampled_fd(self):
"""
Test that the FD shortcuts are handled correctly with ConditionalDimensions
Expand Down

0 comments on commit 830a8cf

Please sign in to comment.