Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/data/r3d_bump.slf
Binary file not shown.
20 changes: 18 additions & 2 deletions tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import xarray as xr
from scipy.spatial import Delaunay

BUMP = pytest.mark.parametrize(
"slf_in",
[
pytest.param("tests/data/r3d_bump.slf", id="3D"),
],
)

TIDAL_FLATS = pytest.mark.parametrize(
"slf_in",
[
Expand Down Expand Up @@ -80,7 +87,6 @@ def test_to_netcdf(tmp_path, slf_in):
@TIDAL_FLATS
def test_to_selafin(tmp_path, slf_in):
with xr.open_dataset(slf_in, engine="selafin") as ds_slf:

# Remove some data which is rebuilt
del ds_slf.attrs["date_start"]

Expand All @@ -98,7 +104,6 @@ def test_to_selafin(tmp_path, slf_in):
@TIDAL_FLATS
def test_to_selafin_eager_mode(tmp_path, slf_in):
with xr.open_dataset(slf_in, lazy_loading=False, engine="selafin") as ds_slf:

# Remove some data which is rebuilt
del ds_slf.attrs["date_start"]

Expand Down Expand Up @@ -167,3 +172,14 @@ def test_from_scratch(tmp_path):
def test_dim(slf_in):
with xr.open_dataset(slf_in, engine="selafin") as ds:
repr(ds)


@BUMP
def test_eager_vs_lazy(slf_in):
with xr.load_dataset(slf_in, engine="selafin") as ds_eager:
z_levels_eager = ds_eager.Z.isel(time=0).drop_vars("time")
dz_eager = z_levels_eager.diff(dim="plan")
with xr.open_dataset(slf_in, engine="selafin") as ds_lazy:
z_levels_lazy = ds_lazy.Z.isel(time=0).drop_vars("time")
dz_lazy = z_levels_lazy.diff(dim="plan")
xr.testing.assert_allclose(dz_eager, dz_lazy, rtol=1e-3)
99 changes: 49 additions & 50 deletions xarray_selafin/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def write_serafin(fout, ds):
else:
temp[iv] = ds.isel(time=it)[var].values
if slf_header.nb_planes > 1:
temp[iv] = np.reshape(np.ravel(temp[iv]), (slf_header.nb_planes, slf_header.nb_nodes_2d))
temp[iv] = np.reshape(
np.ravel(temp[iv]), (slf_header.nb_planes, slf_header.nb_nodes_2d)
)
resout.write_entire_frame(
slf_header,
t_,
Expand All @@ -190,70 +192,66 @@ def __getitem__(self, key):
)

def _raw_indexing_method(self, key):
if isinstance(key, tuple):
ndim = len(self.shape)
if ndim not in (2, 3):
raise NotImplementedError(f"Unsupported SELAFIN shape {self.shape}")

if not isinstance(key, tuple):
raise NotImplementedError("SELAFIN access must use tuple indexing")

# --- Parse keys
if ndim == 3:
# 3D file
if len(key) == 3:
time_key, node_key, plan_key = key
time_key, plan_key, node_key = key
elif len(key) == 2:
time_key, node_key = key
plan_key = None
plan_key = slice(None)
else:
raise NotImplementedError
else:
raise NotImplementedError

# Convert time_key and node_key to ranges to handle steps and to list indices for SELAFIN reader
if isinstance(time_key, slice):
time_indices = range(*time_key.indices(self.shape[0]))
elif isinstance(time_key, int):
time_indices = [time_key]
raise NotImplementedError("Only (time, plan, node) or (time, node) supported for 3D files")
else:
raise ValueError("time_key must be an integer or slice")
# 2D file
if len(key) == 2:
time_key, node_key = key
elif len(key) == 1:
time_key = key[0]
node_key = slice(None)
else:
raise NotImplementedError("Only (time, node) supported for 2D files")

# --- helper
def _range_from_key(k, n):
if isinstance(k, slice):
return range(*k.indices(n))
elif isinstance(k, int):
return [k]
else:
raise ValueError("index must be int or slice")

if isinstance(node_key, slice):
node_indices = range(*node_key.indices(self.shape[1]))
elif isinstance(node_key, int):
node_indices = [node_key]
else:
raise ValueError("node_key must be an integer or slice")
time_indices = _range_from_key(time_key, self.shape[0])

if plan_key is not None:
if isinstance(plan_key, slice):
plan_indices = range(*plan_key.indices(self.shape[2]))
elif isinstance(plan_key, int):
plan_indices = [plan_key]
else:
raise ValueError("plan_key must be an integer or slice")
data_shape = (len(time_indices), len(node_indices), len(plan_indices))
if ndim == 3:
plan_indices = _range_from_key(plan_key, self.shape[1])
node_indices = _range_from_key(node_key, self.shape[2])
data_shape = (len(time_indices), len(plan_indices), len(node_indices))
else:
node_indices = _range_from_key(node_key, self.shape[1])
data_shape = (len(time_indices), len(node_indices))

# Initialize data array to hold the result
data = np.empty(data_shape, dtype=self.dtype)

# Iterate over the time indices to read the required time steps
for it, t in enumerate(time_indices):
temp = self.slf_reader.read_var_in_frame(t, self.var) # shape = (nb_nodes,)
temp = np.reshape(temp, self.shape[1:]) # shape = (nb_nodes_2d, nb_planes)
if node_key == slice(None) and plan_key == slice(None): # speedup if not selection
data[it] = temp
temp = self.slf_reader.read_var_in_frame(t, self.var)
if ndim == 3:
temp = np.reshape(temp, (self.shape[1], self.shape[2])) # (nplan, nnode)
values = temp[np.ix_(plan_indices, node_indices)]
data[it, :, :] = values
else:
if plan_key is None:
data[it] = temp[node_indices]
else:
values = temp[node_indices][:, plan_indices]
data[it] = np.reshape(values, (len(plan_indices), len(node_indices))).T
temp = np.asarray(temp) # (nnode,)
values = temp[node_indices]
data[it, :] = values

# Remove dimension if key was an integer
if isinstance(node_key, int):
if plan_key is None:
data = data[:, 0]
else:
data = data[:, 0, :]
if isinstance(plan_key, int):
data = data[..., 0]
if isinstance(time_key, int):
data = data[0, ...]
return data
return data.squeeze()


class SelafinBackendEntrypoint(BackendEntrypoint):
Expand Down Expand Up @@ -321,6 +319,7 @@ def open_dataset(
# Avoid a ResourceWarning (unclosed file)
def close():
slf.__exit__()

ds.set_close(close)

ds.attrs["title"] = slf.header.title.decode(Serafin.SLF_EIT).strip()
Expand Down