Skip to content

Commit

Permalink
Change to use h5netcdf backend for reads (#132)
Browse files Browse the repository at this point in the history
* Replace netCDF4 package dependency with h5netcdf

* Update read/write engine to h5netcdf

The commit adjusts the xarray `open_mfdataset()` and `to_netcdf()` method calls
to use the h5netcdf engine for reading and writing operations. The corresponding
tests were also updated.

It also updates the `info` module and tests to reflect the use of the h5netcdf
package.

* Improve code comment clarity sentence structure

* Enable parallel reading by default

The default configuration for reading files has been changed from sequential to
parallel. This change was implemented by modifying the default value in the
`config.get` method for the "parallel read" field in the `extract.py` file.
This should improve performance for operations that involve reading multiple
files.

* Restore netcdf4 package for use as write engine

When h5netcdf is used as the engine for writes the resulting netCDF4 files are
not read correctly by ncdump or xarray.open_dataset(..., engine="netcdf4").
So, we'll use h5netcdf to read and netcdf4 to write.

Returning netcdf4 as a dependency also restores the ncdump tool to the
environment.

* Restore netcdf4 package dependency in test env

* Explicitly use netcdf4 to write test datasets

This ensures consistency with dataset writing in extract.write_netcdf().

* Include both netcdf4 & h5netcdf in versions info
  • Loading branch information
douglatornell authored Oct 8, 2024
1 parent c580720 commit 9db197c
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 11 deletions.
1 change: 1 addition & 0 deletions envs/environment-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- click
- dask
- flox
- h5netcdf
- netCDF4
- pip
- python=3.12
Expand Down
1 change: 1 addition & 0 deletions envs/environment-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- click
- dask
- flox
- h5netcdf
- netCDF4
- pip
- python
Expand Down
1 change: 1 addition & 0 deletions envs/environment-user.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- click
- dask
- flox
- h5netcdf
- netCDF4
- pip
- python=3.12
Expand Down
3 changes: 3 additions & 0 deletions envs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ black==24.4.2
bokeh==3.4.1
Bottleneck==1.3.8
Brotli==1.1.0
cached-property==1.5.2
certifi==2024.7.4
cffi==1.16.0
cfgv==3.3.1
Expand All @@ -36,6 +37,8 @@ flox==0.9.7
fsspec==2024.5.0
h11==0.14.0
h2==4.1.0
h5netcdf==1.3.0
h5py==3.11.0
hatch==1.11.1
hatchling==1.24.2
hpack==4.0.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"click",
"dask",
"flox",
"netCDF4",
"h5netcdf",
"pyyaml",
"rich",
"structlog",
Expand Down
7 changes: 4 additions & 3 deletions reshapr/core/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,12 @@ def open_dataset(ds_paths, chunk_size, config):
extract_vars = {var for var in config["extract variables"]}
# Use 1st and last dataset paths to calculate the set of all variables
# in the dataset, and from that the set of variables to drop.
# We need to use variables lists from 1st and last datasets in order to avoid issue #51.
# We need to use the variables lists from 1st and last datasets to avoid issue #51.
for ds_path in (ds_paths[0], ds_paths[-1]):
with xarray.open_dataset(ds_path, chunks=chunk_size) as ds:
with xarray.open_dataset(ds_path, chunks=chunk_size, engine="h5netcdf") as ds:
drop_vars.update(var for var in ds.data_vars)
drop_vars -= extract_vars
parallel_read = config.get("parallel read", False)
parallel_read = config.get("parallel read", True)
ds = xarray.open_mfdataset(
ds_paths,
chunks=chunk_size,
Expand Down Expand Up @@ -1226,6 +1226,7 @@ def write_netcdf(extracted_ds, nc_path, encoding, nc_format, unlimited_dim):
format=nc_format,
encoding=encoding,
unlimited_dims=unlimited_dim,
engine="netcdf4",
)
logger.info("wrote netCDF4 file", nc_path=os.fspath(nc_path))

Expand Down
3 changes: 2 additions & 1 deletion reshapr/core/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _basic_info(console):
:param :py:class:`rich.console.Console` console:
"""
versions = {
pkg: metadata.version(pkg) for pkg in ("reshapr", "xarray", "dask", "netcdf4")
pkg: metadata.version(pkg)
for pkg in ("reshapr", "xarray", "dask", "h5netcdf", "netcdf4")
}
for pkg, version in versions.items():
console.print(f"{pkg}, version [magenta]{version}", highlight=False)
Expand Down
2 changes: 2 additions & 0 deletions tests/api/v1/test_extract_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def test_extract_netcdf(self, tmp_path):
format="NETCDF4",
encoding=encoding,
unlimited_dims="time_counter",
engine="netcdf4",
)

model_profile_yaml = tmp_path / "test_profile.yaml"
Expand Down Expand Up @@ -296,6 +297,7 @@ def test_extract_netcdf_with_resampling(self, tmp_path):
format="NETCDF4",
encoding=encoding,
unlimited_dims="time_counter",
engine="netcdf4",
)

model_profile_yaml = tmp_path / "test_profile.yaml"
Expand Down
8 changes: 5 additions & 3 deletions tests/core/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_cli_extract(self, tmp_path):
format="NETCDF4",
encoding=encoding,
unlimited_dims="time_counter",
engine="netcdf4",
)

model_profile_yaml = tmp_path / "test_profile.yaml"
Expand Down Expand Up @@ -298,6 +299,7 @@ def test_cli_extract_with_resampling(self, tmp_path):
format="NETCDF4",
encoding=encoding,
unlimited_dims="time_counter",
engine="netcdf4",
)

model_profile_yaml = tmp_path / "test_profile.yaml"
Expand Down Expand Up @@ -1009,7 +1011,7 @@ def fixture_source_dataset(self):
def test_open_dataset(self, source_dataset, log_output, tmp_path):
results_archive = tmp_path / "results_archive"
results_archive.mkdir()
source_dataset.to_netcdf(results_archive / "test_dataset.nc")
source_dataset.to_netcdf(results_archive / "test_dataset.nc", engine="netcdf4")
ds_paths = [results_archive / "test_dataset.nc"]
chunk_size = {
"time_counter": 4,
Expand All @@ -1033,7 +1035,7 @@ def test_open_dataset(self, source_dataset, log_output, tmp_path):
def test_open_dataset_parallel_read(self, source_dataset, log_output, tmp_path):
results_archive = tmp_path / "results_archive"
results_archive.mkdir()
source_dataset.to_netcdf(results_archive / "test_dataset.nc")
source_dataset.to_netcdf(results_archive / "test_dataset.nc", engine="netcdf4")
ds_paths = [results_archive / "test_dataset.nc"]
chunk_size = {
"time_counter": 4,
Expand Down Expand Up @@ -1064,7 +1066,7 @@ def test_exit_when_no_dataset_vars(self, source_dataset, log_output, tmp_path):
"""
results_archive = tmp_path / "results_archive"
results_archive.mkdir()
source_dataset.to_netcdf(results_archive / "test_dataset.nc")
source_dataset.to_netcdf(results_archive / "test_dataset.nc", engine="netcdf4")
ds_paths = [results_archive / "test_dataset.nc"]
chunk_size = {
"time_counter": 4,
Expand Down
7 changes: 4 additions & 3 deletions tests/core/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class TestBasicInfo:
"""

@pytest.mark.parametrize(
"pkg, line", (("reshapr", 0), ("xarray", 1), ("dask", 2), ("netcdf4", 3))
"pkg, line",
(("reshapr", 0), ("xarray", 1), ("dask", 2), ("h5netcdf", 3), ("netcdf4", 4)),
)
def test_pkg_version(self, pkg, line, capsys):
info.info(cluster_or_model="", time_interval="", vars_group="")
Expand All @@ -51,7 +52,7 @@ def test_cluster_configs(self, capsys):

stdout_lines = capsys.readouterr().out.splitlines()
expected = {"salish_cluster.yaml"}
assert set(line.strip() for line in stdout_lines[6:7]) == expected
assert set(line.strip() for line in stdout_lines[7:8]) == expected

def test_model_profiles(self, capsys):
info.info(cluster_or_model="", time_interval="", vars_group="")
Expand All @@ -69,7 +70,7 @@ def test_model_profiles(self, capsys):
"HRDPS-2.5km-operational.yaml",
}
assert (
set(line.strip() for line in stdout_lines[9 : len(expected) + 9])
set(line.strip() for line in stdout_lines[10 : len(expected) + 10])
== expected
)

Expand Down

0 comments on commit 9db197c

Please sign in to comment.