-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Revamp sparse subfunction #2374
Changes from all commits
d1a5e32
d088c09
f07151d
ec8f0aa
c8acf3a
67ce156
e05c7e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,4 @@ | ||
from collections import OrderedDict | ||
try: | ||
from collections.abc import Iterable | ||
except ImportError: | ||
# Before python 3.10 | ||
from collections import Iterable | ||
from itertools import product | ||
|
||
import sympy | ||
|
@@ -34,6 +29,14 @@ | |
_default_radius = {'linear': 1, 'sinc': 4} | ||
|
||
|
||
class SparseSubFunction(SubFunction): | ||
|
||
def _arg_apply(self, dataobj, **kwargs): | ||
if self.parent is not None: | ||
return self.parent._dist_subfunc_gather(dataobj, self) | ||
return super()._arg_apply(dataobj, **kwargs) | ||
|
||
|
||
class AbstractSparseFunction(DiscreteFunction): | ||
|
||
""" | ||
|
@@ -49,7 +52,8 @@ class AbstractSparseFunction(DiscreteFunction): | |
_sub_functions = () | ||
"""SubFunctions encapsulated within this AbstractSparseFunction.""" | ||
|
||
__rkwargs__ = DiscreteFunction.__rkwargs__ + ('npoint_global', 'space_order') | ||
__rkwargs__ = (DiscreteFunction.__rkwargs__ + | ||
('dimensions', 'npoint_global', 'space_order')) | ||
|
||
def __init_finalize__(self, *args, **kwargs): | ||
super().__init_finalize__(*args, **kwargs) | ||
|
@@ -61,10 +65,22 @@ def __init_finalize__(self, *args, **kwargs): | |
|
||
@classmethod | ||
def __indices_setup__(cls, *args, **kwargs): | ||
dimensions = as_tuple(kwargs.get('dimensions')) | ||
if not dimensions: | ||
dimensions = (Dimension(name='p_%s' % kwargs["name"]),) | ||
# Need this not to break MatrixSparseFunction | ||
try: | ||
_sub_funcs = tuple(cls._sub_functions) | ||
except TypeError: | ||
_sub_funcs = () | ||
# If a subfunction provided use the sparse dimension | ||
for f in _sub_funcs: | ||
try: | ||
sparse_dim = kwargs[f].indices[0] | ||
break | ||
except (KeyError, AttributeError): | ||
continue | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrong indent level for this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh, TIL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes and it's quite handy |
||
sparse_dim = Dimension(name='p_%s' % kwargs["name"]) | ||
mloubout marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
dimensions = as_tuple(kwargs.get('dimensions', sparse_dim)) | ||
if args: | ||
return tuple(dimensions), tuple(args) | ||
else: | ||
|
@@ -120,6 +136,22 @@ def __subfunc_setup__(self, key, suffix, dtype=None): | |
if key is not None and not isinstance(key, SubFunction): | ||
key = np.array(key) | ||
|
||
# Check if already a SubFunction | ||
if isinstance(key, SubFunction): | ||
d = self.indices[self._sparse_position] | ||
if d in key.indices: | ||
# Can use as is, dimension already matches | ||
if self.alias: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like to keep the else for simple blocks and unindent when it's a decent code chunck that is the "main" part |
||
return key._rebuild(alias=self.alias, name=name) | ||
else: | ||
return key | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
# Need to rebuild so the dimensions match the parent | ||
# SparseFunction, for example we end up here via `.subs(d, new_d)` | ||
indices = (d, *key.indices[1:]) | ||
return key._rebuild(*indices, name=name, alias=self.alias) | ||
|
||
# Given an array or nothing, create dimension and SubFunction | ||
if key is not None: | ||
dimensions = (self._sparse_dim, Dimension(name='d')) | ||
if key.ndim > 2: | ||
|
@@ -132,22 +164,13 @@ def __subfunc_setup__(self, key, suffix, dtype=None): | |
dimensions = (self._sparse_dim, Dimension(name='d')) | ||
shape = (self.npoint, self.grid.dim) | ||
|
||
# Check if already a SubFunction | ||
if isinstance(key, SubFunction): | ||
# Need to rebuild so the dimensions match the parent SparseFunction | ||
indices = (self.indices[self._sparse_position], *key.indices[1:]) | ||
return key._rebuild(*indices, name=name, shape=shape, | ||
alias=self.alias, halo=None) | ||
elif key is not None and not isinstance(key, Iterable): | ||
raise ValueError("`%s` must be either SubFunction " | ||
"or iterable (e.g., list, np.ndarray)" % key) | ||
|
||
if key is None: | ||
# Fallback to default behaviour | ||
dtype = dtype or self.dtype | ||
else: | ||
if (shape != key.shape and key.shape != (shape[1],)) and \ | ||
self._distributor.nprocs == 1: | ||
if shape != key.shape and \ | ||
key.shape != (shape[1],) and \ | ||
self._distributor.nprocs == 1: | ||
raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" % | ||
(suffix, key.shape[:2], shape)) | ||
|
||
|
@@ -157,7 +180,7 @@ def __subfunc_setup__(self, key, suffix, dtype=None): | |
else: | ||
dtype = dtype or self.dtype | ||
|
||
sf = SubFunction( | ||
sf = SparseSubFunction( | ||
name=name, dtype=dtype, dimensions=dimensions, | ||
shape=shape, space_order=0, initializer=key, alias=self.alias, | ||
distributor=self._distributor, parent=self | ||
|
@@ -584,12 +607,6 @@ def _dist_scatter(self, data=None): | |
mapper.update(self._dist_subfunc_scatter(getattr(self, i))) | ||
return mapper | ||
|
||
def _dist_gather(self, data, *subfunc): | ||
self._dist_data_gather(data) | ||
for (sg, s) in zip(subfunc, self._sub_functions): | ||
if getattr(self, s) is not None: | ||
self._dist_subfunc_gather(sg, getattr(self, s)) | ||
|
||
def _eval_at(self, func): | ||
return self | ||
|
||
|
@@ -637,11 +654,11 @@ def _arg_values(self, **kwargs): | |
|
||
return values | ||
|
||
def _arg_apply(self, dataobj, *subfuncs, alias=None): | ||
def _arg_apply(self, dataobj, alias=None): | ||
key = alias if alias is not None else self | ||
if isinstance(key, AbstractSparseFunction): | ||
# Gather into `self.data` | ||
key._dist_gather(dataobj, *subfuncs) | ||
key._dist_data_gather(dataobj) | ||
elif self._distributor.nprocs > 1: | ||
raise NotImplementedError("Don't know how to gather data from an " | ||
"object of type `%s`" % type(key)) | ||
|
@@ -698,7 +715,7 @@ def __indices_setup__(cls, *args, **kwargs): | |
dimensions = as_tuple(kwargs.get('dimensions')) | ||
if not dimensions: | ||
dimensions = (kwargs['grid'].time_dim, | ||
Dimension(name='p_%s' % kwargs["name"]),) | ||
*super().__indices_setup__(*args, **kwargs)[0]) | ||
|
||
if args: | ||
return tuple(dimensions), tuple(args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpicking: we can probably save a line here by compacting these four lines into three, but obviously this is for another day