Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI: Fix sparse subfunction handling when used without parent #2278

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ObjectiveDomain(dv.SubDomain):
name = 'objective_domain'

def __init__(self, lw):
super(ObjectiveDomain, self).__init__()
super().__init__()
self.lw = lw

def define(self, dimensions):
Expand Down
2 changes: 1 addition & 1 deletion devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def initialize(cls):
cls.lib = lib

def __init__(self, node):
super(NumaAllocator, self).__init__()
super().__init__()
self._node = node

def _alloc_C_libcall(self, size, ctype):
Expand Down
14 changes: 7 additions & 7 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
# Retrieve the pertinent local data prior to MPI send/receive operations
data_idx = loc_data_idx(loc_idx)
self._index_stash = flip_idx(glb_idx, self._decomposition)
local_val = super(Data, self).__getitem__(data_idx)
local_val = super().__getitem__(data_idx)
self._index_stash = None

comm = self._distributor.comm
Expand Down Expand Up @@ -314,7 +314,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
return None
else:
self._index_stash = glb_idx
retval = super(Data, self).__getitem__(loc_idx)
retval = super().__getitem__(loc_idx)
self._index_stash = None
return retval

Expand All @@ -328,9 +328,9 @@ def __setitem__(self, glb_idx, val, comm_type):
if index_is_basic(loc_idx):
# Won't go through `__getitem__` as it's basic indexing mode,
# so we should just propage `loc_idx`
super(Data, self).__setitem__(loc_idx, val)
super().__setitem__(loc_idx, val)
else:
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
elif isinstance(val, Data) and val._is_distributed:
if comm_type is index_by_index:
glb_idx, val = self._process_args(glb_idx, val)
Expand All @@ -353,7 +353,7 @@ def __setitem__(self, glb_idx, val, comm_type):
self.__setitem__(idx_global[j], data_global[j])
elif self._is_distributed:
# `val` is decomposed, `self` is decomposed -> local set
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
else:
# `val` is decomposed, `self` is replicated -> gatherall-like
raise NotImplementedError
Expand Down Expand Up @@ -389,13 +389,13 @@ def __setitem__(self, glb_idx, val, comm_type):
else:
# `val` is replicated`, `self` is replicated -> plain ndarray.__setitem__
pass
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
elif isinstance(val, Iterable):
if self._is_mpi_distributed:
raise NotImplementedError("With MPI, data can only be set "
"via scalars, numpy arrays or "
"other data ")
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
else:
raise ValueError("Cannot insert obj of type `%s` into a Data" % type(val))

Expand Down
2 changes: 1 addition & 1 deletion devito/data/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __new__(cls, items, local):
raise TypeError("Illegal Decomposition element type")
if not is_integer(local) and (0 <= local < len(items)):
raise ValueError("`local` must be an index in ``items``.")
obj = super(Decomposition, cls).__new__(cls, [np.array(i) for i in items])
obj = super().__new__(cls, [np.array(i) for i in items])
obj._local = local
return obj

Expand Down
2 changes: 1 addition & 1 deletion devito/data/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __str__(self):
class DataSide(Tag):

def __init__(self, name, val, flipto=None):
super(DataSide, self).__init__(name, val)
super().__init__(name, val)
self.flipto = flipto
if flipto is not None:
flipto.flipto = self
Expand Down
8 changes: 4 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _fd_priority(self):
return .75 if self.is_TimeDependent else .5

def __hash__(self):
return super(Differentiable, self).__hash__()
return super().__hash__()

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -245,7 +245,7 @@ def __neg__(self):
return Mul(sympy.S.NegativeOne, self)

def __eq__(self, other):
ret = super(Differentiable, self).__eq__(other)
ret = super().__eq__(other)
if ret is NotImplemented or not ret:
# Non comparable or not equal as sympy objects
return False
Expand Down Expand Up @@ -734,7 +734,7 @@ class IndexDerivative(IndexSum):
__rargs__ = ('expr', 'mapper')

def __new__(cls, expr, mapper, **kwargs):
dimensions = as_tuple(mapper.values())
dimensions = as_tuple(set(mapper.values()))

