diff --git a/devito/passes/iet/asynchrony.py b/devito/passes/iet/asynchrony.py index 01635945cc6..29b9cce8a4c 100644 --- a/devito/passes/iet/asynchrony.py +++ b/devito/passes/iet/asynchrony.py @@ -12,7 +12,7 @@ from devito.symbolics import (CondEq, CondNe, FieldFromComposite, FieldFromPointer, Null) from devito.tools import split -from devito.types import (Lock, Pointer, PThreadArray, QueueID, SharedData, Symbol, +from devito.types import (Lock, Pointer, PThreadArray, QueueID, SharedData, Temp, VolatileInt) __all__ = ['pthreadify'] @@ -194,7 +194,7 @@ def lower_async_calls(iet, track=None, sregistry=None): condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1) activation = [While(condition)] else: - d = Symbol(name=sregistry.make_name(prefix=threads.index.name)) + d = Temp(name=sregistry.make_name(prefix=threads.index.name)) condition = CondNe(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 1) activation = [DummyExpr(d, 0), While(condition, DummyExpr(d, (d + 1) % threads.size))] @@ -203,12 +203,9 @@ def lower_async_calls(iet, track=None, sregistry=None): activation.append( DummyExpr(FieldFromComposite(sdata.symbolic_flag, sdata[d]), 2) ) - activation = List( - header=[c.Line(), c.Comment("Activate `%s`" % threads.name)], - body=activation, - footer=c.Line() - ) - mapper[n] = activation + name = sregistry.make_name(prefix='activate') + efunc = efuncs[name] = make_callable(name, activation) + mapper[n] = Call(name, efunc.parameters) if mapper: # Inject activation diff --git a/devito/passes/iet/orchestration.py b/devito/passes/iet/orchestration.py index 024417f2e6a..39fd286f1ae 100644 --- a/devito/passes/iet/orchestration.py +++ b/devito/passes/iet/orchestration.py @@ -35,10 +35,9 @@ def __init__(self, sregistry): def _make_waitlock(self, iet, sync_ops, *args): waitloop = List( - header=c.Comment("Wait for `%s` to be copied to the host" % + header=c.Comment("Wait for `%s` to be transferred" % ",".join(s.target.name for s in sync_ops)), body=BusyWait(Or(*[CondEq(s.handle, 0) for s in sync_ops])), - footer=c.Line() ) iet = List(body=(waitloop,) + iet.body) @@ -50,12 +49,13 @@ def _make_releaselock(self, iet, sync_ops, *args): pre.append(BusyWait(Or(*[CondNe(s.handle, 2) for s in sync_ops]))) pre.extend(DummyExpr(s.handle, 0) for s in sync_ops) - iet = List( - header=c.Comment("Release lock(s) as soon as possible"), - body=pre + [iet] - ) + name = self.sregistry.make_name(prefix="release_lock") + parameters = derive_parameters(pre, ordering='canonical') + efunc = Callable(name, pre, 'void', parameters, 'static') - return iet, [] + iet = List(body=[Call(name, efunc.parameters)] + [iet]) + + return iet, [efunc] def _make_withlock(self, iet, sync_ops, layer): body, prefix = withlock(layer, iet, sync_ops, self.lang, self.sregistry) diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 1e2f6249e48..2c34582634c 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -175,9 +175,10 @@ def test_tasking_in_isolation(self, opt): assert len(sections) == 3 assert str(sections[0].body[0].body[0].body[0].body[0]) == 'while(lock0[0] == 0);' body = sections[2].body[0].body[0] + body = op._func_table['release_lock0'].root.body assert str(body.body[0].condition) == 'Ne(lock0[0], 2)' assert str(body.body[1]) == 'lock0[0] = 0;' - body = body.body[2] + body = op._func_table['activate0'].root.body assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)' assert str(body.body[1]) == 'sdata0[0].time = time;' assert str(body.body[2]) == 'sdata0[0].flag = 2;' @@ -207,7 +208,7 @@ def test_tasking_fused(self): assert len(retrieve_iteration_tree(op)) == 3 locks = [i for i in FindSymbols().visit(op) if isinstance(i, Lock)] assert len(locks) == 1 # Only 1 because it's only `tmp` that needs protection - assert len(op._func_table) == 3 + assert len(op._func_table) == 5 exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root) b = 17 if configuration['language'] == 'openacc' else 16 # No `qid` w/ OMP assert str(exprs[b]) == 'const int deviceid = sdata->deviceid;' @@ -244,26 +245,15 @@ def test_tasking_unfused_two_locks(self): # Check generated code assert len(retrieve_iteration_tree(op)) == 3 - assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 3 + assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 4 sections = FindNodes(Section).visit(op) assert len(sections) == 4 assert (str(sections[1].body[0].body[0].body[0].body[0]) == 'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock body = sections[2].body[0].body[0] - assert str(body.body[0].condition) == 'Ne(lock0[0], 2)' - assert str(body.body[1]) == 'lock0[0] = 0;' # Set-lock - body = body.body[2] - assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)' # Wait-thread - assert str(body.body[1]) == 'sdata0[0].time = time;' - assert str(body.body[2]) == 'sdata0[0].flag = 2;' - body = sections[3].body[0].body[0] - assert str(body.body[0].condition) == 'Ne(lock1[0], 2)' - assert str(body.body[1]) == 'lock1[0] = 0;' # Set-lock - body = body.body[2] - assert str(body.body[0].condition) == 'Ne(sdata1[0].flag, 1)' # Wait-thread - assert str(body.body[1]) == 'sdata1[0].time = time;' - assert str(body.body[2]) == 'sdata1[0].flag = 2;' - assert len(op._func_table) == 3 + assert str(body.body[0]) == 'release_lock0(lock0);' + assert str(body.body[1]) == 'activate0(time,sdata0);' + assert len(op._func_table) == 5 exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root) b = 18 if configuration['language'] == 'openacc' else 17 # No `qid` w/ OMP assert str(exprs[b]) == 'lock0[0] = 1;' @@ -300,15 +290,15 @@ def test_tasking_forcefuse(self): assert len(sections) == 3 assert (str(sections[1].body[0].body[0].body[0].body[0]) == 'while(lock0[0] == 0 || lock1[0] == 0);') # Wait-lock - body = sections[2].body[0].body[0] + body = op._func_table['release_lock0'].root.body assert str(body.body[0].condition) == 'Ne(lock0[0], 2) | Ne(lock1[0], 2)' assert str(body.body[1]) == 'lock0[0] = 0;' # Set-lock assert str(body.body[2]) == 'lock1[0] = 0;' # Set-lock - body = body.body[3] + body = op._func_table['activate0'].root.body assert str(body.body[0].condition) == 'Ne(sdata0[0].flag, 1)' # Wait-thread assert str(body.body[1]) == 'sdata0[0].time = time;' assert str(body.body[2]) == 'sdata0[0].flag = 2;' - assert len(op._func_table) == 3 + assert len(op._func_table) == 5 exprs = FindNodes(Expression).visit(op._func_table['copy_to_host0'].root) b = 21 if configuration['language'] == 'openacc' else 20 # No `qid` w/ OMP assert str(exprs[b]) == 'lock0[0] = 1;' @@ -369,11 +359,12 @@ def test_tasking_multi_output(self): assert len(sections) == 2 assert str(sections[0].body[0].body[0].body[0].body[0]) ==\ 'while(lock0[t2] == 0);' + body = op1._func_table['release_lock0'].root.body for i in range(3): - assert 'lock0[t' in str(sections[1].body[0].body[0].body[1 + i]) # Set-lock - assert str(sections[1].body[0].body[0].body[4].body[-1]) ==\ - 'sdata0[wi0].flag = 2;' - assert len(op1._func_table) == 3 + assert 'lock0[t' in str(body.body[1 + i]) # Set-lock + body = op1._func_table['activate0'].root.body + assert str(body.body[-1]) == 'sdata0[wi0].flag = 2;' + assert len(op1._func_table) == 5 exprs = FindNodes(Expression).visit(op1._func_table['copy_to_host0'].root) b = 21 if configuration['language'] == 'openacc' else 20 # No `qid` w/ OMP for i in range(3): @@ -427,7 +418,7 @@ def test_streaming_basic(self, opt, ntmps): op = Operator(eqn, opt=opt) # Check generated code - assert len(op._func_table) == 7 + assert len(op._func_table) == 9 assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps op.apply(time_M=nt-2) @@ -436,7 +427,7 @@ def test_streaming_basic(self, opt, ntmps): assert np.all(u.data[1] == 36) @pytest.mark.parametrize('opt,ntmps', [ - (('buffering', 'streaming', 'orchestrate'), 10), + (('buffering', 'streaming', 'orchestrate'), 11), (('buffering', 'streaming', 'fuse', 'orchestrate', {'fuse-tasks': True}), 7), ]) def test_streaming_two_buffers(self, opt, ntmps): @@ -456,7 +447,7 @@ def test_streaming_two_buffers(self, opt, ntmps): op = Operator(eqn, opt=opt) # Check generated code - assert len(op._func_table) == 7 + assert len(op._func_table) == 9 assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps op.apply(time_M=nt-2) @@ -592,7 +583,7 @@ def test_streaming_multi_input(self, opt, ntmps): op1 = Operator(eqn, opt=opt) # Check generated code - assert len(op1._func_table) == 7 + assert len(op1._func_table) == 9 assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == ntmps op0.apply(time_M=nt-2, dt=0.1) @@ -679,7 +670,7 @@ def test_streaming_postponed_deletion(self, opt, ntmps): op1 = Operator(eqns, opt=opt) # Check generated code - assert len(op1._func_table) == 7 + assert len(op1._func_table) == 9 assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == ntmps op0.apply(time_M=nt-1) @@ -748,7 +739,7 @@ def test_composite_buffering_tasking_multi_output(self): assert len(retrieve_iteration_tree(op1)) == 6 assert len(retrieve_iteration_tree(op2)) == 4 symbols = FindSymbols().visit(op1) - assert len([i for i in symbols if isinstance(i, Lock)]) == 3 + assert len([i for i in symbols if isinstance(i, Lock)]) == 4 threads = [i for i in symbols if isinstance(i, PThreadArray)] assert len(threads) == 3 assert threads[0].size.size == async_degree @@ -762,7 +753,7 @@ def test_composite_buffering_tasking_multi_output(self): # It is true that the usave and vsave eqns are separated in two different # loop nests, but they eventually get mapped to the same pair of efuncs, # since devito attempts to maximize code reuse - assert len(op1._func_table) == 6 + assert len(op1._func_table) == 8 # Check output op0.apply(time_M=nt-1) @@ -802,7 +793,7 @@ def test_composite_full_0(self): assert len(retrieve_iteration_tree(op0)) == 1 assert len(retrieve_iteration_tree(op1)) == 3 symbols = FindSymbols().visit(op1) - assert len([i for i in symbols if isinstance(i, Lock)]) == 3 + assert len([i for i in symbols if isinstance(i, Lock)]) == 4 threads = [i for i in symbols if isinstance(i, PThreadArray)] assert len(threads) == 3 assert all(i.size == 1 for i in threads) @@ -930,7 +921,7 @@ def test_save_multi_output(self): # The `usave` and `vsave` eqns are in separate tasks, but the tasks # are identical, so they get mapped to the same efuncs (init + copy) # There also are two extra functions to allocate and free arrays - assert len(op._func_table) == 6 + assert len(op._func_table) == 8 op.apply(time_M=nt-1) @@ -986,7 +977,7 @@ def test_save_w_nonaffine_time(self): # We just check the generated code here assert len([i for i in FindSymbols().visit(op) if isinstance(i, Lock)]) == 1 - assert len(op._func_table) == 6 + assert len(op._func_table) == 8 def test_save_w_subdims(self): nt = 10 @@ -1043,7 +1034,7 @@ def test_streaming_w_shifting(self, opt, ntmps): op = Operator(eqns, opt=opt) # Check generated code - assert len(op._func_table) == 7 + assert len(op._func_table) == 9 assert len([i for i in FindSymbols().visit(op) if i.is_Array]) == ntmps # From time_m=15 to time_M=35 with a factor=5 -- it means that, thanks @@ -1098,11 +1089,11 @@ def test_streaming_complete(self): {'fuse-tasks': True})) # Check generated code - assert len(op1._func_table) == 11 - assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == 7 - assert len(op2._func_table) == 11 - assert len([i for i in FindSymbols().visit(op2) if i.is_Array]) == 7 - assert len(op3._func_table) == 8 + assert len(op1._func_table) == 14 + assert len([i for i in FindSymbols().visit(op1) if i.is_Array]) == 8 + assert len(op2._func_table) == 14 + assert len([i for i in FindSymbols().visit(op2) if i.is_Array]) == 8 + assert len(op3._func_table) == 10 assert len([i for i in FindSymbols().visit(op3) if i.is_Array]) == 7 op0.apply(time_m=15, time_M=35, save_shift=0)