diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index a711d4b7479..e340a06d769 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -158,8 +158,10 @@ class DeviceFunction(Callable): A Callable executed asynchronously on a device. """ - def __init__(self, name, body, retval='void', parameters=None, prefix='__global__'): - super().__init__(name, body, retval, parameters=parameters, prefix=prefix) + def __init__(self, name, body, retval='void', parameters=None, + prefix='__global__', templates=None): + super().__init__(name, body, retval, parameters=parameters, prefix=prefix, + templates=templates) class DeviceCall(Call): @@ -192,8 +194,9 @@ class KernelLaunch(DeviceCall): """ def __init__(self, name, grid, block, shm=0, stream=None, - arguments=None, writes=None): - super().__init__(name, arguments=arguments, writes=writes) + arguments=None, writes=None, templates=None): + super().__init__(name, arguments=arguments, writes=writes, + templates=templates) # Kernel launch arguments self.grid = grid diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 36a7a2f8361..56d6e124754 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -250,9 +250,9 @@ class Call(ExprStmt, Node): Defaults to False. writes : list, optional The AbstractFunctions that will be written to by the called function. - Explicitly tagging these AbstractFunctions is useful in the case of external - calls, that is whenever the compiler would be unable to retrieve that - information by analysis of the IET graph. + Explicitly tagging these AbstractFunctions is useful in the case of + external calls, that is whenever the compiler would be unable to + retrieve that information by analysis of the IET graph. templates : list of Basic, optional The template arguments of the Call. """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index d94f2f1d78e..c043d08cdb8 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -313,13 +313,16 @@ def _args_call(self, args): ret.append(ccode(i)) return ret - def _gen_signature(self, o): + def _gen_signature(self, o, is_declaration=False): decls = self._args_decl(o.parameters) prefix = ' '.join(o.prefix + (self._gen_rettype(o.retval),)) signature = c.FunctionDeclaration(c.Value(prefix, o.name), decls) if o.templates: tparams = ', '.join([i.inline() for i in self._args_decl(o.templates)]) - signature = c.Template(tparams, signature) + if is_declaration: + signature = TemplateDecl(tparams, signature) + else: + signature = c.Template(tparams, signature) return signature def _blankline_logic(self, children): @@ -614,8 +617,10 @@ def visit_HaloSpot(self, o): return c.Collection(body) def visit_KernelLaunch(self, o): - arguments = self._args_call(o.arguments) - arguments = ','.join(arguments) + if o.templates: + templates = '<%s>' % ','.join([str(i) for i in o.templates]) + else: + templates = '' launch_args = [o.grid, o.block] if o.shm is not None: @@ -624,7 +629,11 @@ def visit_KernelLaunch(self, o): launch_args.append(o.stream) launch_config = ','.join(str(i) for i in launch_args) - return c.Statement('%s<<<%s>>>(%s)' % (o.name, launch_config, arguments)) + arguments = self._args_call(o.arguments) + arguments = ','.join(arguments) + + return c.Statement('%s%s<<<%s>>>(%s)' + % (o.name, templates, launch_config, arguments)) # Operator-handle machinery @@ -678,7 +687,7 @@ def visit_Operator(self, o, mode='all'): efuncs = [blankline] items = [i.root for i in o._func_table.values() if i.local] for i in sorted_efuncs(items): - esigns.append(self._gen_signature(i)) + esigns.append(self._gen_signature(i, is_declaration=True)) efuncs.extend([self._visit(i), blankline]) # Definitions @@ -1358,6 +1367,14 @@ def generate(self): yield tip +class TemplateDecl(c.Template): + + # Workaround to generate ';' at the end + + def generate(self): + return super().generate(with_semicolon=True) + + def sorted_efuncs(efuncs): from devito.ir.iet.efunc import (CommCallable, DeviceFunction, ThreadCallable, ElementalFunction)