Skip to content

Commit

Permalink
Add extra overload for to_netcdf (#8268)
Browse files Browse the repository at this point in the history
* Add extra overload for to_netcdf

The current signature does not match with pyright
if a non literal bool is passed.

* fix typing

* add entry to whats-new

---------

Co-authored-by: Michael Niklas <[email protected]>
  • Loading branch information
jenshnielsen and headtr1ck authored Dec 3, 2023
1 parent d44bfd7 commit b8b7857
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 8 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ Bug fixes
- Static typing of ``p0`` and ``bounds`` arguments of :py:func:`xarray.DataArray.curvefit` and :py:func:`xarray.Dataset.curvefit`
was changed to ``Mapping`` (:pull:`8502`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Fix typing of :py:func:`xarray.DataArray.to_netcdf` and :py:func:`xarray.Dataset.to_netcdf`
when ``compute`` is evaluated to bool instead of a Literal (:pull:`8268`).
By `Jens Hedegaard Nielsen <https://github.com/jenshnielsen>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
56 changes: 56 additions & 0 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,62 @@ def to_netcdf(
...


# if compute cannot be evaluated at type check time
# we may get back either Delayed or None
@overload
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
compute: bool = False,
multifile: Literal[False] = False,
invalid_netcdf: bool = False,
) -> Delayed | None:
...


# if multifile cannot be evaluated at type check time
# we may get back either writer and datastore or Delayed or None
@overload
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike,
mode: Literal["w", "a"] = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
compute: bool = False,
multifile: bool = False,
invalid_netcdf: bool = False,
) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None:
...


# Any
@overload
def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None,
mode: Literal["w", "a"] = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
compute: bool = False,
multifile: bool = False,
invalid_netcdf: bool = False,
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
...


def to_netcdf(
dataset: Dataset,
path_or_file: str | os.PathLike | None = None,
Expand Down
25 changes: 21 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3914,6 +3914,23 @@ def to_netcdf(
) -> bytes:
...

# compute=False returns dask.Delayed
@overload
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
*,
compute: Literal[False],
invalid_netcdf: bool = False,
) -> Delayed:
...

# default return None
@overload
def to_netcdf(
Expand All @@ -3930,7 +3947,8 @@ def to_netcdf(
) -> None:
...

# compute=False returns dask.Delayed
# if compute cannot be evaluated at type check time
# we may get back either Delayed or None
@overload
def to_netcdf(
self,
Expand All @@ -3941,10 +3959,9 @@ def to_netcdf(
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
*,
compute: Literal[False],
compute: bool = True,
invalid_netcdf: bool = False,
) -> Delayed:
) -> Delayed | None:
...

def to_netcdf(
Expand Down
25 changes: 21 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,23 @@ def to_netcdf(
) -> bytes:
...

# compute=False returns dask.Delayed
@overload
def to_netcdf(
self,
path: str | PathLike,
mode: Literal["w", "a"] = "w",
format: T_NetcdfTypes | None = None,
group: str | None = None,
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Any, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
*,
compute: Literal[False],
invalid_netcdf: bool = False,
) -> Delayed:
...

# default return None
@overload
def to_netcdf(
Expand All @@ -2175,7 +2192,8 @@ def to_netcdf(
) -> None:
...

# compute=False returns dask.Delayed
# if compute cannot be evaluated at type check time
# we may get back either Delayed or None
@overload
def to_netcdf(
self,
Expand All @@ -2186,10 +2204,9 @@ def to_netcdf(
engine: T_NetcdfEngine | None = None,
encoding: Mapping[Any, Mapping[str, Any]] | None = None,
unlimited_dims: Iterable[Hashable] | None = None,
*,
compute: Literal[False],
compute: bool = True,
invalid_netcdf: bool = False,
) -> Delayed:
) -> Delayed | None:
...

def to_netcdf(
Expand Down

0 comments on commit b8b7857

Please sign in to comment.