Skip to content

Commit f3c9eb4

Browse files
committed
docs: fix docs again
1 parent 1810177 commit f3c9eb4

File tree

8 files changed

+65
-30
lines changed

8 files changed

+65
-30
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ tmp/
3535
*.egg
3636
dist/
3737
.DS_STORE
38+
venv
39+
.venv

.pre-commit-config.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,13 @@ repos:
4848
rev: 23.7.0
4949
hooks:
5050
- id: black
51+
52+
- repo: https://github.com/pre-commit/mirrors-mypy
53+
rev: "v1.0.0"
54+
hooks:
55+
- id: mypy
56+
additional_dependencies: [typing_extensions>=4.4.0]
57+
args:
58+
- --ignore-missing-imports
59+
- --config=pyproject.toml
60+
files: ".*(_draft.*)$"

pyproject.toml

+9
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ build-backend = "setuptools.build_meta"
3232

3333
[tool.black]
3434
line-length = 88
35+
36+
[tool.mypy]
37+
python_version = "3.9"
38+
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
39+
files = [
40+
"src/array_api_stubs/_draft/**/*.py"
41+
]
42+
follow_imports = "silent"
43+
disable_error_code = "empty-body,type-var"

src/_array_api_conf.py

+4
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@
6666
]
6767
nitpick_ignore_regex = [
6868
("py:class", ".*array"),
69+
("py:class", ".*Array"),
6970
("py:class", ".*device"),
71+
("py:class", ".*Device"),
7072
("py:class", ".*dtype"),
73+
("py:class", ".*DType"),
7174
("py:class", ".*NestedSequence"),
7275
("py:class", ".*SupportsBufferProtocol"),
7376
("py:class", ".*PyCapsule"),
@@ -84,6 +87,7 @@
8487
"array": "array",
8588
"Device": "device",
8689
"Dtype": "dtype",
90+
"DType": "dtype",
8791
}
8892

8993
# Make autosummary show the signatures of functions in the tables using actual

src/array_api_stubs/_draft/_types.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
"Info",
3232
]
3333

34-
from dataclasses import dataclass
3534
from typing import (
3635
Any,
3736
List,
@@ -45,10 +44,13 @@
4544
Protocol,
4645
)
4746
from enum import Enum
47+
from .data_types import DType
4848

49-
array = TypeVar("array", bound="array_")
49+
array = TypeVar("array", bound="Array")
5050
device = TypeVar("device")
51-
dtype = TypeVar("dtype")
51+
dtype = TypeVar("dtype", bound=DType)
52+
device_ = TypeVar("device_") # only used in this file
53+
dtype_ = TypeVar("dtype_", bound=DType) # only used in this file
5254
SupportsDLPack = TypeVar("SupportsDLPack")
5355
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
5456
PyCapsule = TypeVar("PyCapsule")
@@ -88,7 +90,7 @@ def __len__(self, /) -> int:
8890
...
8991

9092

91-
class Info(Protocol):
93+
class Info(Protocol[device]):
9294
"""Namespace returned by `__array_namespace_info__`."""
9395

9496
def capabilities(self) -> Capabilities:
@@ -147,12 +149,12 @@ def dtypes(
147149
)
148150

149151

150-
class _array(Protocol[array, dtype, device]):
152+
class Array(Protocol[array, dtype_, device_, PyCapsule]): # type: ignore
151153
def __init__(self: array) -> None:
152154
"""Initialize the attributes for the array object class."""
153155

154156
@property
155-
def dtype(self: array) -> dtype:
157+
def dtype(self: array) -> dtype_:
156158
"""
157159
Data type of the array elements.
158160
@@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
163165
"""
164166

165167
@property
166-
def device(self: array) -> device:
168+
def device(self: array) -> device_:
167169
"""
168170
Hardware device the array data resides on.
169171
@@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
625627
ONE_API = 14
626628
"""
627629

628-
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
630+
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
629631
r"""
630632
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
631633
@@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
10721074
Added complex data type support.
10731075
"""
10741076

1075-
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
1077+
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
10761078
"""
10771079
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
10781080
@@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
13421344
"""
13431345

13441346
def to_device(
1345-
self: array, device: device, /, *, stream: Optional[Union[int, Any]] = None
1347+
self: array, device: device_, /, *, stream: Optional[Union[int, Any]] = None
13461348
) -> array:
13471349
"""
13481350
Copy the array from the device on which it currently resides to the specified ``device``.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._types import Array
2+
3+
# for documentation
4+
array = Array
5+
6+
__all__ = ["array"]
+19-17
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
__all__ = ["__eq__"]
1+
from __future__ import annotations
22

3+
__all__ = ["DType"]
34

4-
from ._types import dtype
55

6+
from typing import Protocol
67

7-
def __eq__(self: dtype, other: dtype, /) -> bool:
8-
"""
9-
Computes the truth value of ``self == other`` in order to test for data type object equality.
108

11-
Parameters
12-
----------
13-
self: dtype
14-
data type instance. May be any supported data type.
15-
other: dtype
16-
other data type instance. May be any supported data type.
17-
18-
Returns
19-
-------
20-
out: bool
21-
a boolean indicating whether the data type objects are equal.
22-
"""
9+
class DType(Protocol):
10+
def __eq__(self, other: DType, /) -> bool:
11+
"""
12+
Computes the truth value of ``self == other`` in order to test for data type object equality.
13+
Parameters
14+
----------
15+
self: dtype
16+
data type instance. May be any supported data type.
17+
other: dtype
18+
other data type instance. May be any supported data type.
19+
Returns
20+
-------
21+
out: bool
22+
a boolean indicating whether the data type objects are equal.
23+
"""
24+
...

src/array_api_stubs/_draft/linalg.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def matrix_norm(
301301
/,
302302
*,
303303
keepdims: bool = False,
304-
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro",
304+
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", # type: ignore
305305
) -> array:
306306
"""
307307
Computes the matrix norm of a matrix (or a stack of matrices) ``x``.
@@ -781,7 +781,7 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
781781
"""
782782

783783

784-
def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
784+
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
785785
"""Alias for :func:`~array_api.vecdot`."""
786786

787787

@@ -791,7 +791,7 @@ def vector_norm(
791791
*,
792792
axis: Optional[Union[int, Tuple[int, ...]]] = None,
793793
keepdims: bool = False,
794-
ord: Union[int, float, Literal[inf, -inf]] = 2,
794+
ord: Union[int, float, Literal[inf, -inf]] = 2, # type: ignore
795795
) -> array:
796796
r"""
797797
Computes the vector norm of a vector (or batch of vectors) ``x``.

0 commit comments

Comments
 (0)