diff --git a/tree_math/_src/structs_test.py b/tree_math/_src/structs_test.py index 879e69c..b3af767 100644 --- a/tree_math/_src/structs_test.py +++ b/tree_math/_src/structs_test.py @@ -41,8 +41,8 @@ class StructsTest(test_util.TestCase): dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5]))) ) def testFlattenUnflatten(self, x): - leaves, structure = jax.tree_flatten(x) - y = jax.tree_unflatten(structure, leaves) + leaves, structure = jax.tree.flatten(x) + y = jax.tree.unflatten(structure, leaves) np.testing.assert_allclose(x.a, y.a) np.testing.assert_allclose(x.b, y.b)