Skip to content

Commit 0f65307

Browse files
authored
typing for numpy 1.20 (#4878)
* typing for numpy 1.20 * [skip-ci] add whats-new.rst * update formatting * -> np.dtype * fix bug, use Mapping, check for dict-like * enable typing CI * fixes * remove some unnecessary ignores again
1 parent 348eb48 commit 0f65307

File tree

12 files changed

+62
-25
lines changed

12 files changed

+62
-25
lines changed

.github/workflows/ci-additional.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ jobs:
161161
name: Type checking (mypy)
162162
runs-on: "ubuntu-latest"
163163
needs: detect-ci-trigger
164-
if: false && needs.detect-ci-trigger.outputs.triggered == 'false'
164+
if: needs.detect-ci-trigger.outputs.triggered == 'false'
165165
defaults:
166166
run:
167167
shell: bash -l {0}

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Internal Changes
190190
in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn <https://github.com/rhkleijn>`_.
191191
- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release
192192
all resources. (:pull:`#4809`), By `Alessandro Amici <https://github.com/alexamici>`_.
193+
- Update type hints to work with numpy v1.20 (:pull:`4878`). By `Mathias Hauser <https://github.com/mathause>`_.
193194
- Ensure warnings cannot be turned into exceptions in :py:func:`testing.assert_equal` and
194195
the other ``assert_*`` functions (:pull:`4864`). By `Mathias Hauser <https://github.com/mathause>`_.
195196
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.

xarray/core/accessor_dt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
is_np_datetime_like,
1010
is_np_timedelta_like,
1111
)
12+
from .npcompat import DTypeLike
1213
from .pycompat import is_duck_dask_array
1314

1415

@@ -178,8 +179,9 @@ class Properties:
178179
def __init__(self, obj):
179180
self._obj = obj
180181

181-
def _tslib_field_accessor( # type: ignore
182-
name: str, docstring: str = None, dtype: np.dtype = None
182+
@staticmethod
183+
def _tslib_field_accessor(
184+
name: str, docstring: str = None, dtype: DTypeLike = None
183185
):
184186
def f(self, dtype=dtype):
185187
if dtype is None:

xarray/core/common.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Tuple,
1717
TypeVar,
1818
Union,
19+
overload,
1920
)
2021

2122
import numpy as np
@@ -35,6 +36,8 @@
3536

3637
if TYPE_CHECKING:
3738
from .dataarray import DataArray
39+
from .dataset import Dataset
40+
from .variable import Variable
3841
from .weighted import Weighted
3942

4043
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
@@ -1501,7 +1504,26 @@ def __getitem__(self, value):
15011504
raise NotImplementedError()
15021505

15031506

