diff --git a/doc/whats-new.rst b/doc/whats-new.rst index af47a096697..acb81f3692a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ Bug fixes - Fix error when encoding an empty :py:class:`numpy.datetime64` array (:issue:`10722`, :pull:`10723`). By `Spencer Clark `_. +- Fix error from ``to_netcdf(..., compute=False)`` when using Dask Distributed + (:issue:`10725`). + By `Stephan Hoyer `_. - Propagation coordinate attrs in :py:meth:`xarray.Dataset.map` (:issue:`9317`, :pull:`10602`). By `Justus Magin `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 48d1a4c5135..6822fbf4f57 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1858,6 +1858,20 @@ def open_mfdataset( return combined +def _get_netcdf_autoclose(dataset: Dataset, engine: T_NetcdfEngine) -> bool: + """Should we close files after each write operations?""" + scheduler = get_dask_scheduler() + have_chunks = any(v.chunks is not None for v in dataset.variables.values()) + + autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] + if autoclose and engine == "scipy": + raise NotImplementedError( + f"Writing netCDF files with the {engine} backend " + f"is not currently supported with dask's {scheduler} scheduler" + ) + return autoclose + + WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, @@ -2064,16 +2078,7 @@ def to_netcdf( # sanitize unlimited_dims unlimited_dims = _sanitize_unlimited_dims(dataset, unlimited_dims) - # handle scheduler specific logic - scheduler = get_dask_scheduler() - have_chunks = any(v.chunks is not None for v in dataset.variables.values()) - - autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"] - if autoclose and engine == "scipy": - raise NotImplementedError( - f"Writing netCDF files with the {engine} backend " - f"is not currently supported with dask's {scheduler} scheduler" - ) + autoclose = _get_netcdf_autoclose(dataset, engine) if path_or_file is None: if not compute: @@ -2116,7 +2121,7 @@ def to_netcdf( writes = writer.sync(compute=compute) finally: - if not multifile: + if not multifile and not autoclose: # type: ignore[redundant-expr,unused-ignore] if compute: store.close() else: diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 9ae83bc2664..db17a2c13df 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -85,11 +85,13 @@ def tmp_netcdf_filename(tmpdir): @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) +@pytest.mark.parametrize("compute", [True, False]) def test_dask_distributed_netcdf_roundtrip( loop, # noqa: F811 tmp_netcdf_filename, engine, nc_format, + compute, ): if engine not in ENGINES: pytest.skip("engine not available") @@ -107,7 +109,11 @@ def test_dask_distributed_netcdf_roundtrip( ) return - original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) + result = original.to_netcdf( + tmp_netcdf_filename, engine=engine, format=nc_format, compute=compute + ) + if not compute: + result.compute() with xr.open_dataset( tmp_netcdf_filename, chunks=chunks, engine=engine