diff --git a/xarray/core/_reductions.py b/xarray/core/_reductions.py index d312bd41ec5..b4a67ea43bf 100644 --- a/xarray/core/_reductions.py +++ b/xarray/core/_reductions.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar from . import duck_array_ops from .options import OPTIONS @@ -12,7 +12,9 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import T_DataArray, T_Dataset + + T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions") + T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions") try: import flox @@ -24,7 +26,7 @@ class DatasetReductions: __slots__ = () def reduce( - self, + self: T_Dataset, func: Callable[..., Any], dim: None | Hashable | Sequence[Hashable] = None, *, @@ -32,7 +34,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> Dataset: + ) -> T_Dataset: raise NotImplementedError() def count( @@ -1034,7 +1036,7 @@ class DataArrayReductions: __slots__ = () def reduce( - self, + self: T_DataArray, func: Callable[..., Any], dim: None | Hashable | Sequence[Hashable] = None, *, @@ -1042,7 +1044,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> DataArray: + ) -> T_DataArray: raise NotImplementedError() def count( diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index aef290f6d7f..9600009aee6 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -12,6 +12,7 @@ Hashable, Iterator, Mapping, + Sequence, TypeVar, ) @@ -468,9 +469,8 @@ def reduce( obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna ) - result = windows.reduce( - func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs - ) + dim: Sequence[Hashable] = list(rolling_dim.values()) + result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs) # Find valid windows based on count. counts = self._counts(keep_attrs=False) @@ -487,6 +487,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. + dim: Sequence[Hashable] = list(rolling_dim.values()) counts = ( self.obj.notnull(keep_attrs=keep_attrs) .rolling( @@ -494,7 +495,7 @@ def _counts(self, keep_attrs: bool | None) -> DataArray: center={d: self.center[i] for i, d in enumerate(self.dim)}, ) .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) - .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) + .sum(dim=dim, skipna=False, keep_attrs=keep_attrs) ) return counts diff --git a/xarray/util/generate_reductions.py b/xarray/util/generate_reductions.py index 63e08f2f570..ad225b55481 100644 --- a/xarray/util/generate_reductions.py +++ b/xarray/util/generate_reductions.py @@ -22,7 +22,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar from . import duck_array_ops from .options import OPTIONS @@ -31,7 +31,9 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import T_DataArray, T_Dataset + + T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions") + T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions") try: import flox @@ -44,7 +46,7 @@ class {obj}{cls}Reductions: __slots__ = () def reduce( - self, + self{self_type_snippet}, func: Callable[..., Any], dim: None | Hashable | Sequence[Hashable] = None, *, @@ -52,7 +54,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> {obj}: + ) -> {return_type}: raise NotImplementedError()""" GROUPBY_PREAMBLE = """ @@ -246,7 +248,13 @@ def __init__( self.docref = docref self.docref_description = docref_description self.example_call_preamble = example_call_preamble - self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) + self.common_kwargs = dict( + obj=self.datastructure.name, + cls=cls, + self_type_snippet=": " + self.self_type if self.self_type else "", + return_type=self.self_type if self.self_type else self.datastructure.name, + ) + self.preamble = definition_preamble.format(**self.common_kwargs) if not see_also_obj: self.see_also_obj = self.datastructure.name else: @@ -258,12 +266,7 @@ def generate_methods(self): yield self.generate_method(method) def generate_method(self, method): - template_kwargs = dict( - obj=self.datastructure.name, - method=method.name, - self_type_snippet=": " + self.self_type if self.self_type else "", - return_type=self.self_type if self.self_type else self.datastructure.name, - ) + template_kwargs = dict(self.common_kwargs, method=method.name) if method.extra_kwargs: extra_kwargs = "\n " + "\n ".join(