diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 57b17385558..2b9fbab76d4 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -68,8 +68,8 @@ def dtype(self) -> _DType_co: ... _IntOrUnknown = int _Shape = tuple[_IntOrUnknown, ...] _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] -_ShapeType = TypeVar("_ShapeType", bound=Any) -_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) +_ShapeType = TypeVar("_ShapeType", bound=_Shape) +_ShapeType_co = TypeVar("_ShapeType_co", bound=_Shape, covariant=True) _Axis = int _Axes = tuple[_Axis, ...] @@ -117,7 +117,7 @@ class _array(Protocol[_ShapeType_co, _DType_co]): """ @property - def shape(self) -> _Shape: ... + def shape(self) -> _ShapeType_co: ... @property def dtype(self) -> _DType_co: ... diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c5841f6913e..38cfa2a54d6 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -249,14 +249,14 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]): __slots__ = ("_data", "_dims", "_attrs") - _data: duckarray[Any, _DType_co] + _data: duckarray[_ShapeType_co, _DType_co] _dims: _Dims _attrs: dict[Any, Any] | None def __init__( self, dims: _DimsLike, - data: duckarray[Any, _DType_co], + data: duckarray[_ShapeType_co, _DType_co], attrs: _AttrsLike = None, ): self._data = data @@ -291,7 +291,7 @@ def _new( def _new( self, dims: _DimsLike | Default = _default, - data: duckarray[Any, _DType] | Default = _default, + data: duckarray[_ShapeType, _DType] | Default = _default, attrs: _AttrsLike | Default = _default, ) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]: """ @@ -446,7 +446,7 @@ def dtype(self) -> _DType_co: return self._data.dtype @property - def shape(self) -> _Shape: + def shape(self) -> _ShapeType_co: """ Get the shape of the array. @@ -849,9 +849,9 @@ def to_numpy(self) -> np.ndarray[Any, Any]: # TODO an entrypoint so array libraries can choose coercion method? return to_numpy(self._data) - def as_numpy(self) -> Self: + def as_numpy(self) -> NamedArray[Any, Any]: """Coerces wrapped data into a numpy array, returning a Variable.""" - return self._replace(data=self.to_numpy()) + return self._new(data=self.to_numpy()) def reduce( self, @@ -1162,3 +1162,102 @@ def _raise_if_any_duplicate_dimensions( raise ValueError( f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" ) + + +# # %% function should pass + +# data = np.array([1, 2, 3], dtype=np.dtype(np.int64)) +# # data: duckarray[Any, np.dtype[np.int64]] = np.array([1, 2, 3], dtype=np.dtype(np.int64)) +# reveal_type(data) + + +# def test( +# data: duckarray[_ShapeType_co, _DType_co] +# ) -> duckarray[_ShapeType_co, _DType_co]: +# return data + + +# def test2( +# data: _arrayfunction[_ShapeType, _DType] +# ) -> _arrayfunction[_ShapeType, _DType]: +# return data + + +# b = test(data) +# reveal_type(b) +# c = test2(data) +# reveal_type(c) +# a = NamedArray(("time",), data=data) +# reveal_type(a) + + +# # %% Class should pass +# from typing import Generic, TypeVar, Protocol, Union + +# _ST = TypeVar("_ST", bound=Any, covariant=True) +# _DT = TypeVar("_DT", bound=Any, covariant=True) + + +# # Valid numpy protocol: +# class ArrayA(Protocol[_ST, _DT]): +# @property +# def dtype(self) -> _DT: ... +# @property +# def shape(self) -> _ST: ... + + +# class TestArray(Generic[_ST, _DT]): +# __slots__ = ("_data",) + +# _data: ArrayA[_ST, _DT] + +# def __init__(self, data: ArrayA[_ST, _DT]): +# self._data = data + + +# ta = TestArray(data) +# reveal_type(ta) + + +# # %% Class should pass +# # Not valid numpy protocol: +# class ArrayB(Protocol[_ST, _DT]): +# @property +# def dtype(self) -> _DT: ... +# @property +# def shape(self) -> _ST: ... +# def b(self) -> int: ... + + +# duckiearray = Union[ArrayA[_ST, _DT], ArrayB[_ST, _DT]] + + +# class TestArray2(Generic[_ST, _DT]): +# __slots__ = ("_data",) + +# _data: duckiearray[_ST, _DT] + +# def __init__(self, data: duckiearray[_ST, _DT]): +# self._data = data + + +# ta2 = TestArray2(data) +# reveal_type(ta2) + + +# # %% Class should pass +# class TestArray3(Generic[_ST, _DT]): +# __slots__ = ("_data",) + +# _data: duckarray[_ST, _DT] + +# def __init__(self, data: duckarray[_ST, _DT]): +# self._data = data + + +# ta3 = TestArray3(data) +# reveal_type(ta3) +# # %% Namedarray should pass + +# narr = NamedArray(("time",), data) +# reveal_type(narr) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 8ccf8c541b7..bb5252ae493 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -31,22 +31,22 @@ _DType, _IndexKeyLike, _IntOrUnknown, - _Shape, _ShapeLike, + _ShapeType, duckarray, ) class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): - def __init__(self, array: duckarray[Any, _DType_co]) -> None: - self.array: duckarray[Any, _DType_co] = array + def __init__(self, array: duckarray[_ShapeType_co, _DType_co]) -> None: + self.array: duckarray[_ShapeType_co, _DType_co] = array @property def dtype(self) -> _DType_co: return self.array.dtype @property - def shape(self) -> _Shape: + def shape(self) -> _ShapeType_co: return self.array.shape @@ -79,9 +79,11 @@ def __array_namespace__(self) -> ModuleType: return np -def check_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]: +def check_duck_array_typevar( + a: duckarray[_ShapeType, _DType] +) -> duckarray[_ShapeType, _DType]: # Mypy checks a is valid: - b: duckarray[Any, _DType] = a + b: duckarray[_ShapeType, _DType] = a # Runtime check if valid: if isinstance(b, _arrayfunction_or_api):