|
1 |
| -# Copyright 2023 The JAX Authors. |
| 1 | +# Copyright 2024 The JAX Authors. |
2 | 2 | #
|
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | # you may not use this file except in compliance with the License.
|
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +""" |
| 16 | +This module contains metadata related to the `Python array API`_. |
| 17 | +
|
| 18 | +.. _Python array API: https://data-apis.org/array-api/ |
| 19 | +""" |
15 | 20 | from __future__ import annotations
|
16 | 21 |
|
17 | 22 | import jax
|
18 | 23 | from jax._src.sharding import Sharding
|
19 | 24 | from jax._src.lib import xla_client as xc
|
20 | 25 | from jax._src import dtypes as _dtypes, config
|
21 | 26 |
|
22 |
| -# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api |
23 |
| -# deprecation |
24 |
| -class __array_namespace_info__: |
25 | 27 |
|
26 |
| - def __init__(self): |
27 |
| - self._capabilities = { |
28 |
| - "boolean indexing": True, |
29 |
| - "data-dependent shapes": False, |
30 |
| - } |
| 28 | +# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. |
| 29 | +__array_api_version__ = '2023.12' |
| 30 | + |
| 31 | + |
| 32 | +# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete. |
| 33 | +def __array_namespace_info__() -> ArrayNamespaceInfo: |
| 34 | + return ArrayNamespaceInfo() |
| 35 | + |
| 36 | + |
| 37 | +def __array_namespace__(self, /, *, api_version: None | str = None): |
| 38 | + """Return the `Python array API`_ namespace for JAX. |
| 39 | +
|
| 40 | + .. _Python array API: https://data-apis.org/array-api/ |
| 41 | + """ |
| 42 | + if api_version is not None and api_version != __array_api_version__: |
| 43 | + raise ValueError(f"{api_version=!r} is not available; " |
| 44 | + f"available versions are: {[__array_api_version__]}") |
| 45 | + # TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete. |
| 46 | + import jax.experimental.array_api |
| 47 | + return jax.experimental.array_api # pytype: disable=module-attr |
| 48 | + |
| 49 | + |
| 50 | +class ArrayNamespaceInfo: |
| 51 | + """Metadata for the `Python array API`_ |
31 | 52 |
|
| 53 | + .. _Python array API: https://data-apis.org/array-api/ |
| 54 | + """ |
| 55 | + _capabilities = { |
| 56 | + "boolean indexing": True, |
| 57 | + "data-dependent shapes": False, |
| 58 | + } |
32 | 59 |
|
33 | 60 | def _build_dtype_dict(self):
|
34 | 61 | array_api_types = {
|
|
0 commit comments