Skip to content

Commit 23ff475

Browse files
committed
compiler: Tweak SparseFunction reconstruction
1 parent 35d0317 commit 23ff475

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

devito/types/sparse.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class AbstractSparseFunction(DiscreteFunction):
4949
_sub_functions = ()
5050
"""SubFunctions encapsulated within this AbstractSparseFunction."""
5151

52-
__rkwargs__ = DiscreteFunction.__rkwargs__ + ('npoint_global', 'space_order')
52+
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
53+
('dimensions', 'npoint_global', 'space_order'))
5354

5455
def __init_finalize__(self, *args, **kwargs):
5556
super().__init_finalize__(*args, **kwargs)
@@ -133,14 +134,17 @@ def __subfunc_setup__(self, key, suffix, dtype=None):
133134
shape = (self.npoint, self.grid.dim)
134135

135136
# Check if already a SubFunction
137+
d = self.indices[self._sparse_position]
136138
if isinstance(key, SubFunction):
137-
# Need to rebuild so the dimensions match the parent SparseFunction
138-
indices = (self.indices[self._sparse_position], *key.indices[1:])
139-
return key._rebuild(*indices, name=name, shape=shape,
140-
alias=self.alias, halo=None)
141-
elif key is not None and not isinstance(key, Iterable):
142-
raise ValueError("`%s` must be either SubFunction "
143-
"or iterable (e.g., list, np.ndarray)" % key)
139+
if d in key.dimensions and not self.alias:
140+
# From a reconstruction which leaves `dimensions` intact
141+
return key
142+
else:
143+
# Need to rebuild so the dimensions match the parent
144+
# SparseFunction, for example we end up here via `.subs(d, new_d)`
145+
indices = (d, *key.indices[1:])
146+
return key._rebuild(*indices, name=name, shape=shape,
147+
alias=self.alias, halo=None)
144148

145149
if key is None:
146150
# Fallback to default behaviour

tests/test_mpi.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2633,14 +2633,17 @@ def run_adjoint_F(self, nd):
26332633
solver = acoustic_setup(shape=shape, spacing=[15. for _ in shape],
26342634
tn=tn, space_order=so, nrec=nrec,
26352635
preset='layers-isotropic', dtype=np.float64)
2636+
26362637
# Run forward operator
2637-
rec, u, _ = solver.forward()
2638+
src = solver.geometry.src
2639+
rec, u, _ = solver.forward(src=src)
26382640

26392641
assert np.isclose(norm(u) / Eu, 1.0)
26402642
assert np.isclose(norm(rec) / Erec, 1.0)
26412643

26422644
# Run adjoint operator
2643-
srca, v, _ = solver.adjoint(rec=rec)
2645+
srca = src.func(name='srca')
2646+
srca, v, _ = solver.adjoint(srca=srca, rec=rec)
26442647

26452648
assert np.isclose(norm(v) / Ev, 1.0)
26462649
assert np.isclose(norm(srca) / Esrca, 1.0)

tests/test_sparse.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,13 @@ def test_rebuild(self, sptype):
417417
assert getattr(sp, subf).name.startswith("s_")
418418

419419
# Rebuild with different name, this should drop the function
420-
# and create new data
420+
# and create new data, while the coordinates and more generally all
421+
# SubFunctions remain the same
421422
sp2 = sp._rebuild(name="sr")
422-
423-
# Check new subfunction
424423
for subf in sp2._sub_functions:
425424
if getattr(sp2, subf) is not None:
426-
assert getattr(sp2, subf).name.startswith("sr_")
427-
assert np.all(getattr(sp2, subf).data == 0)
425+
assert getattr(sp2, subf).name.startswith("s_")
426+
assert np.all(getattr(sp2, subf).data == getattr(sp, subf).data)
428427

429428
# Rebuild with different name as an alias
430429
sp2 = sp._rebuild(name="sr2", alias=True)
@@ -433,6 +432,14 @@ def test_rebuild(self, sptype):
433432
assert getattr(sp2, subf).name.startswith("sr2_")
434433
assert getattr(sp2, subf).data is None
435434

435+
# Rebuild with different name and dimensions. This is expected to recreate
436+
# the SubFunctions as well
437+
sp2 = sp._rebuild(name="sr3", dimensions=None)
438+
for subf in sp2._sub_functions:
439+
if getattr(sp2, subf) is not None:
440+
assert getattr(sp2, subf).name.startswith("sr3_")
441+
assert np.all(getattr(sp2, subf).data == 0)
442+
436443
@pytest.mark.parametrize('sptype', _sptypes)
437444
def test_subs(self, sptype):
438445
grid = Grid((3, 3, 3))

0 commit comments

Comments
 (0)