Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Generate less integer arithmetic #2301

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def dspace(self):
continue

intervals = [Interval(d,
min([minimum(i) for i in offs]),
max([maximum(i) for i in offs]))
for d, offs in v.items()]
min([minimum(i, ispace=self.ispace) for i in o]),
max([maximum(i, ispace=self.ispace) for i in o]))
for d, o in v.items()]
intervals = IntervalGroup(intervals)

# Factor in the IterationSpace -- if the min/max points aren't zero,
Expand Down
24 changes: 18 additions & 6 deletions devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,30 +308,42 @@ def _relational(expr, callback, udims=None):
return expr.subs(mapper)


def minimum(expr, udims=None):
def minimum(expr, udims=None, ispace=None):
"""
Substitute the unbounded Dimensions in `expr` with their minimum point.

Unbounded Dimensions whose possible minimum value is not known are ignored.
"""
return _relational(expr, lambda e: e._min, udims)
def callback(sd):
try:
return sd._min + ispace[sd].lower
except (TypeError, KeyError):
return sd._min

return _relational(expr, callback, udims)


def maximum(expr, udims=None):
def maximum(expr, udims=None, ispace=None):
"""
Substitute the unbounded Dimensions in `expr` with their maximum point.

Unbounded Dimensions whose possible maximum value is not known are ignored.
"""
return _relational(expr, lambda e: e._max, udims)
def callback(sd):
try:
return sd._max + ispace[sd].upper
except (TypeError, KeyError):
return sd._max

return _relational(expr, callback, udims)


def extrema(expr):
def extrema(expr, ispace=None):
"""
The minimum and maximum extrema assumed by `expr` once the unbounded
Dimensions are resolved.
"""
return Extrema(minimum(expr), maximum(expr))
return Extrema(minimum(expr, ispace=ispace), maximum(expr, ispace=ispace))


def minmax_index(expr, d):
Expand Down
18 changes: 12 additions & 6 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
extra = (c.ispace.itdims + dims,)
ispace = IterationSpace.union(c.ispace, ispace0, relations=extra)

# Set the IterationSpace along the StencilDimensions to start from 0
# (rather than the default `d._min`) to minimize the amount of integer
# arithmetic to calculate the various index access functions
for d in dims:
ispace = ispace.translate(d, -d._min)

try:
s = reusables.pop()
assert s.dtype is w.dtype
Expand All @@ -125,12 +131,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))

# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights:
expr.weights.subs(d, d - d._min)
for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
# Transform e.g. `r0[x + i0 + 2, y] -> r0[x + i0, y, z]` for alignment
# with the shifted `ispace`
base = expr.base
for d in dims:
base = base.subs(d, d + d._min)
expr1 = Inc(s, base*expr.weights)
processed.append(c.rebuild(exprs=expr1, ispace=ispace))

# Track lowered IndexDerivative for subsequent optimization by the caller
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_multiple_cross_derivs(self, coeffs, expected):

# w0, w1, ...
functions = FindSymbols().visit(op)
weights = [f for f in functions if isinstance(f, Weights)]
weights = {f for f in functions if isinstance(f, Weights)}
assert len(weights) == expected


Expand Down
Loading