Skip to content

Commit b169d60

Browse files
authored
Merge pull request #56 from oceanmodeling/jurjendejong
fix: plan/node more robust indexing + ci test
2 parents 3dc3285 + 30f56f5 commit b169d60

File tree

3 files changed

+67
-52
lines changed

3 files changed

+67
-52
lines changed

tests/data/r3d_bump.slf

757 KB
Binary file not shown.

tests/io_test.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
import xarray as xr
55
from scipy.spatial import Delaunay
66

7+
BUMP = pytest.mark.parametrize(
8+
"slf_in",
9+
[
10+
pytest.param("tests/data/r3d_bump.slf", id="3D"),
11+
],
12+
)
13+
714
TIDAL_FLATS = pytest.mark.parametrize(
815
"slf_in",
916
[
@@ -80,7 +87,6 @@ def test_to_netcdf(tmp_path, slf_in):
8087
@TIDAL_FLATS
8188
def test_to_selafin(tmp_path, slf_in):
8289
with xr.open_dataset(slf_in, engine="selafin") as ds_slf:
83-
8490
# Remove some data which is rebuilt
8591
del ds_slf.attrs["date_start"]
8692

@@ -98,7 +104,6 @@ def test_to_selafin(tmp_path, slf_in):
98104
@TIDAL_FLATS
99105
def test_to_selafin_eager_mode(tmp_path, slf_in):
100106
with xr.open_dataset(slf_in, lazy_loading=False, engine="selafin") as ds_slf:
101-
102107
# Remove some data which is rebuilt
103108
del ds_slf.attrs["date_start"]
104109

@@ -167,3 +172,14 @@ def test_from_scratch(tmp_path):
167172
def test_dim(slf_in):
168173
with xr.open_dataset(slf_in, engine="selafin") as ds:
169174
repr(ds)
175+
176+
177+
@BUMP
178+
def test_eager_vs_lazy(slf_in):
179+
with xr.load_dataset(slf_in, engine="selafin") as ds_eager:
180+
z_levels_eager = ds_eager.Z.isel(time=0).drop_vars("time")
181+
dz_eager = z_levels_eager.diff(dim="plan")
182+
with xr.open_dataset(slf_in, engine="selafin") as ds_lazy:
183+
z_levels_lazy = ds_lazy.Z.isel(time=0).drop_vars("time")
184+
dz_lazy = z_levels_lazy.diff(dim="plan")
185+
xr.testing.assert_allclose(dz_eager, dz_lazy, rtol=1e-3)

xarray_selafin/xarray_backend.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def write_serafin(fout, ds):
165165
else:
166166
temp[iv] = ds.isel(time=it)[var].values
167167
if slf_header.nb_planes > 1:
168-
temp[iv] = np.reshape(np.ravel(temp[iv]), (slf_header.nb_planes, slf_header.nb_nodes_2d))
168+
temp[iv] = np.reshape(
169+
np.ravel(temp[iv]), (slf_header.nb_planes, slf_header.nb_nodes_2d)
170+
)
169171
resout.write_entire_frame(
170172
slf_header,
171173
t_,
@@ -190,70 +192,66 @@ def __getitem__(self, key):
190192
)
191193

192194
def _raw_indexing_method(self, key):
193-
if isinstance(key, tuple):
195+
ndim = len(self.shape)
196+
if ndim not in (2, 3):
197+
raise NotImplementedError(f"Unsupported SELAFIN shape {self.shape}")
198+
199+
if not isinstance(key, tuple):
200+
raise NotImplementedError("SELAFIN access must use tuple indexing")
201+
202+
# --- Parse keys
203+
if ndim == 3:
204+
# 3D file
194205
if len(key) == 3:
195-
time_key, node_key, plan_key = key
206+
time_key, plan_key, node_key = key
196207
elif len(key) == 2:
197208
time_key, node_key = key
198-
plan_key = None
209+
plan_key = slice(None)
199210
else:
200-
raise NotImplementedError
201-
else:
202-
raise NotImplementedError
203-
204-
# Convert time_key and node_key to ranges to handle steps and to list indices for SELAFIN reader
205-
if isinstance(time_key, slice):
206-
time_indices = range(*time_key.indices(self.shape[0]))
207-
elif isinstance(time_key, int):
208-
time_indices = [time_key]
211+
raise NotImplementedError("Only (time, plan, node) or (time, node) supported for 3D files")
209212
else:
210-
raise ValueError("time_key must be an integer or slice")
213+
# 2D file
214+
if len(key) == 2:
215+
time_key, node_key = key
216+
elif len(key) == 1:
217+
time_key = key[0]
218+
node_key = slice(None)
219+
else:
220+
raise NotImplementedError("Only (time, node) supported for 2D files")
221+
222+
# --- helper
223+
def _range_from_key(k, n):
224+
if isinstance(k, slice):
225+
return range(*k.indices(n))
226+
elif isinstance(k, int):
227+
return [k]
228+
else:
229+
raise ValueError("index must be int or slice")
211230

212-
if isinstance(node_key, slice):
213-
node_indices = range(*node_key.indices(self.shape[1]))
214-
elif isinstance(node_key, int):
215-
node_indices = [node_key]
216-
else:
217-
raise ValueError("node_key must be an integer or slice")
231+
time_indices = _range_from_key(time_key, self.shape[0])
218232

219-
if plan_key is not None:
220-
if isinstance(plan_key, slice):
221-
plan_indices = range(*plan_key.indices(self.shape[2]))
222-
elif isinstance(plan_key, int):
223-
plan_indices = [plan_key]
224-
else:
225-
raise ValueError("plan_key must be an integer or slice")
226-
data_shape = (len(time_indices), len(node_indices), len(plan_indices))
233+
if ndim == 3:
234+
plan_indices = _range_from_key(plan_key, self.shape[1])
235+
node_indices = _range_from_key(node_key, self.shape[2])
236+
data_shape = (len(time_indices), len(plan_indices), len(node_indices))
227237
else:
238+
node_indices = _range_from_key(node_key, self.shape[1])
228239
data_shape = (len(time_indices), len(node_indices))
229240

230-
# Initialize data array to hold the result
231241
data = np.empty(data_shape, dtype=self.dtype)
232242

233-
# Iterate over the time indices to read the required time steps
234243
for it, t in enumerate(time_indices):
235-
temp = self.slf_reader.read_var_in_frame(t, self.var) # shape = (nb_nodes,)
236-
temp = np.reshape(temp, self.shape[1:]) # shape = (nb_nodes_2d, nb_planes)
237-
if node_key == slice(None) and plan_key == slice(None): # speedup if not selection
238-
data[it] = temp
244+
temp = self.slf_reader.read_var_in_frame(t, self.var)
245+
if ndim == 3:
246+
temp = np.reshape(temp, (self.shape[1], self.shape[2])) # (nplan, nnode)
247+
values = temp[np.ix_(plan_indices, node_indices)]
248+
data[it, :, :] = values
239249
else:
240-
if plan_key is None:
241-
data[it] = temp[node_indices]
242-
else:
243-
values = temp[node_indices][:, plan_indices]
244-
data[it] = np.reshape(values, (len(plan_indices), len(node_indices))).T
250+
temp = np.asarray(temp) # (nnode,)
251+
values = temp[node_indices]
252+
data[it, :] = values
245253

246-
# Remove dimension if key was an integer
247-
if isinstance(node_key, int):
248-
if plan_key is None:
249-
data = data[:, 0]
250-
else:
251-
data = data[:, 0, :]
252-
if isinstance(plan_key, int):
253-
data = data[..., 0]
254-
if isinstance(time_key, int):
255-
data = data[0, ...]
256-
return data
254+
return data.squeeze()
257255

258256

259257
class SelafinBackendEntrypoint(BackendEntrypoint):
@@ -321,6 +319,7 @@ def open_dataset(
321319
# Avoid a ResourceWarning (unclosed file)
322320
def close():
323321
slf.__exit__()
322+
324323
ds.set_close(close)
325324

326325
ds.attrs["title"] = slf.header.title.decode(Serafin.SLF_EIT).strip()

0 commit comments

Comments
 (0)