Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 633773679
  • Loading branch information
Jake VanderPlas authored and tree-math authors committed May 15, 2024
1 parent 0727453 commit f08636e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tree_math/_src/structs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class StructsTest(test_util.TestCase):
dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5])))
)
def testFlattenUnflatten(self, x):
leaves, structure = jax.tree_flatten(x)
y = jax.tree_unflatten(structure, leaves)
leaves, structure = jax.tree.flatten(x)
y = jax.tree.unflatten(structure, leaves)
np.testing.assert_allclose(x.a, y.a)
np.testing.assert_allclose(x.b, y.b)

Expand Down

0 comments on commit f08636e

Please sign in to comment.