Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Improve quality of generated code #2263

Merged
merged 18 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,18 @@ int Kernel(struct dataobj *restrict f_vec, struct dataobj *restrict g_vec, const
{
float (*restrict f) __attribute__ ((aligned (64))) = (float (*)) f_vec->data;
float (*restrict g) __attribute__ ((aligned (64))) = (float (*)) g_vec->data;

/* Flush denormal numbers to zero in hardware */
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
struct timeval start_section0, end_section0;
gettimeofday(&start_section0, NULL);
/* Begin section0 */

START(section0)
for (int x = x_m; x <= x_M; x += 1)
{
g[x + 8] = (3.57142857e-3F*(f[x + 4] - f[x + 12]) + 3.80952381e-2F*(-f[x + 5] + f[x + 11]) + 2.0e-1F*(f[x + 6] - f[x + 10]) + 8.0e-1F*(-f[x + 7] + f[x + 9]))/h_x;
}
/* End section0 */
gettimeofday(&end_section0, NULL);
timers->section0 += (double)(end_section0.tv_sec-start_section0.tv_sec)+(double)(end_section0.tv_usec-start_section0.tv_usec)/1000000;
STOP(section0,timers)

return 0;
}
```
Expand Down
2 changes: 1 addition & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, *functions, op=dv.mpi.MPI.SUM, dtype=None):
def __enter__(self):
i = dv.Dimension(name='mri',)
self.n = dv.Function(name='n', shape=(1,), dimensions=(i,),
grid=self.grid, dtype=self.dtype)
grid=self.grid, dtype=self.dtype, space='host')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not local?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this one really needs to be on the host as it will ultimately be accessed in user-land

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not exceptionally elegant, admittedly

self.n.data[0] = 0
return self

Expand Down
12 changes: 7 additions & 5 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,18 @@ def _normalize_gpu_fit(cls, **kwargs):
return cls.GPU_FIT

@classmethod
def _rcompile_wrapper(cls, **kwargs):
options = kwargs['options']
def _rcompile_wrapper(cls, **kwargs0):
options = kwargs0['options']

def wrapper(expressions, kwargs=kwargs, mode='default'):
def wrapper(expressions, mode='default', **kwargs1):
if mode == 'host':
kwargs = {
kwargs = {**{
'platform': 'cpu64',
'language': 'C' if options['par-disabled'] else 'openmp',
'compiler': 'custom',
}
}, **kwargs1}
else:
kwargs = {**kwargs0, **kwargs1}
return rcompile(expressions, kwargs)

return wrapper
Expand Down
52 changes: 47 additions & 5 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'make_callable',
'EntryFunction', 'AsyncCallable', 'AsyncCall', 'ThreadCallable',
'DeviceFunction', 'DeviceCall']
'DeviceFunction', 'DeviceCall', 'KernelLaunch', 'CommCallable']


# ElementalFunction machinery
Expand Down Expand Up @@ -105,7 +105,7 @@ def make_callable(name, iet, retval='void', prefix='static'):
"""
Utility function to create a Callable from an IET.
"""
parameters = derive_parameters(iet)
parameters = derive_parameters(iet, ordering='canonical')
return Callable(name, iet, retval, parameters=parameters, prefix=prefix)


Expand Down Expand Up @@ -137,8 +137,16 @@ class ThreadCallable(Callable):
A Callable executed asynchronously by a thread.
"""

def __init__(self, name, body, parameters=None, prefix='static'):
super().__init__(name, body, 'void*', parameters=parameters, prefix=prefix)
def __init__(self, name, body, parameters):
super().__init__(name, body, 'void*', parameters=parameters, prefix='static')

# Sanity checks
# By construction, the first unpack statement of a ThreadCallable must
# be the PointerCast that makes `sdata` available in the local scope
assert len(body.unpacks) > 0
v = body.unpacks[0]
assert v.is_PointerCast
self.sdata = v.function


# DeviceFunction machinery
Expand All @@ -157,7 +165,7 @@ def __init__(self, name, body, retval='void', parameters=None, prefix='__global_
class DeviceCall(Call):

"""
A call to an external function executed asynchronously on a device.
A call to a function executed asynchronously on a device.
"""

def __init__(self, name, arguments=None, **kwargs):
Expand All @@ -175,3 +183,37 @@ def __init__(self, name, arguments=None, **kwargs):
processed.append(a)

super().__init__(name, arguments=processed, **kwargs)


class KernelLaunch(DeviceCall):

"""
A call to an asynchronous device kernel.
"""

def __init__(self, name, grid, block, shm=0, stream=None,
arguments=None, writes=None):
super().__init__(name, arguments=arguments, writes=writes)

# Kernel launch arguments
self.grid = grid
self.block = block
self.shm = shm
self.stream = stream

def __repr__(self):
return 'Launch[%s]<<<(%s)>>>' % (self.name,
','.join(str(i.name) for i in self.writes))

@cached_property
def functions(self):
launch_args = (self.grid, self.block,)
if self.stream is not None:
launch_args += (self.stream.function,)
return super().functions + launch_args


# Other relevant Callable subclasses

class CommCallable(Callable):
pass
18 changes: 9 additions & 9 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Symbol)
from devito.types.object import AbstractObject, LocalObject

__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call',
__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt',
'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder',
'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle',
'AugmentedExpression', 'Increment', 'Return', 'While',
Expand Down Expand Up @@ -722,7 +722,7 @@ def all_parameters(self):
@property
def functions(self):
return tuple(i.function for i in self.all_parameters
if isinstance(i.function, AbstractFunction))
if isinstance(i.function, (AbstractFunction, AbstractObject)))

@property
def defines(self):
Expand Down Expand Up @@ -850,20 +850,20 @@ def __init__(self, timer, lname, body):
self._name = lname
self._timer = timer

super().__init__(header=c.Line('START_TIMER(%s)' % lname),
super().__init__(header=c.Line('START(%s)' % lname),
body=body,
footer=c.Line('STOP_TIMER(%s,%s)' % (lname, timer.name)))
footer=c.Line('STOP(%s,%s)' % (lname, timer.name)))

@classmethod
def _start_timer_header(cls):
return ('START_TIMER(S)', ('struct timeval start_ ## S , end_ ## S ; '
'gettimeofday(&start_ ## S , NULL);'))
return ('START(S)', ('struct timeval start_ ## S , end_ ## S ; '
'gettimeofday(&start_ ## S , NULL);'))

@classmethod
def _stop_timer_header(cls):
return ('STOP_TIMER(S,T)', ('gettimeofday(&end_ ## S, NULL); T->S += (double)'
'(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)'
'(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;'))
return ('STOP(S,T)', ('gettimeofday(&end_ ## S, NULL); T->S += (double)'
'(end_ ## S .tv_sec-start_ ## S.tv_sec)+(double)'
'(end_ ## S .tv_usec-start_ ## S .tv_usec)/1000000;'))

@property
def name(self):
Expand Down
13 changes: 12 additions & 1 deletion devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def filter_iterations(tree, key=lambda i: i):
return filtered


def derive_parameters(iet, drop_locals=False):
def derive_parameters(iet, drop_locals=False, ordering='default'):
"""
Derive all input parameters (function call arguments) from an IET
by collecting all symbols not defined in the tree itself.
"""
assert ordering in ('default', 'canonical')

# Extract all candidate parameters
candidates = FindSymbols().visit(iet)

Expand All @@ -122,6 +124,15 @@ def derive_parameters(iet, drop_locals=False):
if drop_locals:
parameters = [p for p in parameters if not (p.is_ArrayBasic or p.is_LocalObject)]

# NOTE: This is requested by the caller when the parameters are used to
# construct Callables whose signature only depends on the object types,
# rather than on their name
# TODO: It should maybe be done systematically... but it's gonna change a huge
# amount of tests and examples; plus, it might break compatibility those
# using devito as a library-generator to be embedded within legacy codes
if ordering == 'canonical':
parameters = sorted(parameters, key=lambda p: str(type(p)))

return parameters


Expand Down
89 changes: 65 additions & 24 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from devito.ir.support.space import Backward
from devito.symbolics import ListInitializer, ccode, uxreplace
from devito.tools import (GenericVisitor, as_tuple, ctypes_to_cstr, filter_ordered,
filter_sorted, flatten, is_external_ctype, c_restrict_void_p)
filter_sorted, flatten, is_external_ctype,
c_restrict_void_p, sorted_priority)
from devito.types.basic import AbstractFunction, Basic
from devito.types import (ArrayObject, CompositeObject, Dimension, Pointer,
IndexedData, DeviceMap)
Expand Down Expand Up @@ -224,7 +225,7 @@ def _gen_struct_decl(self, obj, masked=()):

def _gen_value(self, obj, level=2, masked=()):
qualifiers = [v for k, v in self._qualifiers_mapper.items()
if getattr(obj, k, False) and v not in masked]
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and level == 2:
strtype = obj._C_typedata
Expand All @@ -233,7 +234,8 @@ def _gen_value(self, obj, level=2, masked=()):
strtype = ctypes_to_cstr(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and level >= 1:
strtype = '%s%s' % (strtype, self._restrict_keyword)
if not obj._mem_stack:
strtype = '%s%s' % (strtype, self._restrict_keyword)
strtype = ' '.join(qualifiers + [strtype])

strname = obj._C_name
Expand Down Expand Up @@ -324,7 +326,9 @@ def _blankline_logic(self, children):
prev is ExpressionBundle and
all(i.dim.is_Stencil for i in g)):
rebuilt.extend(g)
elif prev in candidates and k in candidates:
elif (prev in candidates and k in candidates) or \
(prev is not None and k is Section) or \
(prev is Section):
rebuilt.append(BlankLine)
rebuilt.extend(g)
else:
Expand Down Expand Up @@ -375,7 +379,7 @@ def visit_PointerCast(self, o):
else:
rshape = '*'
lvalue = c.Value(i._C_typedata, '*%s' % v)
if o.alignment:
if o.alignment and f._data_alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

# rvalue
Expand Down Expand Up @@ -406,15 +410,13 @@ def visit_Dereference(self, o):
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment,
c.Value(i._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
)
lvalue = c.Value(i._C_typedata,
'(*restrict %s)%s' % (a0.name, shape))
else:
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment, c.Value(i._C_typedata, '*restrict %s' % a0.name)
)
lvalue = c.Value(i._C_typedata, '*restrict %s' % a0.name)
if a0._data_alignment:
lvalue = c.AlignedAttribute(a0._data_alignment, lvalue)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
lvalue = self._gen_value(a0, 0)
Expand All @@ -430,13 +432,7 @@ def visit_List(self, o):

def visit_Section(self, o):
body = flatten(self._visit(i) for i in o.children)
if o.is_subsection:
header = []
footer = []
else:
header = [c.Comment("Begin %s" % o.name)]
footer = [c.Comment("End %s" % o.name)]
return c.Module(header + body + footer)
return c.Module(body)

def visit_Return(self, o):
v = 'return'
Expand Down Expand Up @@ -576,6 +572,19 @@ def visit_HaloSpot(self, o):
body = flatten(self._visit(i) for i in o.children)
return c.Collection(body)

def visit_KernelLaunch(self, o):
arguments = self._args_call(o.arguments)
arguments = ','.join(arguments)

launch_args = [o.grid, o.block]
if o.shm is not None:
launch_args.append(o.shm)
if o.stream is not None:
launch_args.append(o.stream)
launch_config = ','.join(str(i) for i in launch_args)

return c.Statement('%s<<<%s>>>(%s)' % (o.name, launch_config, arguments))

# Operator-handle machinery

def _operator_includes(self, o):
Expand Down Expand Up @@ -625,10 +634,10 @@ def visit_Operator(self, o, mode='all'):
# Elemental functions
esigns = []
efuncs = [blankline]
for i in o._func_table.values():
if i.local:
esigns.append(self._gen_signature(i.root))
efuncs.extend([self._visit(i.root), blankline])
items = [i.root for i in o._func_table.values() if i.local]
for i in sorted_efuncs(items):
esigns.append(self._gen_signature(i))
efuncs.extend([self._visit(i), blankline])

# Definitions
headers = [c.Define(*i) for i in o._headers] + [blankline]
Expand Down Expand Up @@ -1177,13 +1186,19 @@ def visit_Conditional(self, o):
condition = uxreplace(o.condition, self.mapper)
then_body = self._visit(o.then_body)
else_body = self._visit(o.else_body)
return o._rebuild(condition=condition, then_body=then_body, else_body=else_body)
return o._rebuild(condition=condition, then_body=then_body,
else_body=else_body)

def visit_PointerCast(self, o):
function = self.mapper.get(o.function, o.function)
obj = self.mapper.get(o.obj, o.obj)
return o._rebuild(function=function, obj=obj)

def visit_Dereference(self, o):
pointee = self.mapper.get(o.pointee, o.pointee)
pointer = self.mapper.get(o.pointer, o.pointer)
return o._rebuild(pointee=pointee, pointer=pointer)

def visit_Pragma(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
return o._rebuild(arguments=arguments)
Expand All @@ -1200,8 +1215,21 @@ def visit_HaloSpot(self, o):
body = self._visit(o.body)
return o._rebuild(halo_scheme=halo_scheme, body=body)

def visit_While(self, o, **kwargs):
condition = uxreplace(o.condition, self.mapper)
body = self._visit(o.body)
return o._rebuild(condition=condition, body=body)

visit_ThreadedProdder = visit_Call

def visit_KernelLaunch(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
grid = self.mapper.get(o.grid, o.grid)
block = self.mapper.get(o.block, o.block)
stream = self.mapper.get(o.stream, o.stream)
return o._rebuild(grid=grid, block=block, stream=stream,
arguments=arguments)


# Utils

Expand Down Expand Up @@ -1253,3 +1281,16 @@ def generate(self):
if self.cast:
tip = '(%s)%s' % (self.cast, tip)
yield tip


def sorted_efuncs(efuncs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be simpler to define that in efunc to avoid the local import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

circular import, can't import efunc from visitors

from devito.ir.iet.efunc import (CommCallable, DeviceFunction,
ThreadCallable, ElementalFunction)

priority = {
DeviceFunction: 3,
ThreadCallable: 2,
ElementalFunction: 1,
CommCallable: 1
}
return sorted_priority(efuncs, priority)
Loading
Loading