Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Infer type generation from _C_ctype #1971

Merged
merged 13 commits into from
Jul 28, 2022
Merged
17 changes: 14 additions & 3 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from devito.ir.iet.utils import derive_parameters
from devito.tools import as_tuple

__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'EntryFunction',
'AsyncCallable', 'AsyncCall', 'ThreadCallable', 'DeviceFunction',
'DeviceCall']
__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'make_callable',
'EntryFunction', 'AsyncCallable', 'AsyncCall', 'ThreadCallable',
'DeviceFunction', 'DeviceCall']


# ElementalFunction machinery
Expand Down Expand Up @@ -89,6 +89,17 @@ def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static
dynamic_parameters)


# Callable machinery


def make_callable(name, iet, retval='void', prefix='static'):
"""
Utility function to create a Callable from an IET.
"""
parameters = derive_parameters(iet)
return Callable(name, iet, retval, parameters=parameters, prefix=prefix)


# EntryFunction machinery

class EntryFunction(Callable):
Expand Down
38 changes: 15 additions & 23 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
Property, Forward, detect_io)
from devito.symbolics import ListInitializer, CallFromPointer, ccode
from devito.tools import Signer, Tag, as_tuple, filter_ordered, filter_sorted, flatten
from devito.types.basic import AbstractFunction, AbstractSymbol
from devito.types import Indexed, LocalObject, Symbol
from devito.types.basic import AbstractFunction, AbstractObject, AbstractSymbol
from devito.types import Indexed, Symbol

