diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index a711d4b747..a8450bc4aa 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -158,8 +158,10 @@ class DeviceFunction(Callable): A Callable executed asynchronously on a device. """ - def __init__(self, name, body, retval='void', parameters=None, prefix='__global__'): - super().__init__(name, body, retval, parameters=parameters, prefix=prefix) + def __init__(self, name, body, retval='void', parameters=None, + prefix='__global__', templates=None): + super().__init__(name, body, retval, parameters=parameters, prefix=prefix, + templates=templates) class DeviceCall(Call): @@ -171,7 +173,7 @@ class DeviceCall(Call): def __init__(self, name, arguments=None, **kwargs): # Explicitly convert host pointers into device pointers processed = [] - for a in arguments: + for a in as_tuple(arguments): try: f = a.function except AttributeError: @@ -192,8 +194,9 @@ class KernelLaunch(DeviceCall): """ def __init__(self, name, grid, block, shm=0, stream=None, - arguments=None, writes=None): - super().__init__(name, arguments=arguments, writes=writes) + arguments=None, writes=None, templates=None): + super().__init__(name, arguments=arguments, writes=writes, + templates=templates) # Kernel launch arguments self.grid = grid diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 36a7a2f836..56d6e12475 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -250,9 +250,9 @@ class Call(ExprStmt, Node): Defaults to False. writes : list, optional The AbstractFunctions that will be written to by the called function. - Explicitly tagging these AbstractFunctions is useful in the case of external - calls, that is whenever the compiler would be unable to retrieve that - information by analysis of the IET graph. + Explicitly tagging these AbstractFunctions is useful in the case of + external calls, that is whenever the compiler would be unable to + retrieve that information by analysis of the IET graph. templates : list of Basic, optional The template arguments of the Call. """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index d94f2f1d78..c043d08cdb 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -313,13 +313,16 @@ def _args_call(self, args): ret.append(ccode(i)) return ret - def _gen_signature(self, o): + def _gen_signature(self, o, is_declaration=False): decls = self._args_decl(o.parameters) prefix = ' '.join(o.prefix + (self._gen_rettype(o.retval),)) signature = c.FunctionDeclaration(c.Value(prefix, o.name), decls) if o.templates: tparams = ', '.join([i.inline() for i in self._args_decl(o.templates)]) - signature = c.Template(tparams, signature) + if is_declaration: + signature = TemplateDecl(tparams, signature) + else: + signature = c.Template(tparams, signature) return signature def _blankline_logic(self, children): @@ -614,8 +617,10 @@ def visit_HaloSpot(self, o): return c.Collection(body) def visit_KernelLaunch(self, o): - arguments = self._args_call(o.arguments) - arguments = ','.join(arguments) + if o.templates: + templates = '<%s>' % ','.join([str(i) for i in o.templates]) + else: + templates = '' launch_args = [o.grid, o.block] if o.shm is not None: @@ -624,7 +629,11 @@ def visit_KernelLaunch(self, o): launch_args.append(o.stream) launch_config = ','.join(str(i) for i in launch_args) - return c.Statement('%s<<<%s>>>(%s)' % (o.name, launch_config, arguments)) + arguments = self._args_call(o.arguments) + arguments = ','.join(arguments) + + return c.Statement('%s%s<<<%s>>>(%s)' + % (o.name, templates, launch_config, arguments)) # Operator-handle machinery @@ -678,7 +687,7 @@ def visit_Operator(self, o, mode='all'): efuncs = [blankline] items = [i.root for i in o._func_table.values() if i.local] for i in sorted_efuncs(items): - esigns.append(self._gen_signature(i)) + esigns.append(self._gen_signature(i, is_declaration=True)) efuncs.extend([self._visit(i), blankline]) # Definitions @@ -1358,6 +1367,14 @@ def generate(self): yield tip +class TemplateDecl(c.Template): + + # Workaround to generate ';' at the end + + def generate(self): + return super().generate(with_semicolon=True) + + def sorted_efuncs(efuncs): from devito.ir.iet.efunc import (CommCallable, DeviceFunction, ThreadCallable, ElementalFunction) diff --git a/tests/test_iet.py b/tests/test_iet.py index 41bae75821..ff963d5518 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -7,10 +7,10 @@ from devito import (Eq, Grid, Function, TimeFunction, Operator, Dimension, # noqa switchconfig) -from devito.ir.iet import (Call, Callable, Conditional, DummyExpr, Iteration, List, - KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, - filter_iterations, make_efunc, retrieve_iteration_tree, - Transformer) +from devito.ir.iet import (Call, Callable, Conditional, DeviceCall, DummyExpr, + Iteration, List, KernelLaunch, Lambda, ElementalFunction, + CGen, FindSymbols, filter_iterations, make_efunc, + retrieve_iteration_tree, Transformer) from devito.ir import SymbolRegistry from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager @@ -428,13 +428,14 @@ def test_templates_callable(): }""" -def test_templates_call(): +@pytest.mark.parametrize('cls', [Call, DeviceCall]) +def test_templates_call(cls): grid = Grid(shape=(10, 10)) x, y = grid.dimensions u = Function(name='u', grid=grid) - foo = Call('foo', u, templates=[Class('a'), Class('b')]) + foo = cls('foo', u, templates=[Class('a'), Class('b')]) assert str(foo) == "foo(u_vec);" @@ -443,14 +444,15 @@ def test_kernel_launch(): grid = Grid(shape=(10, 10)) u = Function(name='u', grid=grid) + s = Symbol(name='s') class Dim3(LocalObject): dtype = type('dim3', (c_void_p,), {}) kl = KernelLaunch('mykernel', Dim3('mygrid'), Dim3('myblock'), - arguments=(u.indexed,)) + arguments=(u.indexed,), templates=[s]) - assert str(kl) == 'mykernel<<>>(d_u);' + assert str(kl) == 'mykernel<<>>(d_u);' def test_codegen_quality0():