Skip to content

Commit

Permalink
Remove references to deprecated jax.ShapedArray
Browse files Browse the repository at this point in the history
This is deprecated as of google/jax#15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.

PiperOrigin-RevId: 520189916
  • Loading branch information
Jake VanderPlas authored and Copybara-Service committed Mar 29, 2023
1 parent e962caa commit 1bb3b89
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,7 +2881,7 @@ def body(i, xy):
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
wrapped = linear_util.wrap_init(f)
pv = partial_eval.PartialVal(
(jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
(jax.core.ShapedArray((3, 4), onp.float32), jax.core.unit))
_, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv])
self.assertFalse(
any(onp.array_equal(x, onp.full((3, 4), 2., dtype=onp.float32))
Expand Down

0 comments on commit 1bb3b89

Please sign in to comment.