Skip to content

Commit

Permalink
Merge pull request #2301 from devitocodes/less-int-arithm
Browse files Browse the repository at this point in the history
compiler: Generate less integer arithmetic
  • Loading branch information
FabioLuporini authored Feb 2, 2024
2 parents 0ccf5fd + ce8613a commit 0f249f2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
6 changes: 3 additions & 3 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def dspace(self):
continue

intervals = [Interval(d,
min([minimum(i) for i in offs]),
max([maximum(i) for i in offs]))
for d, offs in v.items()]
min([minimum(i, ispace=self.ispace) for i in o]),
max([maximum(i, ispace=self.ispace) for i in o]))
for d, o in v.items()]
intervals = IntervalGroup(intervals)

# Factor in the IterationSpace -- if the min/max points aren't zero,
Expand Down
24 changes: 18 additions & 6 deletions devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,30 +308,42 @@ def _relational(expr, callback, udims=None):
return expr.subs(mapper)


def minimum(expr, udims=None):
def minimum(expr, udims=None, ispace=None):
"""
Substitute the unbounded Dimensions in `expr` with their minimum point.
Unbounded Dimensions whose possible minimum value is not known are ignored.
"""
return _relational(expr, lambda e: e._min, udims)
def callback(sd):
try:
return sd._min + ispace[sd].lower
except (TypeError, KeyError):
return sd._min

return _relational(expr, callback, udims)


def maximum(expr, udims=None):
def maximum(expr, udims=None, ispace=None):
"""
Substitute the unbounded Dimensions in `expr` with their maximum point.
Unbounded Dimensions whose possible maximum value is not known are ignored.
"""
return _relational(expr, lambda e: e._max, udims)
def callback(sd):
try:
return sd._max + ispace[sd].upper
except (TypeError, KeyError):
return sd._max

return _relational(expr, callback, udims)


def extrema(expr):
def extrema(expr, ispace=None):
"""
The minimum and maximum extrema assumed by `expr` once the unbounded
Dimensions are resolved.
"""
return Extrema(minimum(expr), maximum(expr))
return Extrema(minimum(expr, ispace=ispace), maximum(expr, ispace=ispace))


def minmax_index(expr, d):
Expand Down
18 changes: 12 additions & 6 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
extra = (c.ispace.itdims + dims,)
ispace = IterationSpace.union(c.ispace, ispace0, relations=extra)

# Set the IterationSpace along the StencilDimensions to start from 0
# (rather than the default `d._min`) to minimize the amount of integer
# arithmetic to calculate the various index access functions
for d in dims:
ispace = ispace.translate(d, -d._min)

try:
s = reusables.pop()
assert s.dtype is w.dtype
Expand All @@ -125,12 +131,12 @@ def _core(expr, c, weights, reusables, mapper, sregistry):
ispace1 = ispace.project(lambda d: d is not dims[-1])
processed.insert(0, c.rebuild(exprs=expr0, ispace=ispace1))

# Transform e.g. `w[i0] -> w[i0 + 2]` for alignment with the
# StencilDimensions starting points
subs = {expr.weights:
expr.weights.subs(d, d - d._min)
for d in dims}
expr1 = Inc(s, uxreplace(expr.expr, subs))
# Transform e.g. `r0[x + i0 + 2, y] -> r0[x + i0, y, z]` for alignment
# with the shifted `ispace`
base = expr.base
for d in dims:
base = base.subs(d, d + d._min)
expr1 = Inc(s, base*expr.weights)
processed.append(c.rebuild(exprs=expr1, ispace=ispace))

# Track lowered IndexDerivative for subsequent optimization by the caller
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_multiple_cross_derivs(self, coeffs, expected):

# w0, w1, ...
functions = FindSymbols().visit(op)
weights = [f for f in functions if isinstance(f, Weights)]
weights = {f for f in functions if isinstance(f, Weights)}
assert len(weights) == expected


Expand Down

0 comments on commit 0f249f2

Please sign in to comment.