@@ -49,7 +49,8 @@ class AbstractSparseFunction(DiscreteFunction):
49
49
_sub_functions = ()
50
50
"""SubFunctions encapsulated within this AbstractSparseFunction."""
51
51
52
- __rkwargs__ = DiscreteFunction .__rkwargs__ + ('npoint_global' , 'space_order' )
52
+ __rkwargs__ = (DiscreteFunction .__rkwargs__ +
53
+ ('dimensions' , 'npoint_global' , 'space_order' ))
53
54
54
55
def __init_finalize__ (self , * args , ** kwargs ):
55
56
super ().__init_finalize__ (* args , ** kwargs )
@@ -133,14 +134,17 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
133
134
shape = (self .npoint , self .grid .dim )
134
135
135
136
# Check if already a SubFunction
137
+ d = self .indices [self ._sparse_position ]
136
138
if isinstance (key , SubFunction ):
137
- # Need to rebuild so the dimensions match the parent SparseFunction
138
- indices = (self .indices [self ._sparse_position ], * key .indices [1 :])
139
- return key ._rebuild (* indices , name = name , shape = shape ,
140
- alias = self .alias , halo = None )
141
- elif key is not None and not isinstance (key , Iterable ):
142
- raise ValueError ("`%s` must be either SubFunction "
143
- "or iterable (e.g., list, np.ndarray)" % key )
139
+ if d in key .dimensions and not self .alias :
140
+ # From a reconstruction which leaves `dimensions` intact
141
+ return key
142
+ else :
143
+ # Need to rebuild so the dimensions match the parent
144
+ # SparseFunction, for example we end up here via `.subs(d, new_d)`
145
+ indices = (d , * key .indices [1 :])
146
+ return key ._rebuild (* indices , name = name , shape = shape ,
147
+ alias = self .alias , halo = None )
144
148
145
149
if key is None :
146
150
# Fallback to default behaviour
0 commit comments