Skip to content

Commit

Permalink
Use jax.tree_util.$tree_fn instead of deprecated jax.$tree_fn alias.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 464072217
  • Loading branch information
GeorgOstrovski authored and DistraxDev committed Jul 29, 2022
1 parent 599692e commit 615605e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def batch_shape(self) -> Tuple[int, ...]:
lambda s, e: s.shape[:s.ndim - len(e)], sample_spec, self.event_shape)

# Get flat batch shapes.
batch_shapes = jax.tree_structure(sample_spec).flatten_up_to(batch_shapes)
batch_shapes = jax.tree_util.tree_structure(sample_spec).flatten_up_to(
batch_shapes)
if not batch_shapes:
return ()

Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/utils/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __new__(cls, *args, **kwargs):
return object.__new__(registered_cls)

def tree_flatten(self):
leaves, treedef = jax.tree_flatten(self.__dict__)
leaves, treedef = jax.tree_util.tree_flatten(self.__dict__)
switch = list(map(_is_jax_data, leaves))
children = [leaf if s else None for leaf, s in zip(leaves, switch)]
metadata = [None if s else leaf for leaf, s in zip(leaves, switch)]
Expand All @@ -43,7 +43,7 @@ def tree_unflatten(cls, aux_data, children):
metadata, switch, treedef = aux_data
leaves = [j if s else p for j, p, s in zip(children, metadata, switch)]
obj = object.__new__(cls)
obj.__dict__ = jax.tree_unflatten(treedef, leaves)
obj.__dict__ = jax.tree_util.tree_unflatten(treedef, leaves)
return obj


Expand Down

0 comments on commit 615605e

Please sign in to comment.