Skip to content

Commit

Permalink
compiler: improve interpolation parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 19, 2023
1 parent f7ab007 commit 963ec1b
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 33 deletions.
5 changes: 5 additions & 0 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _inject(self, field, expr, implicit_dims=None):
# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
fields, exprs = as_tuple(field), as_tuple(expr)

# Provide either one expr per field or on expr for all fields
if len(fields) > 1:
if len(exprs) == 1:
Expand All @@ -323,6 +324,10 @@ def _inject(self, field, expr, implicit_dims=None):

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)
# Move all temporaries inside inner loop to improve parallelism
# Can only be done for inject as interpolation need a temporary
# summing temp that wouldn't allow collapsing
implicit_dims = implicit_dims + tuple(r.parent for r in self._rdim)

variables = variables + list(fields)

Expand Down
11 changes: 2 additions & 9 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,12 @@ def test_cache_blocking_structure_optrelax():

op = Operator(eqns, opt=('advanced', {'blockrelax': True}))

bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0', 'p_src1_blk0'})
bns, _ = assert_blocking(op, {'x0_blk0', 'p_src0_blk0'})

iters = FindNodes(Iteration).visit(bns['p_src0_blk0'])
assert len(iters) == 2
assert iters[0].dim.is_Block
assert iters[1].dim.is_Block

iters = FindNodes(Iteration).visit(bns['p_src1_blk0'])
assert len(iters) == 5
assert iters[0].dim.is_Block
assert iters[1].dim.is_Block
for i in range(2, 5):
assert not iters[i].dim.is_Block


def test_cache_blocking_structure_optrelax_customdim():
Expand Down Expand Up @@ -965,7 +958,7 @@ def test_parallel_prec_inject(self):
iterations = FindNodes(Iteration).visit(op0)

assert not iterations[0].pragmas
assert 'omp for collapse' in iterations[2].pragmas[0].value
assert 'omp for collapse' in iterations[1].pragmas[0].value


class TestNestedParallelism(object):
Expand Down
29 changes: 14 additions & 15 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_scheduling_after_rewrite():
trees = retrieve_iteration_tree(op)

# Check loop nest structure
assert all(i.dim is j for i, j in zip(trees[1], grid.dimensions)) # time invariant
assert trees[2].root.dim is grid.time_dim
assert all(trees[2].root.dim is tree.root.dim for tree in trees[2:])
assert all(i.dim is j for i, j in zip(trees[0], grid.dimensions)) # time invariant
assert trees[1].root.dim is grid.time_dim
assert all(trees[1].root.dim is tree.root.dim for tree in trees[1:])


