Skip to content

Commit

Permalink
Remove users of jax.api.* symbols, in preparation for removing the de…
Browse files Browse the repository at this point in the history
…precated jax.api name.

In most cases, use the public jax.* name instead.

PiperOrigin-RevId: 395767146
  • Loading branch information
hawkinsp authored and DistraxDev committed Sep 9, 2021
1 parent 272ccb8 commit 416f1bc
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions distrax/_src/utils/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
appear in the output. This is an experimental feature.
"""

import functools

from absl import logging
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -110,7 +112,8 @@ def register_inverse(primitive, inverse_left, inverse_right=None):
def inv(fun):
"""Returns the inverse of `fun` such that (inv(fun) o fun)(x) = x."""
jaxpr_fn = _invertible_jaxpr_and_constants(fun)
@jax.api.wraps(fun) # pylint: disable=no-value-for-parameter

@functools.wraps(fun) # pylint: disable=no-value-for-parameter
def wrapped(*args, **kwargs):
jaxpr, consts = jaxpr_fn(*args, **kwargs)
out = _interpret_inverse(jaxpr, consts, *args)
Expand Down Expand Up @@ -193,7 +196,7 @@ def _invertible_jaxpr_and_constants(fun):
"""Returns a transformation from function invocation to invertible jaxpr."""
jaxpr_maker = jax.make_jaxpr(fun)

@jax.api.wraps(fun) # pylint: disable=no-value-for-parameter
@functools.wraps(fun) # pylint: disable=no-value-for-parameter
def jaxpr_const_maker(*args, **kwargs):
typed_jaxpr = jaxpr_maker(*args, **kwargs)
return typed_jaxpr.jaxpr, typed_jaxpr.literals
Expand Down

0 comments on commit 416f1bc

Please sign in to comment.