1504-
def full_like(other, fill_value, dtype: DTypeLike = None):
1507+
@overload
1508+
def full_like(
1509+
other: "Dataset",
1510+
fill_value,
1511+
dtype: Union[DTypeLike, Mapping[Hashable, DTypeLike]] = None,
1512+
) -> "Dataset":
1513+
...
1514+
1515+
1516+
@overload
1517+
def full_like(other: "DataArray", fill_value, dtype: DTypeLike = None) -> "DataArray":
1518+
...
1519+
1520+
1521+
@overload
1522+
def full_like(other: "Variable", fill_value, dtype: DTypeLike = None) -> "Variable":
1523+
...
1524+
1525+
1526+
def full_like(other, fill_value, dtype=None):
15051527
"""Return a new object with the same shape and type as a given object.
15061528
15071529
Parameters
@@ -1618,15 +1640,22 @@ def full_like(other, fill_value, dtype: DTypeLike = None):
16181640
f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
16191641
)
16201642

1643+
if not isinstance(other, Dataset) and isinstance(dtype, Mapping):
1644+
raise ValueError(
1645+
"'dtype' cannot be dict-like when passing a DataArray or Variable"
1646+
)
1647+
16211648
if isinstance(other, Dataset):
16221649
if not isinstance(fill_value, dict):
16231650
fill_value = {k: fill_value for k in other.data_vars.keys()}
16241651

1625-
if not isinstance(dtype, dict):
1626-
dtype = {k: dtype for k in other.data_vars.keys()}
1652+
if not isinstance(dtype, Mapping):
1653+
dtype_ = {k: dtype for k in other.data_vars.keys()}
1654+
else:
1655+
dtype_ = dtype
16271656

16281657
data_vars = {
1629-
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype.get(k, None))
1658+
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None))
16301659
for k, v in other.data_vars.items()
16311660
}
16321661
return Dataset(data_vars, coords=other.coords, attrs=other.attrs)

xarray/core/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4311,7 +4311,7 @@ def dropna(
43114311
subset = iter(self.data_vars)
43124312

43134313
count = np.zeros(self.dims[dim], dtype=np.int64)
4314-
size = 0
4314+
size = np.int_(0) # for type checking
43154315

43164316
for k in subset:
43174317
array = self._variables[k]
@@ -6370,7 +6370,7 @@ def polyfit(
63706370
lhs = np.vander(x, order)
63716371

63726372
if rcond is None:
6373-
rcond = x.shape[0] * np.core.finfo(x.dtype).eps
6373+
rcond = x.shape[0] * np.core.finfo(x.dtype).eps # type: ignore
63746374

63756375
# Weights:
63766376
if w is not None:
@@ -6414,7 +6414,7 @@ def polyfit(
64146414
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
64156415
skipna_da = True
64166416
elif skipna is None:
6417-
skipna_da = np.any(da.isnull())
6417+
skipna_da = bool(np.any(da.isnull()))
64186418

64196419
dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
64206420
stacked_coords: Dict[Hashable, DataArray] = {}

xarray/core/formatting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,8 @@ def format_array_flat(array, max_width: int):
189189
(max_possibly_relevant < array.size) or (cum_len > max_width).any()
190190
):
191191
padding = " ... "
192-
count = min(
193-
array.size, max(np.argmax(cum_len + len(padding) - 1 > max_width), 2)
194-
)
192+
max_len = max(np.argmax(cum_len + len(padding) - 1 > max_width), 2) # type: ignore
193+
count = min(array.size, max_len)
195194
else:
196195
count = array.size
197196
padding = "" if (count <= 1) else " "

xarray/core/indexing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from contextlib import suppress
66
from datetime import timedelta
7-
from typing import Any, Callable, Iterable, Sequence, Tuple, Union
7+
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
88

99
import numpy as np
1010
import pandas as pd
@@ -1010,7 +1010,7 @@ def _decompose_outer_indexer(
10101010
return indexer, BasicIndexer(())
10111011
assert isinstance(indexer, (OuterIndexer, BasicIndexer))
10121012

1013-
backend_indexer = []
1013+
backend_indexer: List[Any] = []
10141014
np_indexer = []
10151015
# make indexer positive
10161016
pos_indexer = []
@@ -1397,17 +1397,17 @@ def __init__(self, array: Any, dtype: DTypeLike = None):
13971397
self.array = utils.safe_cast_to_index(array)
13981398
if dtype is None:
13991399
if isinstance(array, pd.PeriodIndex):
1400-
dtype = np.dtype("O")
1400+
dtype_ = np.dtype("O")
14011401
elif hasattr(array, "categories"):
14021402
# category isn't a real numpy dtype
1403-
dtype = array.categories.dtype
1403+
dtype_ = array.categories.dtype
14041404
elif not utils.is_valid_numpy_dtype(array.dtype):
1405-
dtype = np.dtype("O")
1405+
dtype_ = np.dtype("O")
14061406
else:
1407-
dtype = array.dtype
1407+
dtype_ = array.dtype
14081408
else:
1409-
dtype = np.dtype(dtype)
1410-
self._dtype = dtype
1409+
dtype_ = np.dtype(dtype)
1410+
self._dtype = dtype_
14111411

14121412
@property
14131413
def dtype(self) -> np.dtype:

xarray/core/npcompat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def moveaxis(a, source, destination):
7575
return result
7676

7777

78-
# Type annotations stubs.
78+
# Type annotations stubs
7979
try:
8080
from numpy.typing import DTypeLike
8181
except ImportError:
8282
# fall back for numpy < 1.20
83-
DTypeLike = Union[np.dtype, str]
83+
DTypeLike = Union[np.dtype, str] # type: ignore
8484

8585

8686
# from dask/array/utils.py

xarray/core/nputils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
import pandas as pd
5-
from numpy.core.multiarray import normalize_axis_index
5+
from numpy.core.multiarray import normalize_axis_index # type: ignore
66

77
try:
88
import bottleneck as bn

xarray/tests/test_cftime_offsets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def test_minus_offset(a, b):
479479

480480
@pytest.mark.parametrize(
481481
("a", "b"),
482-
list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B))
482+
list(zip(np.roll(_EQ_TESTS_A, 1), _EQ_TESTS_B)) # type: ignore
483483
+ [(YearEnd(month=1), YearEnd(month=2))],
484484
ids=_id_func,
485485
)

0 commit comments

Comments
 (0)