# Detect the Weights among the arguments
weightss = []
Expand Down Expand Up @@ -799,7 +799,7 @@ def _evaluate(self, **kwargs):
mapper = {w.subs(d, i): f.weights[n] for n, i in enumerate(d.range)}
expr = expr.xreplace(mapper)

return expr
return EvalDerivative(expr, base=self.base)


# SymPy args ordering is the same for Derivatives and IndexDerivatives
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class ClusterGroup(tuple):
"""

def __new__(cls, clusters, ispace=None):
obj = super(ClusterGroup, cls).__new__(cls, flatten(as_tuple(clusters)))
obj = super().__new__(cls, flatten(as_tuple(clusters)))
obj._ispace = ispace
return obj

Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self):
self.scopes = {}

def __init__(self, state=None):
super(QueueStateful, self).__init__()
super().__init__()
self.state = state or QueueStateful.State()

def _fetch_scope(self, clusters):
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __new__(cls, *args, **kwargs):
rhs = diff2sympy(expr.rhs)

# Finally create the LoweredEq with all metadata attached
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, rhs, evaluate=False)
expr = super().__new__(cls, expr.lhs, rhs, evaluate=False)

expr._ispace = ispace
expr._conditionals = conditionals
Expand Down
6 changes: 3 additions & 3 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)
super().__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
return super()._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)
Expand All @@ -63,7 +63,7 @@ class ElementalFunction(Callable):

def __init__(self, name, body, retval='void', parameters=None, prefix=('static',),
dynamic_parameters=None):
super(ElementalFunction, self).__init__(name, body, retval, parameters, prefix)
super().__init__(name, body, retval, parameters, prefix)

self._mapper = {}
for i in as_tuple(dynamic_parameters):
Expand Down
8 changes: 4 additions & 4 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Node(Signer):
"""

def __new__(cls, *args, **kwargs):
obj = super(Node, cls).__new__(cls)
obj = super().__new__(cls)
argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__)
try:
defaults = dict(zip(argnames[-len(defaultvalues):], defaultvalues))
Expand Down Expand Up @@ -1064,7 +1064,7 @@ class Section(List):
is_Section = True

def __init__(self, name, body=None, is_subsection=False):
super(Section, self).__init__(body=body)
super().__init__(body=body)
self.name = name
self.is_subsection = is_subsection

Expand All @@ -1085,7 +1085,7 @@ class ExpressionBundle(List):
is_ExpressionBundle = True

def __init__(self, ispace, ops, traffic, body=None):
super(ExpressionBundle, self).__init__(body=body)
super().__init__(body=body)
self.ispace = ispace
self.ops = ops
self.traffic = traffic
Expand Down Expand Up @@ -1332,7 +1332,7 @@ class HaloSpot(Node):
_traversable = ['body']

def __init__(self, body, halo_scheme):
super(HaloSpot, self).__init__()
super().__init__()

if isinstance(body, Node):
self._body = body
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def dimensions(self):
return [i.dim for i in self]

def __repr__(self):
return "IterationTree%s" % super(IterationTree, self).__repr__()
return "IterationTree%s" % super().__repr__()

def __getitem__(self, key):
ret = super(IterationTree, self).__getitem__(key)
ret = super().__getitem__(key)
return IterationTree(ret) if isinstance(key, slice) else ret


Expand Down
10 changes: 5 additions & 5 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PrintAST(Visitor):
"""

def __init__(self, verbose=True):
super(PrintAST, self).__init__()
super().__init__()
self.verbose = verbose

@classmethod
Expand Down Expand Up @@ -802,7 +802,7 @@ def default_retval(cls):
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
super(MapNodes, self).__init__()
super().__init__()
if parent_type is None:
self.parent_type = Iteration
elif parent_type == 'any':
Expand Down Expand Up @@ -958,7 +958,7 @@ def default_retval(cls):
}

def __init__(self, match, mode='type'):
super(FindNodes, self).__init__()
super().__init__()
self.match = match
self.rule = self.rules[mode]

Expand Down Expand Up @@ -1038,7 +1038,7 @@ class IsPerfectIteration(Visitor):
"""

def __init__(self, depth=None):
super(IsPerfectIteration, self).__init__()
super().__init__()

