Skip to content

Commit

Permalink
compiler: Abstract orchestration routines
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Nov 10, 2023
1 parent c39112a commit 2cbc4f4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 55 deletions.
13 changes: 5 additions & 8 deletions devito/passes/iet/asynchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.symbolics import (CondEq, CondNe, FieldFromComposite, FieldFromPointer,
Null)
from devito.tools import split
from devito.types import (Lock, Pointer, PThreadArray, QueueID, SharedData, Symbol,
from devito.types import (Lock, Pointer, PThreadArray, QueueID, SharedData, Temp,
VolatileInt)

__all__ = ['pthreadify']
Expand Down Expand Up @@ -194,7 +194,7 @@ def lower_async_calls(iet, track=None, sregistry=None):
condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1)
activation = [While(condition)]
else:
d = Symbol(name=sregistry.make_name(prefix=threads.index.name))
d = Temp(name=sregistry.make_name(prefix=threads.index.name))
condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1)
activation = [DummyExpr(d, 0),
While(condition, DummyExpr(d, (d + 1) % threads.size))]
Expand All @@ -203,12 +203,9 @@ def lower_async_calls(iet, track=None, sregistry=None):
activation.append(
DummyExpr(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 2)
)
activation = List(
header=[c.Line(), c.Comment("Activate `%s`" % threads.name)],
body=activation,
footer=c.Line()
)
mapper[n] = activation
name = sregistry.make_name(prefix='activate')
efunc = efuncs[name] = make_callable(name, activation)
mapper[n] = Call(name, efunc.parameters)

if mapper:
# Inject activation
Expand Down
14 changes: 7 additions & 7 deletions devito/passes/iet/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ def __init__(self, sregistry):

def _make_waitlock(self, iet, sync_ops, *args):
waitloop = List(
header=c.Comment("Wait for `%s` to be copied to the host" %
header=c.Comment("Wait for `%s` to be transferred" %
",".join(s.target.name for s in sync_ops)),
body=BusyWait(Or(*[CondEq(s.handle, 0) for s in sync_ops])),
footer=c.Line()
)

iet = List(body=(waitloop,) + iet.body)
Expand All @@ -50,12 +49,13 @@ def _make_releaselock(self, iet, sync_ops, *args):
pre.append(BusyWait(Or(*[CondNe(s.handle, 2) for s in sync_ops])))
pre.extend(DummyExpr(s.handle, 0) for s in sync_ops)

iet = List(
header=c.Comment("Release lock(s) as soon as possible"),
body=pre + [iet]
)
name = self.sregistry.make_name(prefix="release_lock")
parameters = derive_parameters(pre, ordering='canonical')
efunc = Callable(name, pre, 'void', parameters, 'static')

return iet, []
iet = List(body=[Call(name, efunc.parameters)] + [iet])

return iet, [efunc]

def _make_withlock(self, iet, sync_ops, layer):
body, prefix = withlock(layer, iet, sync_ops, self.lang, self.sregistry)
Expand Down
71 changes: 31 additions & 40 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,10 @@ def test_tasking_in_isolation(self, opt):
assert len(sections) == 3
assert str(sections[0].body[0].body[0].body[0].body[0]) == 'while(lock0[0] == 0);'
body = sections[2].body[0].body[0]
body = op._func_table['release_lock0'].root.body
assert str(body.body[0].condition) == 'Ne(lock0[0], 2)'
assert str(body.body[1]) == 'lock0[0] = 0;'
body = body.body[2]
body = op._func_table['activate0'].root.body
assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)'
assert str(body.body[1]) == 'sdata0[0].time = time;'
assert str(body.body[2]) == 'sdata0[0].flag = 2;'
Expand Down Expand Up @@ -207,7 +208,7 @@ def test_tasking_fused(self):
assert len(retrieve_iteration_tree(op)) == 3
locks = [i for i in FindSymbols().visit(op) if isinstance(i, Lock)]
assert len(locks) == 1 # Only 1 because it's only `tmp` that needs protection
assert len(op._func_table) == 3
assert len(op._func_table) == 5
exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root)
b = 17 if configuration['language'] == 'openacc' else 16 # No `qid` w/ OMP
assert str(exprs[b]) == 'const int deviceid = sdata->deviceid;'
Expand Down Expand Up @@ -244,26 +245,15 @@ def test_tasking_unfused_two_locks(self):

