Skip to content

Commit

Permalink
fix some mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rhkleijn committed Jul 3, 2022
1 parent 7d598fb commit 72ea77f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
14 changes: 8 additions & 6 deletions xarray/core/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar

from . import duck_array_ops
from .options import OPTIONS
Expand All @@ -12,7 +12,9 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataArray, T_Dataset

T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions")

try:
import flox
Expand All @@ -24,15 +26,15 @@ class DatasetReductions:
__slots__ = ()

def reduce(
self,
self: T_Dataset,
func: Callable[..., Any],
dim: None | Hashable | Sequence[Hashable] = None,
*,
axis: None | int | Sequence[int] = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
) -> Dataset:
) -> T_Dataset:
raise NotImplementedError()

def count(
Expand Down Expand Up @@ -1034,15 +1036,15 @@ class DataArrayReductions:
__slots__ = ()

def reduce(
self,
self: T_DataArray,
func: Callable[..., Any],
dim: None | Hashable | Sequence[Hashable] = None,
*,
axis: None | int | Sequence[int] = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
) -> DataArray:
) -> T_DataArray:
raise NotImplementedError()

def count(
Expand Down
9 changes: 5 additions & 4 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Hashable,
Iterator,
Mapping,
Sequence,
TypeVar,
)

Expand Down Expand Up @@ -468,9 +469,8 @@ def reduce(
obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
)

result = windows.reduce(
func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
)
dim: Sequence[Hashable] = list(rolling_dim.values())
result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs)

# Find valid windows based on count.
counts = self._counts(keep_attrs=False)
Expand All @@ -487,14 +487,15 @@ def _counts(self, keep_attrs: bool | None) -> DataArray:
# array is faster to be reduced than object array.
# The use of skipna==False is also faster since it does not need to
# copy the strided array.
dim: Sequence[Hashable] = list(rolling_dim.values())
counts = (
self.obj.notnull(keep_attrs=keep_attrs)
.rolling(
{d: w for d, w in zip(self.dim, self.window)},
center={d: self.center[i] for i, d in enumerate(self.dim)},
)
.construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
.sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs)
.sum(dim=dim, skipna=False, keep_attrs=keep_attrs)
)
return counts

Expand Down
25 changes: 14 additions & 11 deletions xarray/util/generate_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence, TypeVar
from . import duck_array_ops
from .options import OPTIONS
Expand All @@ -31,7 +31,9 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataArray, T_Dataset
T_Dataset = TypeVar("T_Dataset", bound="DatasetReductions")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayReductions")
try:
import flox
Expand All @@ -44,15 +46,15 @@ class {obj}{cls}Reductions:
__slots__ = ()
def reduce(
self,
self{self_type_snippet},
func: Callable[..., Any],
dim: None | Hashable | Sequence[Hashable] = None,
*,
axis: None | int | Sequence[int] = None,
keep_attrs: bool | None = None,
keepdims: bool = False,
**kwargs: Any,
) -> {obj}:
) -> {return_type}:
raise NotImplementedError()"""

GROUPBY_PREAMBLE = """
Expand Down Expand Up @@ -246,7 +248,13 @@ def __init__(
self.docref = docref
self.docref_description = docref_description
self.example_call_preamble = example_call_preamble
self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls)
self.common_kwargs = dict(
obj=self.datastructure.name,
cls=cls,
self_type_snippet=": " + self.self_type if self.self_type else "",
return_type=self.self_type if self.self_type else self.datastructure.name,
)
self.preamble = definition_preamble.format(**self.common_kwargs)
if not see_also_obj:
self.see_also_obj = self.datastructure.name
else:
Expand All @@ -258,12 +266,7 @@ def generate_methods(self):
yield self.generate_method(method)

def generate_method(self, method):
template_kwargs = dict(
obj=self.datastructure.name,
method=method.name,
self_type_snippet=": " + self.self_type if self.self_type else "",
return_type=self.self_type if self.self_type else self.datastructure.name,
)
template_kwargs = dict(self.common_kwargs, method=method.name)

if method.extra_kwargs:
extra_kwargs = "\n " + "\n ".join(
Expand Down

0 comments on commit 72ea77f

Please sign in to comment.