Skip to content

Commit ff8e8ad

Browse files
Jake VanderPlasjax authors
Jake VanderPlas
authored and
jax authors
committed
revert #22734
Reverts 5ce66dc PiperOrigin-RevId: 657638187
1 parent 256956a commit ff8e8ad

File tree

5 files changed

+50
-47
lines changed

5 files changed

+50
-47
lines changed

jax/_src/numpy/array_methods.py

-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from jax._src.array import ArrayImpl
4040
from jax._src.lax import lax as lax_internal
4141
from jax._src.lib import xla_client as xc
42-
from jax._src.numpy import array_api_metadata
4342
from jax._src.numpy import lax_numpy
4443
from jax._src.numpy import reductions
4544
from jax._src.numpy import ufuncs
@@ -666,7 +665,6 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
666665
}
667666

668667
_array_methods = {
669-
"__array_namespace__": array_api_metadata.__array_namespace__,
670668
"all": reductions.all,
671669
"any": reductions.any,
672670
"argmax": lax_numpy.argmax,

jax/experimental/array_api/__init__.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,11 @@
3636

3737
from __future__ import annotations
3838

39+
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
40+
3941
from jax.experimental.array_api import fft as fft
4042
from jax.experimental.array_api import linalg as linalg
4143

42-
from jax._src.numpy.array_api_metadata import (
43-
__array_api_version__ as __array_api_version__,
44-
__array_namespace_info__ as __array_namespace_info__,
45-
)
46-
4744
from jax.numpy import (
4845
abs as abs,
4946
acos as acos,
@@ -200,3 +197,11 @@
200197
clip as clip,
201198
hypot as hypot,
202199
)
200+
201+
from jax.experimental.array_api._utility_functions import (
202+
__array_namespace_info__ as __array_namespace_info__,
203+
)
204+
205+
from jax.experimental.array_api import _array_methods
206+
_array_methods.add_array_object_methods()
207+
del _array_methods
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import jax
18+
from jax._src.array import ArrayImpl
19+
from jax.experimental.array_api._version import __array_api_version__
20+
21+
22+
def _array_namespace(self, /, *, api_version: None | str = None):
23+
if api_version is not None and api_version != __array_api_version__:
24+
raise ValueError(f"{api_version=!r} is not available; "
25+
f"available versions are: {[__array_api_version__]}")
26+
return jax.experimental.array_api
27+
28+
29+
def add_array_object_methods():
30+
# TODO(jakevdp): set on tracers as well?
31+
setattr(ArrayImpl, "__array_namespace__", _array_namespace)

jax/_src/numpy/array_api_metadata.py jax/experimental/array_api/_utility_functions.py

+9-36
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The JAX Authors.
1+
# Copyright 2023 The JAX Authors.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,50 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""
16-
This module contains metadata related to the `Python array API`_.
17-
18-
.. _Python array API: https://data-apis.org/array-api/
19-
"""
2015
from __future__ import annotations
2116

2217
import jax
2318
from jax._src.sharding import Sharding
2419
from jax._src.lib import xla_client as xc
2520
from jax._src import dtypes as _dtypes, config
2621

22+
# TODO(micky774): Add to jax.numpy.util when finalizing jax.experimental.array_api
23+
# deprecation
24+
class __array_namespace_info__:
2725

28-
# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
29-
__array_api_version__ = '2023.12'
30-
31-
32-
# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
33-
def __array_namespace_info__() -> ArrayNamespaceInfo:
34-
return ArrayNamespaceInfo()
35-
36-
37-
def __array_namespace__(self, /, *, api_version: None | str = None):
38-
"""Return the `Python array API`_ namespace for JAX.
39-
40-
.. _Python array API: https://data-apis.org/array-api/
41-
"""
42-
if api_version is not None and api_version != __array_api_version__:
43-
raise ValueError(f"{api_version=!r} is not available; "
44-
f"available versions are: {[__array_api_version__]}")
45-
# TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete.
46-
import jax.experimental.array_api
47-
return jax.experimental.array_api # pytype: disable=module-attr
48-
49-
50-
class ArrayNamespaceInfo:
51-
"""Metadata for the `Python array API`_
26+
def __init__(self):
27+
self._capabilities = {
28+
"boolean indexing": True,
29+
"data-dependent shapes": False,
30+
}
5231

53-
.. _Python array API: https://data-apis.org/array-api/
54-
"""
55-
_capabilities = {
56-
"boolean indexing": True,
57-
"data-dependent shapes": False,
58-
}
5932

6033
def _build_dtype_dict(self):
6134
array_api_types = {

jax/numpy/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,6 @@
274274
except ImportError:
275275
pass
276276

277-
from jax._src.numpy.array_api_metadata import (
278-
__array_api_version__ as __array_api_version__
279-
)
280-
281277
from jax._src.numpy.index_tricks import (
282278
c_ as c_,
283279
index_exp as index_exp,

0 commit comments

Comments
 (0)