Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jax.tree_util.$tree_fn instead of deprecated jax.$tree_fn alias. #181

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def batch_shape(self) -> Tuple[int, ...]:
lambda s, e: s.shape[:s.ndim - len(e)], sample_spec, self.event_shape)

# Get flat batch shapes.
batch_shapes = jax.tree_structure(sample_spec).flatten_up_to(batch_shapes)
batch_shapes = jax.tree_util.tree_structure(sample_spec).flatten_up_to(
batch_shapes)
if not batch_shapes:
return ()

Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/utils/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __new__(cls, *args, **kwargs):
return object.__new__(registered_cls)

def tree_flatten(self):
leaves, treedef = jax.tree_flatten(self.__dict__)
leaves, treedef = jax.tree_util.tree_flatten(self.__dict__)
switch = list(map(_is_jax_data, leaves))
children = [leaf if s else None for leaf, s in zip(leaves, switch)]
metadata = [None if s else leaf for leaf, s in zip(leaves, switch)]
Expand All @@ -43,7 +43,7 @@ def tree_unflatten(cls, aux_data, children):
metadata, switch, treedef = aux_data
leaves = [j if s else p for j, p, s in zip(children, metadata, switch)]
obj = object.__new__(cls)
obj.__dict__ = jax.tree_unflatten(treedef, leaves)
obj.__dict__ = jax.tree_util.tree_unflatten(treedef, leaves)
return obj


Expand Down