From b8b7857c9c7980b717fb8f7ae3a1e023a28971d3 Mon Sep 17 00:00:00 2001 From: Jens Hedegaard Nielsen Date: Sun, 3 Dec 2023 19:24:54 +0100 Subject: [PATCH] Add extra overload for to_netcdf (#8268) * Add extra overload for to_netcdf The current signature does not match with pyright if a non literal bool is passed. * fix typing * add entry to whats-new --------- Co-authored-by: Michael Niklas --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 56 ++++++++++++++++++++++++++++++++++++++++ xarray/core/dataarray.py | 25 +++++++++++++++--- xarray/core/dataset.py | 25 +++++++++++++++--- 4 files changed, 101 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index bfe0a1b9be5..c907458a6a2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,9 @@ Bug fixes - Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit` was changed to ``Mapping`` (:pull:`8502`). By `Michael Niklas `_. +- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf` + when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`). + By `Jens Hedegaard Nielsen `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c59f2f8d81b..1d538bf94ed 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1160,6 +1160,62 @@ def to_netcdf( ... +# if compute cannot be evaluated at type check time +# we may get back either Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: Literal[False] = False, + invalid_netcdf: bool = False, +) -> Delayed | None: + ... + + +# if multifile cannot be evaluated at type check time +# we may get back either writer and datastore or Delayed or None +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: + ... + + +# Any +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = False, + multifile: bool = False, + invalid_netcdf: bool = False, +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: + ... + + def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1d7e82d3044..c8cc579c8b7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3914,6 +3914,23 @@ def to_netcdf( ) -> bytes: ... + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: + ... + # default return None @overload def to_netcdf( @@ -3930,7 +3947,8 @@ def to_netcdf( ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -3941,10 +3959,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: + ) -> Delayed | None: ... def to_netcdf( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b8093d3dd78..b430b8fddd8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2159,6 +2159,23 @@ def to_netcdf( ) -> bytes: ... + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Any, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: + ... + # default return None @overload def to_netcdf( @@ -2175,7 +2192,8 @@ def to_netcdf( ) -> None: ... - # compute=False returns dask.Delayed + # if compute cannot be evaluated at type check time + # we may get back either Delayed or None @overload def to_netcdf( self, @@ -2186,10 +2204,9 @@ def to_netcdf( engine: T_NetcdfEngine | None = None, encoding: Mapping[Any, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, - *, - compute: Literal[False], + compute: bool = True, invalid_netcdf: bool = False, - ) -> Delayed: + ) -> Delayed | None: ... def to_netcdf(