# Check generated code
assert len(retrieve_iteration_tree(op)) == 3
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 3
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 4
sections = FindNodes(Section).visit(op)
assert len(sections) == 4
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
assert str(body.body[0].condition) == 'Ne(lock0[0], 2)'
assert str(body.body[1]) == 'lock0[0] = 0;' # Set-lock
body = body.body[2]
assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)' # Wait-thread
assert str(body.body[1]) == 'sdata0[0].time = time;'
assert str(body.body[2]) == 'sdata0[0].flag = 2;'
body = sections[3].body[0].body[0]
assert str(body.body[0].condition) == 'Ne(lock1[0], 2)'
assert str(body.body[1]) == 'lock1[0] = 0;' # Set-lock
body = body.body[2]
assert str(body.body[0].condition) == 'Ne(sdata1[0].flag, 1)' # Wait-thread
assert str(body.body[1]) == 'sdata1[0].time = time;'
assert str(body.body[2]) == 'sdata1[0].flag = 2;'
assert len(op._func_table) == 3
assert str(body.body[0]) == 'release_lock0(lock0);'
assert str(body.body[1]) == 'activate0(time,sdata0);'
assert len(op._func_table) == 5
exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root)
b = 18 if configuration['language'] == 'openacc' else 17 # No `qid` w/ OMP
assert str(exprs[b]) == 'lock0[0] = 1;'
Expand Down Expand Up @@ -300,15 +290,15 @@ def test_tasking_forcefuse(self):
assert len(sections) == 3
assert (str(sections[1].body[0].body[0].body[0].body[0]) ==
'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock
body = sections[2].body[0].body[0]
body = op._func_table['release_lock0'].root.body
assert str(body.body[0].condition) == 'Ne(lock0[0], 2) | Ne(lock1[0], 2)'
assert str(body.body[1]) == 'lock0[0] = 0;' # Set-lock
assert str(body.body[2]) == 'lock1[0] = 0;' # Set-lock
body = body.body[3]
body = op._func_table['activate0'].root.body
assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)' # Wait-thread
assert str(body.body[1]) == 'sdata0[0].time = time;'
assert str(body.body[2]) == 'sdata0[0].flag = 2;'
assert len(op._func_table) == 3
assert len(op._func_table) == 5
exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root)
b = 21 if configuration['language'] == 'openacc' else 20 # No `qid` w/ OMP
assert str(exprs[b]) == 'lock0[0] = 1;'
Expand Down Expand Up @@ -369,11 +359,12 @@ def test_tasking_multi_output(self):
assert len(sections) == 2
assert str(sections[0].body[0].body[0].body[0].body[0]) ==\
'while(lock0[t2] == 0);'
body = op1._func_table['release_lock0'].root.body
for i in range(3):
assert 'lock0[t' in str(sections[1].body[0].body[0].body[1 + i]) # Set-lock
assert str(sections[1].body[0].body[0].body[4].body[-1]) ==\
'sdata0[wi0].flag = 2;'
assert len(op1._func_table) == 3
assert 'lock0[t' in str(body.body[1 + i]) # Set-lock
body = op1._func_table['activate0'].root.body
assert str(body.body[-1]) == 'sdata0[wi0].flag = 2;'
assert len(op1._func_table) == 5
exprs = FindNodes(Expression).visit(op1._func_table['copy_to_host0'].root)
b = 21 if configuration['language'] == 'openacc' else 20 # No `qid` w/ OMP
for i in range(3):
Expand Down Expand Up @@ -427,7 +418,7 @@ def test_streaming_basic(self, opt, ntmps):
op = Operator(eqn, opt=opt)

# Check generated code
assert len(op._func_table) == 7
assert len(op._func_table) == 9
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps

op.apply(time_M=nt-2)
Expand All @@ -436,7 +427,7 @@ def test_streaming_basic(self, opt, ntmps):
assert np.all(u.data[1] == 36)

@pytest.mark.parametrize('opt,ntmps', [
(('buffering', 'streaming', 'orchestrate'), 10),
(('buffering', 'streaming', 'orchestrate'), 11),
(('buffering', 'streaming', 'fuse', 'orchestrate', {'fuse-tasks': True}), 7),
])
def test_streaming_two_buffers(self, opt, ntmps):
Expand All @@ -456,7 +447,7 @@ def test_streaming_two_buffers(self, opt, ntmps):
op = Operator(eqn, opt=opt)

