Skip to content

Commit

Permalink
compiler: prevent temporary for local reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 27, 2023
1 parent a7d9c49 commit 867f4d2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
2 changes: 1 addition & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None):
self.op = op

def __enter__(self):
i = dv.Dimension(name='i',)
i = dv.Dimension(name='mri',)
self.n = dv.Function(name='n', shape=(1,), dimensions=(i,),
grid=self.grid, dtype=self.dtype)
self.n.data[0] = 0
Expand Down
5 changes: 3 additions & 2 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def normalize_reductions(cluster, sregistry, options):

processed = []
for e in cluster.exprs:
if e.is_Reduction and (e.lhs.is_Indexed or cluster.is_sparse):
if e.is_Reduction and e.lhs.is_Indexed and cluster.is_sparse:
# Transform `e` such that we reduce into a scalar (ultimately via
# atomic ops, though this part is carried out by a much later pass)
# For example, given `i = m[p_src]` (i.e., indirection array), turn:
Expand All @@ -471,7 +471,8 @@ def normalize_reductions(cluster, sregistry, options):
processed.extend([e.func(v, e.rhs, operation=None),
e.func(e.lhs, v)])

elif e.is_Reduction and e.lhs.is_Symbol and opt_mapify_reduce:
elif e.is_Reduction and e.lhs.is_Symbol and opt_mapify_reduce \
and not cluster.is_sparse:
# Transform `e` into what is in essence an explicit map-reduce
# For example, turn:
# `s += f(u[x], v[x], ...)`
Expand Down
13 changes: 12 additions & 1 deletion devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,18 @@ def is_dense(self):

@property
def is_sparse(self):
return not self.is_dense
"""
A cluster is sparse if it represent a sparse operation i.e if both
* The cluster contains sparse functions
* The cluster uses dense functions
If only the first case is true, the cluster only contains operation on the sparse
function itself without indirection and therefore only contains dense operations.
"""
return (any(f.is_SparseFunction for f in self.functions) and
len([f for f in self.functions
if (f.is_Function and not f.is_SparseFunction)]) > 0)

@property
def is_halo_touch(self):
Expand Down
5 changes: 5 additions & 0 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def apply(self, func):
kwargs['conditionals'] = {k: func(v) for k, v in self.conditionals.items()}
return self.func(*args, **kwargs)

def __repr__(self):
if not self.is_Reduction:
return super().__repr__()
else:
return '%s = %s(%s, %s)' % (self.lhs, self.operation, self.lhs, self.rhs)
# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__

Expand Down
4 changes: 2 additions & 2 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ def guard(self, expr=None):
condition=condition, indirect=True)

if expr is None:
out = self.indexify().xreplace({self._sparse_dim: cd})
out = self.indexify()._subs(self._sparse_dim, cd)
else:
functions = {f for f in retrieve_function_carriers(expr)
if f.is_SparseFunction}
out = indexify(expr).xreplace({f._sparse_dim: cd for f in functions})
out = indexify(expr).subs({f._sparse_dim: cd for f in functions})

return out, temps

Expand Down
26 changes: 25 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,13 +2669,37 @@ def test_sparse_const(self):

u = TimeFunction(name="u", grid=grid)
src = PrecomputedSparseTimeFunction(name="src", grid=grid, npoint=1, nt=11,
r=2, interpolation_coeffs=np.ones((1, 3, 2)))
r=2, interpolation_coeffs=np.ones((1, 3, 2)),
gridpoints=[[5, 5, 5]])
u.data.fill(1.)

op = Operator(src.interpolate(u))

cond = FindNodes(Conditional).visit(op)
assert len(cond) == 1
assert len(cond[0].args['then_body'][0].exprs) == 1
assert all(e.is_scalar for e in cond[0].args['then_body'][0].exprs)

op()
assert np.all(src.data == 8)

def test_reduction_local(self):
grid = Grid((11, 11))
d = Dimension("i")
n = Function(name="n", dimensions=(d,), shape=(1,))
u = Function(name="u", grid=grid)
u.data.fill(1.)

op = Operator(Inc(n[0], u))
op()

cond = FindNodes(Expression).visit(op)
iterations = FindNodes(Iteration).visit(op)
# Should not creat any temporary for the reduction
assert len(cond) == 1
assert "reduction(+:n[0])" in iterations[0].pragmas[0].value
assert n.data[0] == 11*11


class TestIsoAcoustic(object):

Expand Down

0 comments on commit 867f4d2

Please sign in to comment.