Skip to content

Commit ee4e622

Browse files
jakeharmon8copybara-github
authored andcommitted
Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886981
1 parent 8ca8408 commit ee4e622

11 files changed

+24
-24
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ You can learn here how Trax works, how to create new models and how to train the
116116

117117
The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.
118118

119-
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).
119+
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).
120120

121121

122122
```python

docs/source/notebooks/tf_numpy_and_keras.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"\n",
2727
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
2828
"\n",
29-
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
29+
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
3030
"\n",
3131
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
3232
"\n",

docs/source/notebooks/trax_intro.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
"\n",
229229
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
230230
"\n",
231-
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
231+
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
232232
]
233233
},
234234
{

trax/intro.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
"\n",
229229
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
230230
"\n",
231-
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
231+
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
232232
]
233233
},
234234
{

trax/supervised/training_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_loop_with_initialized_model(self):
143143

144144
def test_train_save_restore_dense(self):
145145
"""Saves and restores a checkpoint to check for equivalence."""
146-
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
146+
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
147147
train_data = data.Serial(lambda _: _very_simple_data(),
148148
data.CountAndSkip('simple_data'))
149149
task = training.TrainTask(
@@ -327,7 +327,7 @@ def test_restores_step(self):
327327

328328
def test_restores_memory_efficient_from_standard(self):
329329
"""Training restores step from directory where it saved it."""
330-
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
330+
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
331331
model = tl.Serial(tl.Dense(4), tl.Dense(1))
332332
task_std = training.TrainTask(
333333
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001))
@@ -345,7 +345,7 @@ def test_restores_memory_efficient_from_standard(self):
345345

346346
def test_restores_from_smaller_model(self):
347347
"""Training restores from a checkpoint created with smaller model."""
348-
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
348+
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
349349
model1 = tl.Serial(tl.Dense(1))
350350
task = training.TrainTask(
351351
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01))
@@ -374,7 +374,7 @@ def test_restore_fails_different_model(self):
374374

375375
def test_restores_step_bfloat16(self):
376376
"""Training restores step from directory where it saved it, w/ bfloat16."""
377-
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
377+
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
378378
model = tl.Serial(tl.Dense(1, use_bfloat16=True))
379379
# We'll also use Adafactor with bfloat16 to check restoring bfloat slots.
380380
opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True)

trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_two_operands_5(self):
8989
self._check(s, x, y)
9090

9191
def test_two_operands_6(self):
92-
# based on https://github.com/google/jax/issues/37#issuecomment-448572187
92+
# based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187
9393
r = self.rng()
9494
x = r.randn(2, 1)
9595
y = r.randn(2, 3, 4)

trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def testFloatIndexingError(self):
720720
with self.assertRaisesRegex(IndexError, error_regex):
721721
npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
722722

723-
def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
723+
def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
724724
array = jnp.ones(5)
725725
self.assertAllClose(array, array[:10], check_dtypes=True)
726726

trax/tf_numpy/jax_tests/lax_numpy_test.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2051,14 +2051,14 @@ def testAstype(self):
20512051
# TODO(mattjj): test other ndarray-like method overrides
20522052

20532053
def testOnpMean(self):
2054-
# from https://github.com/google/jax/issues/125
2054+
# from https://github.com/jax-ml/jax/issues/125
20552055
x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.)
20562056
ans = onp.mean(x)
20572057
self.assertAllClose(ans, onp.array(1./3), check_dtypes=False)
20582058

20592059
@jtu.disable
20602060
def testArangeOnFloats(self):
2061-
# from https://github.com/google/jax/issues/145
2061+
# from https://github.com/jax-ml/jax/issues/145
20622062
expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_)
20632063
ans = lnp.arange(0.0, 1.0, 0.1)
20642064
self.assertAllClose(expected, ans, check_dtypes=True)
@@ -2407,7 +2407,7 @@ def testSymmetrizeDtypePromotion(self):
24072407

24082408
@jtu.disable
24092409
def testIssue347(self):
2410-
# https://github.com/google/jax/issues/347
2410+
# https://github.com/jax-ml/jax/issues/347
24112411
def test_fail(x):
24122412
x = lnp.sqrt(lnp.sum(x ** 2, axis=1))
24132413
ones = lnp.ones_like(x)
@@ -2419,7 +2419,7 @@ def test_fail(x):
24192419
assert not onp.any(onp.isnan(result))
24202420

24212421
def testIssue453(self):
2422-
# https://github.com/google/jax/issues/453
2422+
# https://github.com/jax-ml/jax/issues/453
24232423
a = onp.arange(6) + 1
24242424
ans = lnp.reshape(a, (3, 2), order='F')
24252425
expected = onp.reshape(a, (3, 2), order='F')
@@ -2432,7 +2432,7 @@ def testIssue453(self):
24322432
(bool, lnp.bool_), (complex, lnp.complex_)]
24332433
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
24342434
def testAtLeastNdLiterals(self, pytype, dtype, op):
2435-
# Fixes: https://github.com/google/jax/issues/634
2435+
# Fixes: https://github.com/jax-ml/jax/issues/634
24362436
onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype)
24372437
lnp_fun = lambda arg: getattr(lnp, op)(arg)
24382438
args_maker = lambda: [pytype(2)]
@@ -2550,7 +2550,7 @@ def testMathSpecialFloatValues(self, op, dtype):
25502550
rtol=tol)
25512551

25522552
def testIssue883(self):
2553-
# from https://github.com/google/jax/issues/883
2553+
# from https://github.com/jax-ml/jax/issues/883
25542554

25552555
@partial(npe.jit, static_argnums=(1,))
25562556
def f(x, v):
@@ -2907,7 +2907,7 @@ def testDisableNumpyRankPromotionBroadcasting(self):
29072907
FLAGS.jax_numpy_rank_promotion = prev_flag
29082908

29092909
def testStackArrayArgument(self):
2910-
# tests https://github.com/google/jax/issues/1271
2910+
# tests https://github.com/jax-ml/jax/issues/1271
29112911
@npe.jit
29122912
def foo(x):
29132913
return lnp.stack(x)
@@ -3120,7 +3120,7 @@ def testOpGradSpecialValue(self, op, special_value, order):
31203120

31213121
@jtu.disable
31223122
def testTakeAlongAxisIssue1521(self):
3123-
# https://github.com/google/jax/issues/1521
3123+
# https://github.com/jax-ml/jax/issues/1521
31243124
idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1))
31253125

31263126
def f(x):

trax/tf_numpy/jax_tests/test_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def test_method_wrapper(self, *args, **kwargs):
316316
return test_method_wrapper
317317
return skip
318318

319-
# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432
319+
# TODO(phawkins): workaround for bug https://github.com/jax-ml/jax/issues/432
320320
# Delete this code after the minimum jaxlib version is 0.1.46 or greater.
321321
skip_on_mac_linalg_bug = partial(
322322
unittest.skipIf,

trax/tf_numpy/jax_tests/vmap_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
class VmapTest(tf.test.TestCase, parameterized.TestCase):
2929

3030
def test_vmap_in_axes_list(self):
31-
# https://github.com/google/jax/issues/2367
31+
# https://github.com/jax-ml/jax/issues/2367
3232
dictionary = {'a': 5., 'b': tf_np.ones(2)}
3333
x = tf_np.zeros(3)
3434
y = tf_np.arange(3.)
@@ -41,7 +41,7 @@ def f(dct, x, y):
4141
self.assertAllClose(out1, out2)
4242

4343
def test_vmap_in_axes_tree_prefix_error(self):
44-
# https://github.com/google/jax/issues/795
44+
# https://github.com/jax-ml/jax/issues/795
4545
self.assertRaisesRegex(
4646
ValueError,
4747
'vmap in_axes specification must be a tree prefix of the corresponding '
@@ -63,14 +63,14 @@ def test_vmap_out_axes_leaf_types(self):
6363
tf_np.array([1., 2.]))
6464

6565
def test_vmap_unbatched_object_passthrough_issue_183(self):
66-
# https://github.com/google/jax/issues/183
66+
# https://github.com/jax-ml/jax/issues/183
6767
fun = lambda f, x: f(x)
6868
vfun = extensions.vmap(fun, (None, 0))
6969
ans = vfun(lambda x: x + 1, tf_np.arange(3))
7070
self.assertAllClose(ans, np.arange(1, 4))
7171

7272
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
73-
# https://github.com/google/jax/issues/705
73+
# https://github.com/jax-ml/jax/issues/705
7474
with self.assertRaisesRegex(
7575
ValueError, 'vmap must have at least one non-None value in in_axes'):
7676
# If the output is mapped, there must be a non-None in_axes

trax/tf_numpy_and_keras.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"\n",
2727
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
2828
"\n",
29-
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
29+
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
3030
"\n",
3131
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
3232
"\n",

0 commit comments

Comments
 (0)