Skip to content

Commit

Permalink
Merge pull request #1971 from devitocodes/array-struct
Browse files Browse the repository at this point in the history
compiler: Infer type generation from _C_ctype
  • Loading branch information
FabioLuporini authored Jul 28, 2022
2 parents 1d98eb1 + 9a5e60d commit 41b303a
Show file tree
Hide file tree
Showing 18 changed files with 412 additions and 437 deletions.
17 changes: 14 additions & 3 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from devito.ir.iet.utils import derive_parameters
from devito.tools import as_tuple

__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'EntryFunction',
'AsyncCallable', 'AsyncCall', 'ThreadCallable', 'DeviceFunction',
'DeviceCall']
__all__ = ['ElementalFunction', 'ElementalCall', 'make_efunc', 'make_callable',
'EntryFunction', 'AsyncCallable', 'AsyncCall', 'ThreadCallable',
'DeviceFunction', 'DeviceCall']


# ElementalFunction machinery
Expand Down Expand Up @@ -89,6 +89,17 @@ def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static
dynamic_parameters)


# Callable machinery


def make_callable(name, iet, retval='void', prefix='static'):
"""
Utility function to create a Callable from an IET.
"""
parameters = derive_parameters(iet)
return Callable(name, iet, retval, parameters=parameters, prefix=prefix)


# EntryFunction machinery

class EntryFunction(Callable):
Expand Down
38 changes: 15 additions & 23 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
Property, Forward, detect_io)
from devito.symbolics import ListInitializer, CallFromPointer, ccode
from devito.tools import Signer, Tag, as_tuple, filter_ordered, filter_sorted, flatten
from devito.types.basic import AbstractFunction, AbstractSymbol
from devito.types import Indexed, LocalObject, Symbol
from devito.types.basic import AbstractFunction, AbstractObject, AbstractSymbol
from devito.types import Indexed, Symbol