assert depth is None or isinstance(depth, Iteration)
self.depth = depth
Expand Down Expand Up @@ -1091,7 +1091,7 @@ class Transformer(Visitor):
"""

def __init__(self, mapper, nested=False):
super(Transformer, self).__init__()
super().__init__()
self.mapper = mapper
self.nested = nested

Expand Down
8 changes: 4 additions & 4 deletions devito/ir/stree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NodeIteration(ScheduleTree):
is_Iteration = True

def __init__(self, ispace, parent=None, properties=None):
super(NodeIteration, self).__init__(parent)
super().__init__(parent)
self.ispace = ispace
self.properties = properties

Expand Down Expand Up @@ -78,7 +78,7 @@ class NodeConditional(ScheduleTree):
is_Conditional = True

def __init__(self, guard, parent=None):
super(NodeConditional, self).__init__(parent)
super().__init__(parent)
self.guard = guard

@property
Expand All @@ -91,7 +91,7 @@ class NodeSync(ScheduleTree):
is_Sync = True

def __init__(self, sync_ops, parent=None):
super(NodeSync, self).__init__(parent)
super().__init__(parent)
self.sync_ops = sync_ops

@property
Expand All @@ -104,7 +104,7 @@ class NodeExprs(ScheduleTree):
is_Exprs = True

def __init__(self, exprs, ispace, dspace, ops, traffic, parent=None):
super(NodeExprs, self).__init__(parent)
super().__init__(parent)
self.exprs = exprs
self.ispace = ispace
self.dspace = dspace
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@ def inplace(self, dim=None):

def __add__(self, other):
assert isinstance(other, DependenceGroup)
return DependenceGroup(super(DependenceGroup, self).__or__(other))
return DependenceGroup(super().__or__(other))

def __sub__(self, other):
assert isinstance(other, DependenceGroup)
return DependenceGroup(super(DependenceGroup, self).__sub__(other))
return DependenceGroup(super().__sub__(other))

def project(self, function):
"""
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Property(Tag):
_KNOWN = []

def __init__(self, name, val=None):
super(Property, self).__init__(name, val)
super().__init__(name, val)
Property._KNOWN.append(self)


Expand Down
18 changes: 9 additions & 9 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class Interval(AbstractInterval):
is_Defined = True

def __init__(self, dim, lower=0, upper=0, stamp=S0):
super(Interval, self).__init__(dim, stamp)
super().__init__(dim, stamp)

try:
self.lower = int(lower)
Expand All @@ -147,7 +147,7 @@ def __eq__(self, o):
if self is o:
return True

return (super(Interval, self).__eq__(o) and
return (super().__eq__(o) and
self.lower == o.lower and
self.upper == o.upper)

Expand Down Expand Up @@ -526,16 +526,16 @@ def expand(self, d=None):

def index(self, key):
if isinstance(key, Interval):
return super(IntervalGroup, self).index(key)
return super().index(key)
elif isinstance(key, Dimension):
return super(IntervalGroup, self).index(self[key])
return super().index(self[key])
raise ValueError("Expected Interval or Dimension, got `%s`" % type(key))

def __getitem__(self, key):
if is_integer(key):
return super(IntervalGroup, self).__getitem__(key)
return super().__getitem__(key)
elif isinstance(key, slice):
retval = super(IntervalGroup, self).__getitem__(key)
retval = super().__getitem__(key)
return IntervalGroup(retval, relations=self.relations, mode=self.mode)

if not self.is_well_defined:
Expand Down Expand Up @@ -699,7 +699,7 @@ def __eq__(self, other):
self.parts == other.parts)

def __hash__(self):
return hash((super(DataSpace, self).__hash__(), self.parts))
return hash((super().__hash__(), self.parts))

@classmethod
def union(cls, *others):
Expand Down Expand Up @@ -753,7 +753,7 @@ class IterationSpace(Space):
"""

def __init__(self, intervals, sub_iterators=None, directions=None):
super(IterationSpace, self).__init__(intervals)
super().__init__(intervals)

# Normalize sub-iterators
sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v))))
Expand Down Expand Up @@ -788,7 +788,7 @@ def __lt__(self, other):
return len(self.itintervals) < len(other.itintervals)

def __hash__(self):
return hash((super(IterationSpace, self).__hash__(), self.sub_iterators,
return hash((super().__hash__(), self.sub_iterators,
self.directions))

def __contains__(self, d):
Expand Down
Loading
Loading