31
31
"Info" ,
32
32
]
33
33
34
- from dataclasses import dataclass
35
34
from typing import (
36
35
Any ,
37
36
List ,
45
44
Protocol ,
46
45
)
47
46
from enum import Enum
47
+ from .data_types import DType
48
48
49
- array = TypeVar ("array" , bound = "array_ " )
49
+ array = TypeVar ("array" , bound = "Array " )
50
50
device = TypeVar ("device" )
51
- dtype = TypeVar ("dtype" )
51
+ dtype = TypeVar ("dtype" , bound = DType )
52
+ device_ = TypeVar ("device_" ) # only used in this file
53
+ dtype_ = TypeVar ("dtype_" , bound = DType ) # only used in this file
52
54
SupportsDLPack = TypeVar ("SupportsDLPack" )
53
55
SupportsBufferProtocol = TypeVar ("SupportsBufferProtocol" )
54
56
PyCapsule = TypeVar ("PyCapsule" )
@@ -88,7 +90,7 @@ def __len__(self, /) -> int:
88
90
...
89
91
90
92
91
- class Info (Protocol ):
93
+ class Info (Protocol [ device ] ):
92
94
"""Namespace returned by `__array_namespace_info__`."""
93
95
94
96
def capabilities (self ) -> Capabilities :
@@ -147,12 +149,12 @@ def dtypes(
147
149
)
148
150
149
151
150
- class _array (Protocol [array , dtype , device ]):
152
+ class Array (Protocol [array , dtype_ , device_ , PyCapsule ]): # type: ignore
151
153
def __init__ (self : array ) -> None :
152
154
"""Initialize the attributes for the array object class."""
153
155
154
156
@property
155
- def dtype (self : array ) -> dtype :
157
+ def dtype (self : array ) -> dtype_ :
156
158
"""
157
159
Data type of the array elements.
158
160
@@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
163
165
"""
164
166
165
167
@property
166
- def device (self : array ) -> device :
168
+ def device (self : array ) -> device_ :
167
169
"""
168
170
Hardware device the array data resides on.
169
171
@@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
625
627
ONE_API = 14
626
628
"""
627
629
628
- def __eq__ (self : array , other : Union [int , float , bool , array ], / ) -> array :
630
+ def __eq__ (self : array , other : Union [int , float , bool , array ], / ) -> array : # type: ignore
629
631
r"""
630
632
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
631
633
@@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
1072
1074
Added complex data type support.
1073
1075
"""
1074
1076
1075
- def __ne__ (self : array , other : Union [int , float , bool , array ], / ) -> array :
1077
+ def __ne__ (self : array , other : Union [int , float , bool , array ], / ) -> array : # type: ignore
1076
1078
"""
1077
1079
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
1078
1080
@@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
1342
1344
"""
1343
1345
1344
1346
def to_device (
1345
- self : array , device : device , / , * , stream : Optional [Union [int , Any ]] = None
1347
+ self : array , device : device_ , / , * , stream : Optional [Union [int , Any ]] = None
1346
1348
) -> array :
1347
1349
"""
1348
1350
Copy the array from the device on which it currently resides to the specified ``device``.
0 commit comments