Skip to content

Commit 5ce66dc

Browse files
author
jax authors
committed
Merge pull request #22734 from jakevdp:array-api-methods
PiperOrigin-RevId: 657299896
2 parents 7fd9302 + 00ba7a6 commit 5ce66dc

File tree

5 files changed

+47
-50
lines changed

5 files changed

+47
-50
lines changed

jax/experimental/array_api/_utility_functions.py jax/_src/numpy/array_api_metadata.py

+36-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 The JAX Authors.
1+
# Copyright 2024 The JAX Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,23 +12,50 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""
16+
This module contains metadata related to the `Python array API`_.
17+
18+
.. _Python array API: https://data-apis.org/array-api/
19+
"""
1520
from __future__ import annotations
1621

1722
import jax
1823
from jax._src.sharding import Sharding
1924
from jax._src.lib import xla_client as xc
2025
from jax._src import dtypes as _dtypes, config
2126

22-
# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api
23-
# deprecation
24-
class __array_namespace_info__:
2527

26-
def __init__(self):
27-
self._capabilities = {
28-
"boolean indexing": True,
29-
"data-dependent shapes": False,
30-
}
28+
# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
29+
__array_api_version__ = '2023.12'
30+
31+
32+
# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
33+
def __array_namespace_info__() -> ArrayNamespaceInfo:
34+
return ArrayNamespaceInfo()
35+
36+
37+
def __array_namespace__(self, /, *, api_version: None | str = None):
38+
"""Return the `Python array API`_ namespace for JAX.
39+
40+
.. _Python array API: https://data-apis.org/array-api/
41+
"""
42+
if api_version is not None and api_version != __array_api_version__:
43+
raise ValueError(f"{api_version=!r} is not available; "
44+
f"available versions are: {[__array_api_version__]}")
45+
# TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete.
46+
import jax.experimental.array_api
47+
return jax.experimental.array_api # pytype: disable=module-attr
48+
49+
50+
class ArrayNamespaceInfo:
51+
"""Metadata for the `Python array API`_
3152
53+
.. _Python array API: https://data-apis.org/array-api/
54+
"""
55+
_capabilities = {
56+
"boolean indexing": True,
57+
"data-dependent shapes": False,
58+
}
3259

3360
def _build_dtype_dict(self):
3461
array_api_types = {

jax/_src/numpy/array_methods.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src.array import ArrayImpl
4040
from jax._src.lax import lax as lax_internal
4141
from jax._src.lib import xla_client as xc
42+
from jax._src.numpy import array_api_metadata
4243
from jax._src.numpy import lax_numpy
4344
from jax._src.numpy import reductions
4445
from jax._src.numpy import ufuncs
@@ -665,6 +666,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
665666
}
666667

667668
_array_methods = {
669+
"__array_namespace__": array_api_metadata.__array_namespace__,
668670
"all": reductions.all,
669671
"any": reductions.any,
670672
"argmax": lax_numpy.argmax,

jax/experimental/array_api/__init__.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@
3636

3737
from __future__ import annotations
3838

39-
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
40-
4139
from jax.experimental.array_api import fft as fft
4240
from jax.experimental.array_api import linalg as linalg
4341

42+
from jax._src.numpy.array_api_metadata import (
43+
__array_api_version__ as __array_api_version__,
44+
__array_namespace_info__ as __array_namespace_info__,
45+
)
46+
4447
from jax.numpy import (
4548
abs as abs,
4649
acos as acos,
@@ -197,11 +200,3 @@
197200
clip as clip,
198201
hypot as hypot,
199202
)
200-
201-
from jax.experimental.array_api._utility_functions import (
202-
__array_namespace_info__ as __array_namespace_info__,
203-
)
204-
205-
from jax.experimental.array_api import _array_methods
206-
_array_methods.add_array_object_methods()
207-
del _array_methods

jax/experimental/array_api/_array_methods.py

-31
This file was deleted.

jax/numpy/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@
274274
except ImportError:
275275
pass
276276

277+
from jax._src.numpy.array_api_metadata import (
278+
__array_api_version__ as __array_api_version__
279+
)
280+
277281
from jax._src.numpy.index_tricks import (
278282
c_ as c_,
279283
index_exp as index_exp,

0 commit comments

Comments
 (0)