# Check generated code
assert len(op._func_table) == 7
assert len(op._func_table) == 9
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps

op.apply(time_M=nt-2)
Expand Down Expand Up @@ -592,7 +583,7 @@ def test_streaming_multi_input(self, opt, ntmps):
op1 = Operator(eqn, opt=opt)

# Check generated code
assert len(op1._func_table) == 7
assert len(op1._func_table) == 9
assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == ntmps

op0.apply(time_M=nt-2, dt=0.1)
Expand Down Expand Up @@ -679,7 +670,7 @@ def test_streaming_postponed_deletion(self, opt, ntmps):
op1 = Operator(eqns, opt=opt)

# Check generated code
assert len(op1._func_table) == 7
assert len(op1._func_table) == 9
assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == ntmps

op0.apply(time_M=nt-1)
Expand Down Expand Up @@ -748,7 +739,7 @@ def test_composite_buffering_tasking_multi_output(self):
assert len(retrieve_iteration_tree(op1)) == 6
assert len(retrieve_iteration_tree(op2)) == 4
symbols = FindSymbols().visit(op1)
assert len([i for i in symbols if isinstance(i, Lock)]) == 3
assert len([i for i in symbols if isinstance(i, Lock)]) == 4
threads = [i for i in symbols if isinstance(i, PThreadArray)]
assert len(threads) == 3
assert threads[0].size.size == async_degree
Expand All @@ -762,7 +753,7 @@ def test_composite_buffering_tasking_multi_output(self):
# It is true that the usave and vsave eqns are separated in two different
# loop nests, but they eventually get mapped to the same pair of efuncs,
# since devito attempts to maximize code reuse
assert len(op1._func_table) == 6
assert len(op1._func_table) == 8

# Check output
op0.apply(time_M=nt-1)
Expand Down Expand Up @@ -802,7 +793,7 @@ def test_composite_full_0(self):
assert len(retrieve_iteration_tree(op0)) == 1
assert len(retrieve_iteration_tree(op1)) == 3
symbols = FindSymbols().visit(op1)
assert len([i for i in symbols if isinstance(i, Lock)]) == 3
assert len([i for i in symbols if isinstance(i, Lock)]) == 4
threads = [i for i in symbols if isinstance(i, PThreadArray)]
assert len(threads) == 3
assert all(i.size == 1 for i in threads)
Expand Down Expand Up @@ -930,7 +921,7 @@ def test_save_multi_output(self):
# The `usave` and `vsave` eqns are in separate tasks, but the tasks
# are identical, so they get mapped to the same efuncs (init + copy)
# There also are two extra functions to allocate and free arrays
assert len(op._func_table) == 6
assert len(op._func_table) == 8

op.apply(time_M=nt-1)

Expand Down Expand Up @@ -986,7 +977,7 @@ def test_save_w_nonaffine_time(self):

# We just check the generated code here
assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1
assert len(op._func_table) == 6
assert len(op._func_table) == 8

def test_save_w_subdims(self):
nt = 10
Expand Down Expand Up @@ -1043,7 +1034,7 @@ def test_streaming_w_shifting(self, opt, ntmps):
op = Operator(eqns, opt=opt)

# Check generated code
assert len(op._func_table) == 7
assert len(op._func_table) == 9
assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps

# From time_m=15 to time_M=35 with a factor=5 -- it means that, thanks
Expand Down Expand Up @@ -1098,11 +1089,11 @@ def test_streaming_complete(self):
{'fuse-tasks': True}))

# Check generated code
assert len(op1._func_table) == 11
assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == 7
assert len(op2._func_table) == 11
assert len([i for i in FindSymbols().visit(op2) if i.is_Array]) == 7
assert len(op3._func_table) == 8
assert len(op1._func_table) == 14
assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == 8
assert len(op2._func_table) == 14
assert len([i for i in FindSymbols().visit(op2) if i.is_Array]) == 8
assert len(op3._func_table) == 10
assert len([i for i in FindSymbols().visit(op3) if i.is_Array]) == 7

op0.apply(time_m=15, time_M=35, save_shift=0)
Expand Down

0 comments on commit 2cbc4f4

Please sign in to comment.