diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 56d6e124754..bd7a8b67ea5 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -24,7 +24,7 @@ __all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder', 'MetaCall', 'PointerCast', 'HaloSpot', - 'Definition', 'ExpressionBundle', 'AugmentedExpression', + 'Definition', 'ExpressionBundle', 'AugmentedExpression', 'Break', 'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace', @@ -432,8 +432,10 @@ def is_initializable(self): """ True if it can be an initializing assignment, False otherwise. """ - return ((self.is_scalar and not self.is_reduction) or - (self.is_tensor and isinstance(self.expr.rhs, ListInitializer))) + if self.write.is_LocalObject and self.write._mem_internal_lazy: + return False + return (((self.is_scalar and not self.is_reduction) or + (self.is_tensor and isinstance(self.expr.rhs, ListInitializer)))) @property def defines(self): @@ -796,17 +798,19 @@ class CallableBody(MultiTraversable): Data deallocations for `body`. errors : list of Nodes, optional Error handling for `body`. + retstmt : Node, optional + The return statement for `body`. """ is_CallableBody = True _traversable = ['unpacks', 'init', 'standalones', 'allocs', 'stacks', 'casts', 'bundles', 'maps', 'strides', 'objs', 'body', - 'unmaps', 'unbundles', 'frees', 'errors'] + 'unmaps', 'unbundles', 'frees', 'errors', 'retstmt'] def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(), allocs=(), stacks=(), casts=(), bundles=(), objs=(), maps=(), - unmaps=(), unbundles=(), frees=(), errors=()): + unmaps=(), unbundles=(), frees=(), errors=(), retstmt=()): # Sanity check assert not isinstance(body, CallableBody), "CallableBody's cannot be nested" @@ -826,6 +830,7 @@ def __init__(self, body, init=(), standalones=(), unpacks=(), strides=(), self.unbundles = as_tuple(unbundles) self.frees = as_tuple(frees) self.errors = as_tuple(errors) + self.retstmt = as_tuple(retstmt) def __repr__(self): return ("