@pytest.mark.parametrize('exprs,expected,min_cost', [
Expand Down Expand Up @@ -1687,7 +1687,7 @@ def test_drop_redundants_after_fusion(self, rotate):
op = Operator(eqns, opt=('advanced', {'cire-rotate': rotate}))

arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
assert len(arrays) == 4
assert len(arrays) == 2
assert all(i._mem_heap and not i._mem_external for i in arrays)

def test_full_shape_big_temporaries(self):
Expand Down Expand Up @@ -2711,11 +2711,10 @@ def test_fullopt(self):
assert np.isclose(summary0[('section0', None)].oi, 2.851, atol=0.001)

assert summary1[('section0', None)].ops == 9
assert summary1[('section1', None)].ops == 9
assert summary1[('section2', None)].ops == 31
assert summary1[('section3', None)].ops == 26
assert summary1[('section4', None)].ops == 22
assert np.isclose(summary1[('section2', None)].oi, 1.767, atol=0.001)
assert summary1[('section1', None)].ops == 31
assert summary1[('section2', None)].ops == 88
assert summary1[('section3', None)].ops == 22
assert np.isclose(summary1[('section1', None)].oi, 1.767, atol=0.001)

assert np.allclose(u0.data, u1.data, atol=10e-5)
assert np.allclose(rec0.data, rec1.data, atol=10e-5)
Expand Down Expand Up @@ -2775,8 +2774,8 @@ def test_fullopt(self):
assert np.allclose(self.tti_noopt[1].data, rec.data, atol=10e-1)

# Check expected opcount/oi
assert summary[('section3', None)].ops == 92
assert np.isclose(summary[('section3', None)].oi, 2.074, atol=0.001)
assert summary[('section2', None)].ops == 92
assert np.isclose(summary[('section2', None)].oi, 2.074, atol=0.001)

# With optimizations enabled, there should be exactly four BlockDimensions
op = wavesolver.op_fwd()
Expand All @@ -2794,7 +2793,7 @@ def test_fullopt(self):
# 3 Arrays are defined globally for the sparse positions temporaries
# and two additional bock-sized Arrays are defined locally
arrays = [i for i in FindSymbols().visit(op) if i.is_Array]
extra_arrays = 2+3+3
extra_arrays = 2+3
assert len(arrays) == 4 + extra_arrays
assert all(i._mem_heap and not i._mem_external for i in arrays)
bns, pbs = assert_blocking(op, {'x0_blk0'})
Expand Down Expand Up @@ -2830,7 +2829,7 @@ def test_fullopt_w_mpi(self):
def test_opcounts(self, space_order, expected):
op = self.tti_operator(opt='advanced', space_order=space_order)
sections = list(op.op_fwd()._profiler._sections.values())
assert sections[3].sops == expected
assert sections[2].sops == expected

@switchconfig(profiling='advanced')
@pytest.mark.parametrize('space_order,expected', [
Expand All @@ -2840,8 +2839,8 @@ def test_opcounts_adjoint(self, space_order, expected):
wavesolver = self.tti_operator(opt=('advanced', {'openmp': False}))
op = wavesolver.op_adj()

assert op._profiler._sections['section3'].sops == expected
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 7+3+3
assert op._profiler._sections['section2'].sops == expected
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == 7+3


class TestTTIv2(object):
Expand Down
18 changes: 10 additions & 8 deletions tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def test_tile_insteadof_collapse(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 6
assert len(trees) == 4

assert trees[1][1].pragmas[0].value ==\
assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[2][1].pragmas[0].value ==\
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(32,4) present(u)'
# Only the AFFINE Iterations are tiled
assert trees[4][1].pragmas[0].value ==\
'acc parallel loop present(src,src_coords,u) deviceptr(r1,r2,r3)'
assert trees[3][1].pragmas[0].value ==\
'acc parallel loop collapse(4) present(src,src_coords,u)'

@pytest.mark.parametrize('par_tile', [((32, 4, 4), (8, 8)), ((32, 4), (8, 8)),
((32, 4, 4), (8, 8, 8))])
Expand All @@ -130,12 +130,14 @@ def test_multiple_tile_sizes(self, par_tile):
opt=('advanced', {'par-tile': par_tile}))

trees = retrieve_iteration_tree(op)
assert len(trees) == 6
assert len(trees) == 4

assert trees[1][1].pragmas[0].value ==\
assert trees[0][1].pragmas[0].value ==\
'acc parallel loop tile(32,4,4) present(u)'
assert trees[2][1].pragmas[0].value ==\
assert trees[1][1].pragmas[0].value ==\
'acc parallel loop tile(8,8) present(u)'
assert trees[3][1].pragmas[0].value ==\
'acc parallel loop collapse(4) present(src,src_coords,u)'

def test_multi_tile_blocking_structure(self):
grid = Grid(shape=(8, 8, 8))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sympy import Float

from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction,
Function, TimeFunction, DefaultDimension, Eq,
Function, TimeFunction, DefaultDimension, Eq, switchconfig,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction,
MatrixSparseTimeFunction)
from examples.seismic import (demo_model, TimeAxis, RickerSource, Receiver,
Expand Down Expand Up @@ -736,6 +736,7 @@ class SparseFirst(SparseFunction):
assert np.allclose(s.data, expected)


@switchconfig(safe_math=True)
def test_inject_function():
nt = 11

Expand Down

0 comments on commit 963ec1b

Please sign in to comment.