Skip to content

Commit

Permalink
Merge pull request #2263 from devitocodes/better-codegen
Browse files Browse the repository at this point in the history
compiler: Improve quality of generated code
  • Loading branch information
mloubout authored Nov 21, 2023
2 parents 21ef56c + 4dfc228 commit 7a86b36
Show file tree
Hide file tree
Showing 38 changed files with 774 additions and 574 deletions.
11 changes: 5 additions & 6 deletions FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,18 @@ int Kernel(struct dataobj *restrict f_vec, struct dataobj *restrict g_vec, const
{
float (*restrict f) __attribute__ ((aligned (64))) = (float (*)) f_vec->data;
float (*restrict g) __attribute__ ((aligned (64))) = (float (*)) g_vec->data;

/* Flush denormal numbers to zero in hardware */
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
struct timeval start_section0, end_section0;
gettimeofday(&start_section0, NULL);
/* Begin section0 */

START(section0)
for (int x = x_m; x <= x_M; x += 1)
{
g[x + 8] = (3.57142857e-3F*(f[x + 4] - f[x + 12]) + 3.80952381e-2F*(-f[x + 5] + f[x + 11]) + 2.0e-1F*(f[x + 6] - f[x + 10]) + 8.0e-1F*(-f[x + 7] + f[x + 9]))/h_x;
}
/* End section0 */
gettimeofday(&end_section0, NULL);
timers->section0 += (double)(end_section0.tv_sec-start_section0.tv_sec)+(double)(end_section0.tv_usec-start_section0.tv_usec)/1000000;
STOP(section0,timers)

return 0;
}
```
Expand Down
2 changes: 1 addition & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None):
def __enter__(self):
i = dv.Dimension(name='mri',)
self.n = dv.Function(name='n', shape=(1,), dimensions=(i,),
grid=self.grid, dtype=self.dtype)
grid=self.grid, dtype=self.dtype, space='host')
self.n.data[0] = 0
return self

Expand Down
12 changes: 7 additions & 5 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,18 @@ def _normalize_gpu_fit(cls, **kwargs):
return cls.GPU_FIT

@classmethod
def _rcompile_wrapper(cls, **kwargs):
options = kwargs['options']
def _rcompile_wrapper(cls, **kwargs0):
options = kwargs0['options']

def wrapper(expressions, kwargs=kwargs, mode='default'):
def wrapper(expressions, mode='default', **kwargs1):
if mode == 'host':
kwargs = {
kwargs = {**{
'platform': 'cpu64',
'language': 'C' if options['par-disabled'] else 'openmp',
'compiler': 'custom',
}
}, **kwargs1}
else:
kwargs = {**kwargs0, **kwargs1}
return rcompile(expressions, kwargs)

return wrapper
Expand Down
52 changes: 47 additions & 5 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'make_callable',
'EntryFunction', 'AsyncCallable', 'AsyncCall', 'ThreadCallable',
'DeviceFunction', 'DeviceCall']
'DeviceFunction', 'DeviceCall', 'KernelLaunch', 'CommCallable']


# ElementalFunction machinery
Expand Down Expand Up @@ -105,7 +105,7 @@ def make_callable(name, iet, retval='void', prefix='static'):
"""
Utility function to create a Callable from an IET.
"""
parameters = derive_parameters(iet)
parameters = derive_parameters(iet, ordering='canonical')
return Callable(name, iet, retval, parameters=parameters, prefix=prefix)


Expand Down Expand Up @@ -137,8 +137,16 @@ class ThreadCallable(Callable):
A Callable executed asynchronously by a thread.
"""

def __init__(self, name, body, parameters=None, prefix='static'):
super().__init__(name, body, 'void*', parameters=parameters, prefix=prefix)
def __init__(self, name, body, parameters):
super().__init__(name, body, 'void*', parameters=parameters, prefix='static')

# Sanity checks
# By construction, the first unpack statement of a ThreadCallable must
# be the PointerCast that makes `sdata` available in the local scope
assert len(body.unpacks) > 0
v = body.unpacks[0]
assert v.is_PointerCast
self.sdata = v.function


# DeviceFunction machinery
Expand All @@ -157,7 +165,7 @@ def __init__(self, name, body, retval='void', parameters=None, prefix='__global_
class DeviceCall(Call):

"""
A call to an external function executed asynchronously on a device.
A call to a function executed asynchronously on a device.
"""

def __init__(self, name, arguments=None, **kwargs):
Expand All @@ -175,3 +183,37 @@ def __init__(self, name, arguments=None, **kwargs):
processed.append(a)

super().__init__(name, arguments=processed, **kwargs)


class KernelLaunch(DeviceCall):

"""
A call to an asynchronous device kernel.
"""

def __init__(self, name, grid, block, shm=0, stream=None,
arguments=None, writes=None):
super().__init__(name, arguments=arguments, writes=writes)

# Kernel launch arguments
self.grid = grid
self.block = block
self.shm = shm
self.stream = stream

def __repr__(self):
return 'Launch[%s]<<<(%s)>>>' % (self.name,
','.join(str(i.name) for i in self.writes))

@cached_property
def functions(self):
launch_args = (self.grid, self.block,)
if self.stream is not None:
launch_args += (self.stream.function,)
return super().functions + launch_args


# Other relevant Callable subclasses

class CommCallable(Callable):
pass
18 changes: 9 additions & 9 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Symbol)
from devito.types.object import AbstractObject, LocalObject

__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call',
__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt',
'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder',
'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle',
'AugmentedExpression', 'Increment', 'Return', 'While',
Expand Down Expand Up @@ -722,7 +722,7 @@ def all_parameters(self):
@property
def functions(self):
return tuple(i.function for i in self.all_parameters
if isinstance(i.function, AbstractFunction))
if isinstance(i.function, (AbstractFunction, AbstractObject)))

@property
def defines(self):
Expand Down Expand Up @@ -850,20 +850,20 @@ def __init__(self, timer, lname, body):
self._name = lname
self._timer = timer

super().__init__(header=c.Line('START_TIMER(%s)' % lname),
super().__init__(header=c.Line('START(%s)' % lname),
body=body,
footer=c.Line('STOP_TIMER(%s,%s)' % (lname, timer.name)))
footer=c.Line('STOP(%s,%s)' % (lname, timer.name)))

@classmethod
def _start_timer_header(cls):
return ('START_TIMER(S)', ('struct timeval start_ ## S , end_ ## S ; '
'gettimeofday(&start_ ## S , NULL);'))
return ('START(S)', ('struct timeval start_ ## S , end_ ## S ; '
'gettimeofday(&start_ ## S , NULL);'))

@classmethod
def _stop_timer_header(cls):
return ('STOP_TIMER(S,T)', ('gettimeofday(&end_ ## S, NULL); T->S += (double)'
'(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)'
'(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;'))
return ('STOP(S,T)', ('gettimeofday(&end_ ## S, NULL); T->S += (double)'
'(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)'
'(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;'))

@property
def name(self):
Expand Down
13 changes: 12 additions & 1 deletion devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def filter_iterations(tree, key=lambda i: i):
return filtered


def derive_parameters(iet, drop_locals=False):
def derive_parameters(iet, drop_locals=False, ordering='default'):
"""
Derive all input parameters (function call arguments) from an IET
by collecting all symbols not defined in the tree itself.
"""
assert ordering in ('default', 'canonical')

# Extract all candidate parameters
candidates = FindSymbols().visit(iet)

Expand All @@ -122,6 +124,15 @@ def derive_parameters(iet, drop_locals=False):
if drop_locals:
parameters = [p for p in parameters if not (p.is_ArrayBasic or p.is_LocalObject)]

# NOTE: This is requested by the caller when the parameters are used to
# construct Callables whose signature only depends on the object types,
# rather than on their name
# TODO: It should maybe be done systematically... but it's gonna change a huge
# amount of tests and examples; plus, it might break compatibility those
# using devito as a library-generator to be embedded within legacy codes
if ordering == 'canonical':
parameters = sorted(parameters, key=lambda p: str(type(p)))

return parameters


Expand Down
89 changes: 65 additions & 24 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from devito.ir.support.space import Backward
from devito.symbolics import ListInitializer, ccode, uxreplace
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype, c_restrict_void_p)
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)
Expand Down Expand Up @@ -224,7 +225,7 @@ def _gen_struct_decl(self, obj, masked=()):

def _gen_value(self, obj, level=2, masked=()):
qualifiers = [v for k, v in self._qualifiers_mapper.items()
if getattr(obj, k, False) and v not in masked]
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and level == 2:
strtype = obj._C_typedata
Expand All @@ -233,7 +234,8 @@ def _gen_value(self, obj, level=2, masked=()):
strtype = ctypes_to_cstr(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and level >= 1:
strtype = '%s%s' % (strtype, self._restrict_keyword)
if not obj._mem_stack:
strtype = '%s%s' % (strtype, self._restrict_keyword)
strtype = ' '.join(qualifiers + [strtype])

strname = obj._C_name
Expand Down Expand Up @@ -324,7 +326,9 @@ def _blankline_logic(self, children):
prev is ExpressionBundle and
all(i.dim.is_Stencil for i in g)):
rebuilt.extend(g)
elif prev in candidates and k in candidates:
elif (prev in candidates and k in candidates) or \
(prev is not None and k is Section) or \
(prev is Section):
rebuilt.append(BlankLine)
rebuilt.extend(g)
else:
Expand Down Expand Up @@ -375,7 +379,7 @@ def visit_PointerCast(self, o):
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
if o.alignment:
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

# rvalue
Expand Down Expand Up @@ -406,15 +410,13 @@ def visit_Dereference(self, o):
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment,
c.Value(i._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment, c.Value(i._C_typedata, '*restrict %s' % a0.name)
)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand All @@ -430,13 +432,7 @@ def visit_List(self, o):

def visit_Section(self, o):
body = flatten(self._visit(i) for i in o.children)
if o.is_subsection:
header = []
footer = []
else:
header = [c.Comment("Begin %s" % o.name)]
footer = [c.Comment("End %s" % o.name)]
return c.Module(header + body + footer)
return c.Module(body)

def visit_Return(self, o):
v = 'return'
Expand Down Expand Up @@ -576,6 +572,19 @@ def visit_HaloSpot(self, o):
body = flatten(self._visit(i) for i in o.children)
return c.Collection(body)

def visit_KernelLaunch(self, o):
arguments = self._args_call(o.arguments)
arguments = ','.join(arguments)

launch_args = [o.grid, o.block]
if o.shm is not None:
launch_args.append(o.shm)
if o.stream is not None:
launch_args.append(o.stream)
launch_config = ','.join(str(i) for i in launch_args)

return c.Statement('%s<<<%s>>>(%s)' % (o.name, launch_config, arguments))

# Operator-handle machinery

def _operator_includes(self, o):
Expand Down Expand Up @@ -625,10 +634,10 @@ def visit_Operator(self, o, mode='all'):
# Elemental functions
esigns = []
efuncs = [blankline]
for i in o._func_table.values():
if i.local:
esigns.append(self._gen_signature(i.root))
efuncs.extend([self._visit(i.root), blankline])
items = [i.root for i in o._func_table.values() if i.local]
for i in sorted_efuncs(items):
esigns.append(self._gen_signature(i))
efuncs.extend([self._visit(i), blankline])

# Definitions
headers = [c.Define(*i) for i in o._headers] + [blankline]
Expand Down Expand Up @@ -1177,13 +1186,19 @@ def visit_Conditional(self, o):
condition = uxreplace(o.condition, self.mapper)
then_body = self._visit(o.then_body)
else_body = self._visit(o.else_body)
return o._rebuild(condition=condition, then_body=then_body, else_body=else_body)
return o._rebuild(condition=condition, then_body=then_body,
else_body=else_body)

def visit_PointerCast(self, o):
function = self.mapper.get(o.function, o.function)
obj = self.mapper.get(o.obj, o.obj)
return o._rebuild(function=function, obj=obj)

def visit_Dereference(self, o):
pointee = self.mapper.get(o.pointee, o.pointee)
pointer = self.mapper.get(o.pointer, o.pointer)
return o._rebuild(pointee=pointee, pointer=pointer)

def visit_Pragma(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
return o._rebuild(arguments=arguments)
Expand All @@ -1200,8 +1215,21 @@ def visit_HaloSpot(self, o):
body = self._visit(o.body)
return o._rebuild(halo_scheme=halo_scheme, body=body)

def visit_While(self, o, **kwargs):
condition = uxreplace(o.condition, self.mapper)
body = self._visit(o.body)
return o._rebuild(condition=condition, body=body)

visit_ThreadedProdder = visit_Call

def visit_KernelLaunch(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
grid = self.mapper.get(o.grid, o.grid)
block = self.mapper.get(o.block, o.block)
stream = self.mapper.get(o.stream, o.stream)
return o._rebuild(grid=grid, block=block, stream=stream,
arguments=arguments)


# Utils

Expand Down Expand Up @@ -1253,3 +1281,16 @@ def generate(self):
if self.cast:
tip = '(%s)%s' % (self.cast, tip)
yield tip


def sorted_efuncs(efuncs):
from devito.ir.iet.efunc import (CommCallable, DeviceFunction,
ThreadCallable, ElementalFunction)

priority = {
DeviceFunction: 3,
ThreadCallable: 2,
ElementalFunction: 1,
CommCallable: 1
}
return sorted_priority(efuncs, priority)
Loading

0 comments on commit 7a86b36

Please sign in to comment.