Skip to content

Commit

Permalink
compiler: Tweak check_stability to ensure cleanup is performed
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Mar 21, 2024
1 parent 7b7b1eb commit adc0389
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
17 changes: 12 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable',
'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section',
'TimedList', 'Prodder', 'MetaCall', 'PointerCast', 'HaloSpot',
'Definition', 'ExpressionBundle', 'AugmentedExpression',
'Definition', 'ExpressionBundle', 'AugmentedExpression', 'Break',
'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration',
'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma',
'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace',
Expand Down Expand Up @@ -432,8 +432,8 @@ def is_initializable(self):
"""
True if it can be an initializing assignment, False otherwise.
"""
return ((self.is_scalar and not self.is_reduction) or
(self.is_tensor and isinstance(self.expr.rhs, ListInitializer)))
return (((self.is_scalar and not self.is_reduction) or
(self.is_tensor and isinstance(self.expr.rhs, ListInitializer))))

@property
def defines(self):
Expand Down Expand Up @@ -796,17 +796,19 @@ class CallableBody(MultiTraversable):
Data deallocations for `body`.
errors : list of Nodes, optional
Error handling for `body`.
retstmt : Node, optional
The return statement for `body`.
"""

is_CallableBody = True

_traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks',
'casts', 'bundles', 'maps', 'strides', 'objs', 'body',
'unmaps', 'unbundles', 'frees', 'errors']
'unmaps', 'unbundles', 'frees', 'errors', 'retstmt']

def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(),
unmaps=(), unbundles=(), frees=(), errors=()):
unmaps=(), unbundles=(), frees=(), errors=(), retstmt=()):
# Sanity check
assert not isinstance(body, CallableBody), "CallableBody's cannot be nested"

Expand All @@ -826,6 +828,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(),
self.unbundles = as_tuple(unbundles)
self.frees = as_tuple(frees)
self.errors = as_tuple(errors)
self.retstmt = as_tuple(retstmt)

def __repr__(self):
return ("<CallableBody <unpacks=%d, allocs=%d, casts=%d, maps=%d, "
Expand Down Expand Up @@ -1385,6 +1388,10 @@ def __repr__(self):
return ""


class Break(Node):
pass


class Return(Node):

def __init__(self, value=None):
Expand Down
11 changes: 10 additions & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ def visit_Section(self, o):
body = flatten(self._visit(i) for i in o.children)
return c.Module(body)

def visit_Break(self, o):
return c.Statement('break')

def visit_Return(self, o):
v = 'return'
if o.value is not None:
Expand Down Expand Up @@ -679,7 +682,13 @@ def visit_Operator(self, o, mode='all'):
# Kernel signature and body
body = flatten(self._visit(i) for i in o.children)
signature = self._gen_signature(o)
retval = [c.Line(), c.Statement("return 0")]

# Honor the `retstmt` flag if set
if o.body.retstmt:
retval = []
else:
retval = [c.Line(), c.Statement("return 0")]

kernel = c.FunctionBody(signature, c.Block(body + retval))

# Elemental functions
Expand Down
34 changes: 29 additions & 5 deletions devito/passes/iet/errors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import cgen as c
import numpy as np
from sympy import Not
from sympy import Expr, Not, S

from devito.ir.iet import (Call, Conditional, EntryFunction, Iteration, List,
Return, FindNodes, FindSymbols, Transformer,
from devito.ir.iet import (Call, Conditional, DummyExpr, EntryFunction, Iteration,
List, Break, Return, FindNodes, FindSymbols, Transformer,
make_callable)
from devito.passes.iet.engine import iet_pass
from devito.symbolics import CondEq, DefFunction
from devito.types import Eq, Inc, Symbol
from devito.tools import dtype_to_ctype
from devito.types import Eq, Inc, LocalObject, Symbol

__all__ = ['check_stability', 'error_mapper']

Expand Down Expand Up @@ -69,18 +70,41 @@ def _check_stability(iet, wmovs=(), rcompile=None, sregistry=None):
name = sregistry.make_name(prefix='check')
check = Symbol(name=name, dtype=np.int32)

retval = Retval(name='retval')

errctl = Conditional(CondEq(n.dim % 100, 0), List(body=[
Call(efunc.name, efunc.parameters, retobj=check),
Conditional(Not(check), Return(error_mapper['Stability']))
Conditional(Not(check), List(body=[
DummyExpr(retval, error_mapper['Stability']),
Break()
]))
]))
errctl = List(header=c.Comment("Stability check"), body=[errctl])
mapper[n] = n._rebuild(nodes=n.nodes + (errctl,))

# One check is enough
break
else:
return iet, {}

iet = Transformer(mapper).visit(iet)

# We now must return a suitable error code
body = iet.body._rebuild(
body=(DummyExpr(retval, 0, init=True),) + iet.body.body,
retstmt=Return(retval)
)
iet = iet._rebuild(body=body)

return iet, {'efuncs': efuncs, 'includes': includes}


class Retval(LocalObject, Expr):

dtype = dtype_to_ctype(np.int32)
default_initvalue = S.Zero


error_mapper = {
'Stability': 100,
'KernelLaunch': 200,
Expand Down

0 comments on commit adc0389

Please sign in to comment.