__all__ = ['Node', 'Block', 'Expression', 'Element', 'Callable', 'Call',
__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call',
'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder',
'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle',
'AugmentedExpression', 'Increment', 'Return', 'While',
Expand Down Expand Up @@ -119,6 +119,9 @@ def args_frozen(self):
def __str__(self):
return str(self.ccode)

def __repr__(self):
return self.__class__.__name__

@property
def functions(self):
"""All AbstractFunction objects used by this node."""
Expand Down Expand Up @@ -207,22 +210,6 @@ def __init__(self, header=None, body=None, footer=None):
self.footer = as_tuple(footer)


class Element(Node):

"""
A generic node. Can be a comment, a statement, ... or anything that cannot
be expressed through an IET type.
"""

def __init__(self, element):
assert isinstance(element, (c.Comment, c.Statement, c.Value, c.Initializer,
c.Pragma, c.Line, c.Assign, c.POD))
self.element = element

def __repr__(self):
return "Element::\n\t%s" % (self.element)


class Call(ExprStmt, Node):

"""
Expand Down Expand Up @@ -287,7 +274,7 @@ def children(self):
def functions(self):
retval = []
for i in self.arguments:
if isinstance(i, (AbstractFunction, Indexed, IndexedBase, LocalObject)):
if isinstance(i, (AbstractFunction, Indexed, IndexedBase, AbstractObject)):
retval.append(i.function)
elif isinstance(i, Call):
retval.extend(i.functions)
Expand All @@ -299,7 +286,7 @@ def functions(self):
for s in v:
try:
# `try-except` necessary for e.g. Macro
if isinstance(s.function, (AbstractFunction, LocalObject)):
if isinstance(s.function, (AbstractFunction, AbstractObject)):
retval.append(s.function)
except AttributeError:
continue
Expand All @@ -315,7 +302,7 @@ def expr_symbols(self):
for i in self.arguments:
if isinstance(i, AbstractFunction):
continue
elif isinstance(i, (Indexed, IndexedBase, LocalObject, Symbol)):
elif isinstance(i, (Indexed, IndexedBase, AbstractObject, Symbol)):
retval.append(i)
elif isinstance(i, Call):
retval.extend(i.expr_symbols)
Expand Down Expand Up @@ -1322,12 +1309,17 @@ def __repr__(self):
return ""


class Return(Node):

def __init__(self, value=None):
self.value = value


def DummyExpr(*args, init=False):
return Expression(DummyEq(*args), init=init)


BlankLine = CBlankLine()
Return = lambda i='': Element(c.Statement('return%s' % ((' %s' % i) if i else i)))


# Nodes required for distributed-memory halo exchange
Expand Down
92 changes: 57 additions & 35 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ def visit_Generable(self, o):
body = ' %s' % str(o) if self.verbose else ''
return self.indent + '<C.%s%s>' % (o.__class__.__name__, body)

def visit_Element(self, o):
body = ' %s' % str(o.element) if self.verbose else ''
return self.indent + '<Element%s>' % body

def visit_Callable(self, o):
self._depth += 1
body = self._visit(o.children)
Expand Down Expand Up @@ -244,10 +240,11 @@ def visit_tuple(self, o):

def visit_PointerCast(self, o):
f = o.function
i = f.indexed

if f.is_PointerArray:
# lvalue
lvalue = c.Value(f._C_typedata, '**%s' % f.name)
lvalue = c.Value(i._C_typedata, '**%s' % f.name)

# rvalue
if isinstance(o.obj, ArrayObject):
Expand All @@ -256,7 +253,7 @@ def visit_PointerCast(self, o):
v = f._C_name
else:
assert False
rvalue = '(%s**) %s' % (f._C_typedata, v)
rvalue = '(%s**) %s' % (i._C_typedata, v)

else:
# lvalue
Expand All @@ -267,48 +264,50 @@ def visit_PointerCast(self, o):
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(f._C_typedata, '(*restrict %s)%s' % (v, shape))
lvalue = c.Value(i._C_typedata, '(*restrict %s)%s' % (v, shape))
else:
rshape = '*'
lvalue = c.Value(f._C_typedata, '*%s' % v)
lvalue = c.Value(i._C_typedata, '*%s' % v)
if o.alignment:
lvalue = c.AlignedAttribute(f._data_alignment, lvalue)

# rvalue
if f.is_DiscreteFunction:
if f.is_DiscreteFunction or \
(f.is_Array and f._mem_mapped):
if isinstance(o.obj, IndexedData):
v = f._C_field_data
elif isinstance(o.obj, DeviceMap):
v = f._C_field_dmap
else:
assert False

rvalue = '(%s %s) %s->%s' % (f._C_typedata, rshape, f._C_name, v)
rvalue = '(%s %s) %s->%s' % (i._C_typedata, rshape, f._C_name, v)
else:
if isinstance(o.obj, Pointer):
v = o.obj.name
else:
v = f._C_name

rvalue = '(%s %s) %s' % (f._C_typedata, rshape, v)
rvalue = '(%s %s) %s' % (i._C_typedata, rshape, v)

return c.Initializer(lvalue, rvalue)

def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (a1._C_typedata, shape, a1.name,
rvalue = '(%s (*)%s) %s[%s]' % (i._C_typedata, shape, a1.name,
a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment,
c.Value(a0._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
c.Value(i._C_typedata, '(*restrict %s)%s' % (a0.name, shape))
)
else:
rvalue = '(%s *) %s[%s]' % (a1._C_typedata, a1.name, a1.dim.name)
rvalue = '(%s *) %s[%s]' % (i._C_typedata, a1.name, a1.dim.name)
lvalue = c.AlignedAttribute(
a0._data_alignment, c.Value(a0._C_typedata, '*restrict %s' % a0.name)
a0._data_alignment, c.Value(i._C_typedata, '*restrict %s' % a0.name)
)
else:
rvalue = '%s->%s' % (a1.name, a0._C_name)
Expand All @@ -333,32 +332,37 @@ def visit_Section(self, o):
footer = [c.Comment("End %s" % o.name)]
return c.Module(header + body + footer)

def visit_Element(self, o):
return o.element
def visit_Return(self, o):
v = 'return'
if o.value is not None:
v += ' %s' % o.value
return c.Statement(v)

def visit_Definition(self, o):
f = o.function

if f.is_AbstractFunction:
if f.is_PointerArray:
v = "*"
else:
v = ""
if o.shape is None:
v = "%s*%s" % (v, f._C_name)
else:
v = "%s%s%s" % (v, f._C_name, o.shape)
if o.qualifier is not None:
v = "%s %s" % (v, o.qualifier)
elif f.is_LocalObject:
v = f._C_name
if o.shape is not None:
v = o.shape
else:
assert False
v = ""

if o.constructor_args:
v = MultilineCall(v, o.constructor_args, True)
if o.qualifier is not None:
v = "%s %s" % (v, o.qualifier)

v = c.Value(f._C_typedata, v)
v = "%s%s" % (f._C_name, v)

if f._mem_stack:
v = c.Value(f._C_typedata, v)
else:
if o.constructor_args:
v = MultilineCall(v, o.constructor_args, True)

# Aesthetics: `float * a` -> `float *a`
try:
t, stars = f._C_typename.split()
v = c.Value(t, '%s%s' % (stars, v))
except ValueError:
v = c.Value(f._C_typename, v)

if o.initvalue is not None:
v = c.Initializer(v, o.initvalue)
Expand Down Expand Up @@ -507,7 +511,8 @@ def _operator_typedecls(self, o, mode='all'):
else:
xfilter = lambda i: not isinstance(i, public_types)

typedecls = [i._C_typedecl for i in o.parameters
candidates = o.parameters + tuple(o._dspace.parts)
typedecls = [i._C_typedecl for i in candidates
if xfilter(i) and i._C_typedecl is not None]
for i in o._func_table.values():
if not i.local:
Expand Down Expand Up @@ -964,6 +969,23 @@ class Uxreplace(Transformer):
def visit_Expression(self, o):
return o._rebuild(expr=uxreplace(o.expr, self.mapper))

def visit_Definition(self, o):
try:
return o._rebuild(function=self.mapper[o.function])
except KeyError:
return o

def visit_Return(self, o):
try:
return o._rebuild(value=self.mapper[o.value])
except KeyError:
return o

def visit_Callable(self, o):
body = self._visit(o.body)
parameters = [self.mapper.get(i, i) for i in o.parameters]
return o._rebuild(body=body, parameters=parameters)

def visit_Call(self, o):
arguments = [uxreplace(i, self.mapper) for i in o.arguments]
return o._rebuild(arguments=arguments)
Expand Down
Loading

0 comments on commit 41b303a

Please sign in to comment.