__all__ = ['Node', 'Block', 'Expression', 'Element', 'Callable', 'Call',
__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call',
'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder',
'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle',
'AugmentedExpression', 'Increment', 'Return', 'While',
Expand Down Expand Up @@ -119,6 +119,9 @@ def args_frozen(self):
def __str__(self):
return str(self.ccode)

def __repr__(self):
return self.__class__.__name__

@property
def functions(self):
"""All AbstractFunction objects used by this node."""
Expand Down Expand Up @@ -207,22 +210,6 @@ def __init__(self, header=None, body=None, footer=None):
self.footer = as_tuple(footer)


class Element(Node):

"""
A generic node. Can be a comment, a statement, ... or anything that cannot
be expressed through an IET type.
"""

def __init__(self, element):
assert isinstance(element, (c.Comment, c.Statement, c.Value, c.Initializer,
c.Pragma, c.Line, c.Assign, c.POD))
self.element = element

def __repr__(self):
return "Element::\n\t%s" % (self.element)


class Call(ExprStmt, Node):

"""
Expand Down Expand Up @@ -287,7 +274,7 @@ def children(self):
def functions(self):
retval = []
for i in self.arguments:
if isinstance(i, (AbstractFunction, Indexed, IndexedBase, LocalObject)):
if isinstance(i, (AbstractFunction, Indexed, IndexedBase, AbstractObject)):
retval.append(i.function)
elif isinstance(i, Call):
retval.extend(i.functions)
Expand All @@ -299,7 +286,7 @@ def functions(self):
for s in v:
try:
# `try-except` necessary for e.g. Macro
if isinstance(s.function, (AbstractFunction, LocalObject)):
if isinstance(s.function, (AbstractFunction, AbstractObject)):
retval.append(s.function)
except AttributeError:
continue
Expand All @@ -315,7 +302,7 @@ def expr_symbols(self):
for i in self.arguments:
if isinstance(i, AbstractFunction):
continue
elif isinstance(i, (Indexed, IndexedBase, LocalObject, Symbol)):
elif isinstance(i, (Indexed, IndexedBase, AbstractObject, Symbol)):
retval.append(i)
elif isinstance(i, Call):
retval.extend(i.expr_symbols)
Expand Down Expand Up @@ -1322,12 +1309,17 @@ def __repr__(self):
return ""


class Return(Node):

def __init__(self, value=None):
self.value = value


def DummyExpr(*args, init=False):
return Expression(DummyEq(*args), init=init)


BlankLine = CBlankLine()
Return = lambda i='': Element(c.Statement('return%s' % ((' %s' % i) if i else i)))


# Nodes required for distributed-memory halo exchange
Expand Down
92 changes: 57 additions & 35 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def visit_Generable(self, o):
body = ' %s' % str(o) if self.verbose else ''
return self.indent + '<C.%s%s>' % (o.__class__.__name__, body)

def visit_Element(self, o):
body = ' %s' % str(o.element) if self.verbose else ''
return self.indent + '<Element%s>' % body

def visit_Callable(self, o):
self._depth += 1
body = self._visit(o.children)
Expand Down Expand Up @@ -244,10 +240,11 @@ def visit_tuple(self, o):

def visit_PointerCast(self, o):
f = o.function
i = f.indexed

if f.is_PointerArray:
# lvalue
lvalue = c.Value(f._C_typedata, '**%s' % f.name)
lvalue = c.Value(i._C_typedata, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -256,7 +253,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (f._C_typedata, v)
rvalue = '(%s**) %s' % (i._C_typedata, v)

else:
# lvalue
Expand All @@ -267,48 +264,50 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(f._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(f._C_typedata, '*%s' % v)
lvalue = c.Value(i._C_typedata, '*%s' % v)
if o.alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

# rvalue
if f.is_DiscreteFunction:
if f.is_DiscreteFunction or \
(f.is_Array and f._mem_mapped):
if isinstance(o.obj, IndexedData):
v = f._C_field_data
elif isinstance(o.obj, DeviceMap):
v = f._C_field_dmap
else:
assert False

rvalue = '(%s %s) %s->%s' % (f._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (f._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (a1._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment,
c.Value(a0._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
c.Value(i._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
)
else:
rvalue = '(%s *) %s[%s]' % (a1._C_typedata, a1.name, a1.dim.name)
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment, c.Value(a0._C_typedata, '*restrict %s' % a0.name)
a0._data_alignment, c.Value(i._C_typedata, '*restrict %s' % a0.name)
)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
Expand All @@ -333,32 +332,37 @@ def visit_Section(self, o):
footer = [c.Comment("End %s" % o.name)]
return c.Module(header + body + footer)

def visit_Element(self, o):
return o.element
def visit_Return(self, o):
v = 'return'
if o.value is not None:
v += ' %s' % o.value
return c.Statement(v)

def visit_Definition(self, o):
f = o.function

if f.is_AbstractFunction:
if f.is_PointerArray:
v = "*"
else:
v = ""
if o.shape is None:
v = "%s*%s" % (v, f._C_name)
else:
v = "%s%s%s" % (v, f._C_name, o.shape)
if o.qualifier is not None:
v = "%s %s" % (v, o.qualifier)
elif f.is_LocalObject:
v = f._C_name
if o.shape is not None:
v = o.shape
else:
assert False
v = ""

if o.constructor_args:
v = MultilineCall(v, o.constructor_args, True)
if o.qualifier is not None:
v = "%s %s" % (v, o.qualifier)

v = c.Value(f._C_typedata, v)
v = "%s%s" % (f._C_name, v)

if f._mem_stack:
v = c.Value(f._C_typedata, v)
else:
if o.constructor_args:
v = MultilineCall(v, o.constructor_args, True)

# Aesthetics: `float * a` -> `float *a`
try:
t, stars = f._C_typename.split()
v = c.Value(t, '%s%s' % (stars, v))
except ValueError:
v = c.Value(f._C_typename, v)

if o.initvalue is not None:
v = c.Initializer(v, o.initvalue)
Expand Down Expand Up @@ -507,7 +511,8 @@ def _operator_typedecls(self, o, mode='all'):
else:
xfilter = lambda i: not isinstance(i, public_types)

typedecls = [i._C_typedecl for i in o.parameters
candidates = o.parameters + tuple(o._dspace.parts)
typedecls = [i._C_typedecl for i in candidates
if xfilter(i) and i._C_typedecl is not None]
for i in o._func_table.values():
if not i.local:
Expand Down Expand Up @@ -964,6 +969,23 @@ class Uxreplace(Transformer):
def visit_Expression(self, o):
return o._rebuild(expr=uxreplace(o.expr, self.mapper))

def visit_Definition(self, o):
try:
return o._rebuild(function=self.mapper[o.function])
except KeyError:
return o

def visit_Return(self, o):
try:
return o._rebuild(value=self.mapper[o.value])
except KeyError:
return o

def visit_Callable(self, o):
body = self._visit(o.body)
parameters = [self.mapper.get(i, i) for i in o.parameters]
return o._rebuild(body=body, parameters=parameters)

def visit_Call(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
return o._rebuild(arguments=arguments)
Expand Down
Loading