Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: Support combination of condition and factor for ConditionalDimension #2413

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Stencil, detect_io, detect_accesses)
from devito.symbolics import IntDiv, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min

__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax']

Expand Down Expand Up @@ -191,9 +191,16 @@ def __new__(cls, *args, **kwargs):
if d.condition is None:
conditionals[d] = GuardFactor(d)
else:
conditionals[d] = diff2sympy(lower_exprs(d.condition))
if d.factor is not None:
expr = uxreplace(expr, {d: IntDiv(d.index, d.factor)})
cond = diff2sympy(lower_exprs(d.condition))
if d._factor is not None:
cond = sympy.And(cond, GuardFactor(d))
conditionals[d] = cond
# Replace dimension with index
index = d.index
if d.condition is not None:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.factor)})

conditionals = frozendict(conditionals)

# Lower all Differentiable operations into SymPy operations
Expand Down
6 changes: 4 additions & 2 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from devito.types.basic import Symbol, DataSymbol, Scalar
from devito.types.constant import Constant


__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
'CustomDimension', 'SteppingDimension', 'SubDimension',
'MultiSubDimension', 'ConditionalDimension', 'ModuloDimension',
Expand Down Expand Up @@ -946,14 +947,15 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
super().__init_finalize__(name, parent)

# Always make the factor symbolic to allow overrides with different factor.
if factor is None:
if factor is None or factor == 1:
self._factor = None
elif is_integer(factor):
self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32)
elif factor.is_Constant and is_integer(factor.data):
self._factor = factor
else:
raise ValueError("factor must be an integer or integer Constant")

self._condition = condition
self._indirect = indirect

Expand Down Expand Up @@ -1009,7 +1011,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
# `factor` endpoints are legal, so we return them all. It's then
# up to the caller to decide which one to pick upon reduction
dim = alias or self
if dim._factor is None or size is None:
if dim.condition is not None or size is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we don't check _factor anymore here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, basically if there is a condition then you can't safely use the factor to infer arg bounds

return defaults
try:
# Is it a symbolic factor?
Expand Down
44 changes: 43 additions & 1 deletion devito/types/relational.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""User API to specify relationals."""
from functools import singledispatch

import sympy

__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne']
__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min']


class AbstractRel:
Expand Down Expand Up @@ -208,3 +209,44 @@ def __new__(cls, lhs, rhs=0, subdomain=None, **kwargs):

ops = {Ge: Lt, Gt: Le, Le: Gt, Lt: Ge}
rev = {Ge: Le, Gt: Lt, Lt: Gt, Le: Ge}


def relational_min(expr, s):
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved
"""
Infer the minimum valid value for symbol `s` in the expression `expr`.
For example
- if `expr` is `s < 10`, then the minimum valid value for `s` is 0
- if `expr` is `s >= 10`, then the minimum valid value for `s` is 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or at least this rang a bell when reading the docstring here

"""
if not expr.has(s):
return 0

return _relational_min(expr, s)


@singledispatch
def _relational_min(s, expr):
return 0


@_relational_min.register(sympy.And)
def _(expr, s):
return max([_relational_min(e, s) for e in expr.args])


@_relational_min.register(Gt)
@_relational_min.register(Lt)
def _(expr, s):
if s == expr.gts:
return expr.lts + 1
else:
return 0


@_relational_min.register(Ge)
@_relational_min.register(Le)
def _(expr, s):
if s == expr.gts:
return expr.lts
else:
return 0
29 changes: 29 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,35 @@ def test_issue_1753(self):
op.apply(time_M=1)
assert np.all(np.flatnonzero(f.data) == [3, 30])

def test_issue_2273(self):
grid = Grid(shape=(11, 11))
time = grid.time_dim

nt = 200
bounds = (10, 100)
factor = 5

condition = And(Ge(time, bounds[0]), Le(time, bounds[1]))

time_under = ConditionalDimension(name='timeu', parent=time,
factor=factor, condition=condition)
buffer_size = (bounds[1] - bounds[0] + factor) // factor + 1

rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt,
coordinates=[(.5, .5)])
rec.data[:] = 1.0

u = TimeFunction(name='u', grid=grid, space_order=2)
usaved = TimeFunction(name='usaved', grid=grid, space_order=2,
time_dim=time_under, save=buffer_size)

eq = [Eq(u.forward, u)] + rec.inject(field=u.forward, expr=rec) + [Eq(usaved, u)]

op = Operator(eq)
op(time_m=0, time_M=nt-1)
expected = np.linspace(bounds[0], bounds[1], num=buffer_size-1)
assert np.allclose(usaved.data[:-1, 5, 5], expected)

def test_subsampled_fd(self):
"""
Test that the FD shortcuts are handled correctly with ConditionalDimensions
Expand Down
Loading